Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
MethodRSNNS.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 
5 /**********************************************************************************
6  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
7  * Package: TMVA *
8  * Class : MethodRSNNS *
9  * Web : http://oproject.org *
10  * *
11  * Description: *
12  * Neural Networks in R using the Stuttgart Neural Network Simulator *
13  * *
14  * *
15  * Redistribution and use in source and binary forms, with or without *
16  * modification, are permitted according to the terms listed in LICENSE *
17  * (http://tmva.sourceforge.net/LICENSE) *
18  * *
19  **********************************************************************************/
20 
21 #include <iomanip>
22 
23 #include "TMath.h"
24 #include "Riostream.h"
25 #include "TMatrix.h"
26 #include "TMatrixD.h"
27 #include "TVectorD.h"
28 
30 #include "TMVA/MethodRSNNS.h"
31 #include "TMVA/Tools.h"
32 #include "TMVA/Config.h"
33 #include "TMVA/Ranking.h"
34 #include "TMVA/Types.h"
35 #include "TMVA/PDF.h"
36 #include "TMVA/ClassifierFactory.h"
37 
38 #include "TMVA/Results.h"
39 #include "TMVA/Timer.h"
40 
41 using namespace TMVA;
42 
43 REGISTER_METHOD(RSNNS)
44 
45 ClassImp(MethodRSNNS);
46 
47 //creating an Instance
48 Bool_t MethodRSNNS::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("RSNNS");
49 
50 //_______________________________________________________________________
51 MethodRSNNS::MethodRSNNS(const TString &jobName,
52  const TString &methodTitle,
53  DataSetInfo &dsi,
54  const TString &theOption) :
55  RMethodBase(jobName, Types::kRSNNS, methodTitle, dsi, theOption),
56  fMvaCounter(0),
57  predict("predict"),
58  mlp("mlp"),
59  asfactor("as.factor"),
60  fModel(NULL)
61 {
62  fNetType = methodTitle;
63  if (fNetType != "RMLP") {
64  Log() << kFATAL << " Unknow Method" + fNetType
65  << Endl;
66  return;
67  }
68 
69  // standard constructor for the RSNNS
70  //RSNNS Options for all NN methods
71  fSize = "c(5)";
72  fMaxit = 100;
73 
74  fInitFunc = "Randomize_Weights";
75  fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
76 
77  fLearnFunc = "Std_Backpropagation"; //
78  fLearnFuncParams = "c(0.2,0)";
79 
80  fUpdateFunc = "Topological_Order";
81  fUpdateFuncParams = "c(0)";
82 
83  fHiddenActFunc = "Act_Logistic";
84  fShufflePatterns = kTRUE;
85  fLinOut = kFALSE;
86  fPruneFunc = "NULL";
87  fPruneFuncParams = "NULL";
88 
89 }
90 
91 //_______________________________________________________________________
92 MethodRSNNS::MethodRSNNS(DataSetInfo &theData, const TString &theWeightFile)
93  : RMethodBase(Types::kRSNNS, theData, theWeightFile),
94  fMvaCounter(0),
95  predict("predict"),
96  mlp("mlp"),
97  asfactor("as.factor"),
98  fModel(NULL)
99 
100 {
101  fNetType = "RMLP"; //GetMethodName();//GetMethodName() is not returning RMLP is reting MethodBase why?
102  if (fNetType != "RMLP") {
103  Log() << kFATAL << " Unknow Method = " + fNetType
104  << Endl;
105  return;
106  }
107 
108  // standard constructor for the RSNNS
109  //RSNNS Options for all NN methods
110  fSize = "c(5)";
111  fMaxit = 100;
112 
113  fInitFunc = "Randomize_Weights";
114  fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
115 
116  fLearnFunc = "Std_Backpropagation"; //
117  fLearnFuncParams = "c(0.2,0)";
118 
119  fUpdateFunc = "Topological_Order";
120  fUpdateFuncParams = "c(0)";
121 
122  fHiddenActFunc = "Act_Logistic";
123  fShufflePatterns = kTRUE;
124  fLinOut = kFALSE;
125  fPruneFunc = "NULL";
126  fPruneFuncParams = "NULL";
127 }
128 
129 
130 //_______________________________________________________________________
131 MethodRSNNS::~MethodRSNNS(void)
132 {
133  if (fModel) delete fModel;
134 }
135 
136 //_______________________________________________________________________
137 Bool_t MethodRSNNS::HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/)
138 {
139  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
140  return kFALSE;
141 }
142 
143 
144 //_______________________________________________________________________
145 void MethodRSNNS::Init()
146 {
147  if (!IsModuleLoaded) {
148  Error("Init", "R's package RSNNS can not be loaded.");
149  Log() << kFATAL << " R's package RSNNS can not be loaded."
150  << Endl;
151  return;
152  }
153  //factors creations
154  //RSNNS mlp require a numeric factor then background=0 signal=1 from fFactorTrain/fFactorTest
155  UInt_t size = fFactorTrain.size();
156  fFactorNumeric.resize(size);
157 
158  for (UInt_t i = 0; i < size; i++) {
159  if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
160  else fFactorNumeric[i] = 0;
161  }
162 }
163 
164 void MethodRSNNS::Train()
165 {
166  if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
167  if (fNetType == "RMLP") {
168  ROOT::R::TRObject PruneFunc;
169  if (fPruneFunc == "NULL") PruneFunc = r.Eval("NULL");
170  else PruneFunc = r.Eval(Form("'%s'", fPruneFunc.Data()));
171 
172  SEXP Model = mlp(ROOT::R::Label["x"] = fDfTrain,
173  ROOT::R::Label["y"] = fFactorNumeric,
174  ROOT::R::Label["size"] = r.Eval(fSize),
175  ROOT::R::Label["maxit"] = fMaxit,
176  ROOT::R::Label["initFunc"] = fInitFunc,
177  ROOT::R::Label["initFuncParams"] = r.Eval(fInitFuncParams),
178  ROOT::R::Label["learnFunc"] = fLearnFunc,
179  ROOT::R::Label["learnFuncParams"] = r.Eval(fLearnFuncParams),
180  ROOT::R::Label["updateFunc"] = fUpdateFunc,
181  ROOT::R::Label["updateFuncParams"] = r.Eval(fUpdateFuncParams),
182  ROOT::R::Label["hiddenActFunc"] = fHiddenActFunc,
183  ROOT::R::Label["shufflePatterns"] = fShufflePatterns,
184  ROOT::R::Label["libOut"] = fLinOut,
185  ROOT::R::Label["pruneFunc"] = PruneFunc,
186  ROOT::R::Label["pruneFuncParams"] = r.Eval(fPruneFuncParams));
187  fModel = new ROOT::R::TRObject(Model);
188  //if model persistence is enabled saving it is R serialziation.
189  if (IsModelPersistence())
190  {
191  TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
192  Log() << Endl;
193  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
194  Log() << Endl;
195  r["RMLPModel"] << Model;
196  r << "save(RMLPModel,file='" + path + "')";
197  }
198  }
199 }
200 
201 //_______________________________________________________________________
202 void MethodRSNNS::DeclareOptions()
203 {
204  //RSNNS Options for all NN methods
205 // TVectorF fSize;//number of units in the hidden layer(s)
206  DeclareOptionRef(fSize, "Size", "number of units in the hidden layer(s)");
207  DeclareOptionRef(fMaxit, "Maxit", "Maximum of iterations to learn");
208 
209  DeclareOptionRef(fInitFunc, "InitFunc", "the initialization function to use");
210  DeclareOptionRef(fInitFuncParams, "InitFuncParams", "the parameters for the initialization function");
211 
212  DeclareOptionRef(fLearnFunc, "LearnFunc", "the learning function to use");
213  DeclareOptionRef(fLearnFuncParams, "LearnFuncParams", "the parameters for the learning function");
214 
215  DeclareOptionRef(fUpdateFunc, "UpdateFunc", "the update function to use");
216  DeclareOptionRef(fUpdateFuncParams, "UpdateFuncParams", "the parameters for the update function");
217 
218  DeclareOptionRef(fHiddenActFunc, "HiddenActFunc", "the activation function of all hidden units");
219  DeclareOptionRef(fShufflePatterns, "ShufflePatterns", "should the patterns be shuffled?");
220  DeclareOptionRef(fLinOut, "LinOut", "sets the activation function of the output units to linear or logistic");
221 
222  DeclareOptionRef(fPruneFunc, "PruneFunc", "the prune function to use");
223  DeclareOptionRef(fPruneFuncParams, "PruneFuncParams", "the parameters for the pruning function. Unlike the\
224  other functions, these have to be given in a named list. See\
225  the pruning demos for further explanation.the update function to use");
226 
227 }
228 
229 //_______________________________________________________________________
230 void MethodRSNNS::ProcessOptions()
231 {
232  if (fMaxit <= 0) {
233  Log() << kERROR << " fMaxit <=0... that does not work !! "
234  << " I set it to 50 .. just so that the program does not crash"
235  << Endl;
236  fMaxit = 1;
237  }
238  // standard constructor for the RSNNS
239  //RSNNS Options for all NN methods
240 
241 }
242 
243 //_______________________________________________________________________
244 void MethodRSNNS::TestClassification()
245 {
246  Log() << kINFO << "Testing Classification " << fNetType << " METHOD " << Endl;
247 
248  MethodBase::TestClassification();
249 }
250 
251 
252 //_______________________________________________________________________
253 Double_t MethodRSNNS::GetMvaValue(Double_t *errLower, Double_t *errUpper)
254 {
255  NoErrorCalc(errLower, errUpper);
256  Double_t mvaValue;
257  const TMVA::Event *ev = GetEvent();
258  const UInt_t nvar = DataInfo().GetNVariables();
259  ROOT::R::TRDataFrame fDfEvent;
260  for (UInt_t i = 0; i < nvar; i++) {
261  fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
262  }
263  //if using persistence model
264  if (IsModelPersistence()) ReadModelFromFile();
265 
266  TVectorD result = predict(*fModel, fDfEvent, ROOT::R::Label["type"] = "prob");
267  mvaValue = result[0]; //returning signal prob
268  return mvaValue;
269 }
270 
271 ////////////////////////////////////////////////////////////////////////////////
272 /// get all the MVA values for the events of the current Data type
273 std::vector<Double_t> MethodRSNNS::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
274 {
275  Long64_t nEvents = Data()->GetNEvents();
276  if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
277  if (firstEvt < 0) firstEvt = 0;
278 
279  nEvents = lastEvt-firstEvt;
280 
281  UInt_t nvars = Data()->GetNVariables();
282 
283  // use timer
284  Timer timer( nEvents, GetName(), kTRUE );
285  if (logProgress)
286  Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
287  << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
288 
289 
290  // fill R DATA FRAME with events data
291  std::vector<std::vector<Float_t> > inputData(nvars);
292  for (UInt_t i = 0; i < nvars; i++) {
293  inputData[i] = std::vector<Float_t>(nEvents);
294  }
295 
296  for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
297  Data()->SetCurrentEvent(ievt);
298  const TMVA::Event *e = Data()->GetEvent();
299  assert(nvars == e->GetNVariables());
300  for (UInt_t i = 0; i < nvars; i++) {
301  inputData[i][ievt] = e->GetValue(i);
302  }
303  // if (ievt%100 == 0)
304  // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
305  }
306 
307  ROOT::R::TRDataFrame evtData;
308  for (UInt_t i = 0; i < nvars; i++) {
309  evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
310  }
311  //if using persistence model
312  if (IsModelPersistence()) ReadModelFromFile();
313 
314  std::vector<Double_t> mvaValues(nEvents);
315  ROOT::R::TRObject result = predict(*fModel, evtData, ROOT::R::Label["type"] = "prob");
316  //std::vector<Double_t> probValues(2*nEvents);
317  mvaValues = result.As<std::vector<Double_t>>();
318  // assert(probValues.size() == 2*mvaValues.size());
319  // std::copy(probValues.begin()+nEvents, probValues.end(), mvaValues.begin() );
320 
321  if (logProgress) {
322  Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
323  << timer.GetElapsedTime() << " " << Endl;
324  }
325 
326  return mvaValues;
327 
328 }
329 
330 
331 //_______________________________________________________________________
332 void TMVA::MethodRSNNS::ReadModelFromFile()
333 {
334  ROOT::R::TRInterface::Instance().Require("RSNNS");
335  TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
336  Log() << Endl;
337  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
338  Log() << Endl;
339  r << "load('" + path + "')";
340  SEXP Model;
341  r["RMLPModel"] >> Model;
342  fModel = new ROOT::R::TRObject(Model);
343 
344 }
345 
346 
347 //_______________________________________________________________________
348 void MethodRSNNS::GetHelpMessage() const
349 {
350 // get help message text
351 //
352 // typical length of text line:
353 // "|--------------------------------------------------------------|"
354  Log() << Endl;
355  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
356  Log() << Endl;
357  Log() << "Decision Trees and Rule-Based Models " << Endl;
358  Log() << Endl;
359  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
360  Log() << Endl;
361  Log() << Endl;
362  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
363  Log() << Endl;
364  Log() << "<None>" << Endl;
365 }
366