43 REGISTER_METHOD(RSNNS)
45 ClassImp(MethodRSNNS);
48 Bool_t MethodRSNNS::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("RSNNS");
51 MethodRSNNS::MethodRSNNS(const TString &jobName,
52 const TString &methodTitle,
54 const TString &theOption) :
55 RMethodBase(jobName, Types::kRSNNS, methodTitle, dsi, theOption),
59 asfactor("as.factor"),
62 fNetType = methodTitle;
63 if (fNetType !=
"RMLP") {
64 Log() << kFATAL <<
" Unknow Method" + fNetType
74 fInitFunc =
"Randomize_Weights";
75 fInitFuncParams =
"c(-0.3,0.3)";
77 fLearnFunc =
"Std_Backpropagation";
78 fLearnFuncParams =
"c(0.2,0)";
80 fUpdateFunc =
"Topological_Order";
81 fUpdateFuncParams =
"c(0)";
83 fHiddenActFunc =
"Act_Logistic";
84 fShufflePatterns = kTRUE;
87 fPruneFuncParams =
"NULL";
92 MethodRSNNS::MethodRSNNS(DataSetInfo &theData,
const TString &theWeightFile)
93 : RMethodBase(Types::kRSNNS, theData, theWeightFile),
97 asfactor(
"as.factor"),
102 if (fNetType !=
"RMLP") {
103 Log() << kFATAL <<
" Unknow Method = " + fNetType
113 fInitFunc =
"Randomize_Weights";
114 fInitFuncParams =
"c(-0.3,0.3)";
116 fLearnFunc =
"Std_Backpropagation";
117 fLearnFuncParams =
"c(0.2,0)";
119 fUpdateFunc =
"Topological_Order";
120 fUpdateFuncParams =
"c(0)";
122 fHiddenActFunc =
"Act_Logistic";
123 fShufflePatterns = kTRUE;
126 fPruneFuncParams =
"NULL";
131 MethodRSNNS::~MethodRSNNS(
void)
133 if (fModel)
delete fModel;
137 Bool_t MethodRSNNS::HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
139 if (type == Types::kClassification && numberClasses == 2)
return kTRUE;
145 void MethodRSNNS::Init()
147 if (!IsModuleLoaded) {
148 Error(
"Init",
"R's package RSNNS can not be loaded.");
149 Log() << kFATAL <<
" R's package RSNNS can not be loaded."
155 UInt_t size = fFactorTrain.size();
156 fFactorNumeric.resize(size);
158 for (UInt_t i = 0; i < size; i++) {
159 if (fFactorTrain[i] ==
"signal") fFactorNumeric[i] = 1;
160 else fFactorNumeric[i] = 0;
164 void MethodRSNNS::Train()
166 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL <<
"<Train> Data() has zero events" << Endl;
167 if (fNetType ==
"RMLP") {
168 ROOT::R::TRObject PruneFunc;
169 if (fPruneFunc ==
"NULL") PruneFunc = r.Eval(
"NULL");
170 else PruneFunc = r.Eval(Form(
"'%s'", fPruneFunc.Data()));
172 SEXP Model = mlp(ROOT::R::Label[
"x"] = fDfTrain,
173 ROOT::R::Label[
"y"] = fFactorNumeric,
174 ROOT::R::Label[
"size"] = r.Eval(fSize),
175 ROOT::R::Label[
"maxit"] = fMaxit,
176 ROOT::R::Label[
"initFunc"] = fInitFunc,
177 ROOT::R::Label[
"initFuncParams"] = r.Eval(fInitFuncParams),
178 ROOT::R::Label[
"learnFunc"] = fLearnFunc,
179 ROOT::R::Label[
"learnFuncParams"] = r.Eval(fLearnFuncParams),
180 ROOT::R::Label[
"updateFunc"] = fUpdateFunc,
181 ROOT::R::Label[
"updateFuncParams"] = r.Eval(fUpdateFuncParams),
182 ROOT::R::Label[
"hiddenActFunc"] = fHiddenActFunc,
183 ROOT::R::Label[
"shufflePatterns"] = fShufflePatterns,
184 ROOT::R::Label[
"libOut"] = fLinOut,
185 ROOT::R::Label[
"pruneFunc"] = PruneFunc,
186 ROOT::R::Label[
"pruneFuncParams"] = r.Eval(fPruneFuncParams));
187 fModel =
new ROOT::R::TRObject(Model);
189 if (IsModelPersistence())
191 TString path = GetWeightFileDir() +
"/" + GetName() +
".RData";
193 Log() << gTools().Color(
"bold") <<
"--- Saving State File In:" << gTools().Color(
"reset") << path << Endl;
195 r[
"RMLPModel"] << Model;
196 r <<
"save(RMLPModel,file='" + path +
"')";
202 void MethodRSNNS::DeclareOptions()
206 DeclareOptionRef(fSize,
"Size",
"number of units in the hidden layer(s)");
207 DeclareOptionRef(fMaxit,
"Maxit",
"Maximum of iterations to learn");
209 DeclareOptionRef(fInitFunc,
"InitFunc",
"the initialization function to use");
210 DeclareOptionRef(fInitFuncParams,
"InitFuncParams",
"the parameters for the initialization function");
212 DeclareOptionRef(fLearnFunc,
"LearnFunc",
"the learning function to use");
213 DeclareOptionRef(fLearnFuncParams,
"LearnFuncParams",
"the parameters for the learning function");
215 DeclareOptionRef(fUpdateFunc,
"UpdateFunc",
"the update function to use");
216 DeclareOptionRef(fUpdateFuncParams,
"UpdateFuncParams",
"the parameters for the update function");
218 DeclareOptionRef(fHiddenActFunc,
"HiddenActFunc",
"the activation function of all hidden units");
219 DeclareOptionRef(fShufflePatterns,
"ShufflePatterns",
"should the patterns be shuffled?");
220 DeclareOptionRef(fLinOut,
"LinOut",
"sets the activation function of the output units to linear or logistic");
222 DeclareOptionRef(fPruneFunc,
"PruneFunc",
"the prune function to use");
223 DeclareOptionRef(fPruneFuncParams,
"PruneFuncParams",
"the parameters for the pruning function. Unlike the\
224 other functions, these have to be given in a named list. See\
225 the pruning demos for further explanation.the update function to use");
230 void MethodRSNNS::ProcessOptions()
233 Log() << kERROR <<
" fMaxit <=0... that does not work !! "
234 <<
" I set it to 50 .. just so that the program does not crash"
244 void MethodRSNNS::TestClassification()
246 Log() << kINFO <<
"Testing Classification " << fNetType <<
" METHOD " << Endl;
248 MethodBase::TestClassification();
253 Double_t MethodRSNNS::GetMvaValue(Double_t *errLower, Double_t *errUpper)
255 NoErrorCalc(errLower, errUpper);
257 const TMVA::Event *ev = GetEvent();
258 const UInt_t nvar = DataInfo().GetNVariables();
259 ROOT::R::TRDataFrame fDfEvent;
260 for (UInt_t i = 0; i < nvar; i++) {
261 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
264 if (IsModelPersistence()) ReadModelFromFile();
266 TVectorD result = predict(*fModel, fDfEvent, ROOT::R::Label[
"type"] =
"prob");
267 mvaValue = result[0];
273 std::vector<Double_t> MethodRSNNS::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
275 Long64_t nEvents = Data()->GetNEvents();
276 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
277 if (firstEvt < 0) firstEvt = 0;
279 nEvents = lastEvt-firstEvt;
281 UInt_t nvars = Data()->GetNVariables();
284 Timer timer( nEvents, GetName(), kTRUE );
286 Log() << kINFO<<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Evaluation of " << GetMethodName() <<
" on "
287 << (Data()->GetCurrentType()==Types::kTraining?
"training":
"testing") <<
" sample (" << nEvents <<
" events)" << Endl;
291 std::vector<std::vector<Float_t> > inputData(nvars);
292 for (UInt_t i = 0; i < nvars; i++) {
293 inputData[i] = std::vector<Float_t>(nEvents);
296 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
297 Data()->SetCurrentEvent(ievt);
298 const TMVA::Event *e = Data()->GetEvent();
299 assert(nvars == e->GetNVariables());
300 for (UInt_t i = 0; i < nvars; i++) {
301 inputData[i][ievt] = e->GetValue(i);
307 ROOT::R::TRDataFrame evtData;
308 for (UInt_t i = 0; i < nvars; i++) {
309 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
312 if (IsModelPersistence()) ReadModelFromFile();
314 std::vector<Double_t> mvaValues(nEvents);
315 ROOT::R::TRObject result = predict(*fModel, evtData, ROOT::R::Label[
"type"] =
"prob");
317 mvaValues = result.As<std::vector<Double_t>>();
322 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Elapsed time for evaluation of " << nEvents <<
" events: "
323 << timer.GetElapsedTime() <<
" " << Endl;
332 void TMVA::MethodRSNNS::ReadModelFromFile()
334 ROOT::R::TRInterface::Instance().Require(
"RSNNS");
335 TString path = GetWeightFileDir() +
"/" + GetName() +
".RData";
337 Log() << gTools().Color(
"bold") <<
"--- Loading State File From:" << gTools().Color(
"reset") << path << Endl;
339 r <<
"load('" + path +
"')";
341 r[
"RMLPModel"] >> Model;
342 fModel =
new ROOT::R::TRObject(Model);
348 void MethodRSNNS::GetHelpMessage()
const
355 Log() << gTools().Color(
"bold") <<
"--- Short description:" << gTools().Color(
"reset") << Endl;
357 Log() <<
"Decision Trees and Rule-Based Models " << Endl;
359 Log() << gTools().Color(
"bold") <<
"--- Performance optimisation:" << gTools().Color(
"reset") << Endl;
362 Log() << gTools().Color(
"bold") <<
"--- Performance tuning via configuration options:" << gTools().Color(
"reset") << Endl;
364 Log() <<
"<None>" << Endl;