76 const Bool_t EnforceNormalization__=kTRUE;
78 REGISTER_METHOD(TMlpANN)
80 ClassImp(TMVA::MethodTMlpANN);
85 TMVA::MethodTMlpANN::MethodTMlpANN( const TString& jobName,
86 const TString& methodTitle,
88 const TString& theOption) :
89 TMVA::MethodBase( jobName, Types::kTMlpANN, methodTitle, theData, theOption),
91 fLocalTrainingTree(0),
93 fValidationFraction(0.5),
101 TMVA::MethodTMlpANN::MethodTMlpANN( DataSetInfo& theData,
102 const TString& theWeightFile) :
103 TMVA::MethodBase( Types::kTMlpANN, theData, theWeightFile),
105 fLocalTrainingTree(0),
107 fValidationFraction(0.5),
108 fLearningMethod(
"" )
115 Bool_t TMVA::MethodTMlpANN::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses,
118 if (type == Types::kClassification && numberClasses == 2)
return kTRUE;
126 void TMVA::MethodTMlpANN::Init(
void )
133 TMVA::MethodTMlpANN::~MethodTMlpANN(
void )
135 if (fMLP)
delete fMLP;
141 void TMVA::MethodTMlpANN::CreateMLPOptions( TString layerSpec )
145 while (layerSpec.Length()>0) {
147 if (layerSpec.First(
',')<0) {
152 sToAdd = layerSpec(0,layerSpec.First(
','));
153 layerSpec = layerSpec(layerSpec.First(
',')+1,layerSpec.Length());
156 if (sToAdd.BeginsWith(
"N")) { sToAdd.Remove(0,1); nNodes = GetNvar(); }
157 nNodes += atoi(sToAdd);
158 fHiddenLayer = Form(
"%s%i:", (
const char*)fHiddenLayer, nNodes );
162 std::vector<TString>::iterator itrVar = (*fInputVars).begin();
163 std::vector<TString>::iterator itrVarEnd = (*fInputVars).end();
164 fMLPBuildOptions =
"";
165 for (; itrVar != itrVarEnd; ++itrVar) {
166 if (EnforceNormalization__) fMLPBuildOptions +=
"@";
167 TString myVar = *itrVar; ;
168 fMLPBuildOptions += myVar;
169 fMLPBuildOptions +=
",";
171 fMLPBuildOptions.Chop();
174 fMLPBuildOptions += fHiddenLayer;
175 fMLPBuildOptions +=
"type";
177 Log() << kINFO <<
"Use " << fNcycles <<
" training cycles" << Endl;
178 Log() << kINFO <<
"Use configuration (nodes per hidden layer): " << fHiddenLayer << Endl;
196 void TMVA::MethodTMlpANN::DeclareOptions()
198 DeclareOptionRef( fNcycles = 200,
"NCycles",
"Number of training cycles" );
199 DeclareOptionRef( fLayerSpec =
"N,N-1",
"HiddenLayers",
"Specification of hidden layer architecture (N stands for number of variables; any integers may also be used)" );
201 DeclareOptionRef( fValidationFraction = 0.5,
"ValidationFraction",
202 "Fraction of events in training tree used for cross validation" );
204 DeclareOptionRef( fLearningMethod =
"Stochastic",
"LearningMethod",
"Learning method" );
205 AddPreDefVal( TString(
"Stochastic") );
206 AddPreDefVal( TString(
"Batch") );
207 AddPreDefVal( TString(
"SteepestDescent") );
208 AddPreDefVal( TString(
"RibierePolak") );
209 AddPreDefVal( TString(
"FletcherReeves") );
210 AddPreDefVal( TString(
"BFGS") );
216 void TMVA::MethodTMlpANN::ProcessOptions()
218 CreateMLPOptions(fLayerSpec);
220 if (IgnoreEventsWithNegWeightsInTraining()) {
221 Log() << kFATAL <<
"Mechanism to ignore events with negative weights in training not available for method"
222 << GetMethodTypeName()
223 <<
" --> please remove \"IgnoreNegWeightsInTraining\" option from booking string."
231 Double_t TMVA::MethodTMlpANN::GetMvaValue( Double_t* err, Double_t* errUpper )
233 const Event* ev = GetEvent();
234 TTHREAD_TLS_DECL_ARG(Double_t*, d,
new Double_t[Data()->GetNVariables()]);
236 for (UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
237 d[ivar] = (Double_t)ev->GetValue(ivar);
239 Double_t mvaVal = fMLP->Evaluate(0,d);
242 NoErrorCalc(err, errUpper);
261 void TMVA::MethodTMlpANN::Train(
void )
265 const Long_t basketsize = 128000;
266 Float_t* vArr =
new Float_t[GetNvar()];
268 TTree *localTrainingTree =
new TTree(
"TMLPtrain",
"Local training tree for TMlpANN" );
269 localTrainingTree->Branch(
"type", &type,
"type/I", basketsize );
270 localTrainingTree->Branch(
"weight", &weight,
"weight/F", basketsize );
272 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
273 const char* myVar = GetInternalVarName(ivar).Data();
274 localTrainingTree->Branch( myVar, &vArr[ivar], Form(
"Var%02i/F", ivar), basketsize );
277 for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
278 const Event *ev = GetEvent(ievt);
279 for (UInt_t i=0; i<GetNvar(); i++) {
280 vArr[i] = ev->GetValue( i );
282 type = DataInfo().IsSignal( ev ) ? 1 : 0;
283 weight = ev->GetWeight();
284 localTrainingTree->Fill();
291 TString trainList =
"Entry$<";
292 trainList += 1.0-fValidationFraction;
294 trainList += (Int_t)Data()->GetNEvtSigTrain();
295 trainList +=
" || (Entry$>";
296 trainList += (Int_t)Data()->GetNEvtSigTrain();
297 trainList +=
" && Entry$<";
298 trainList += (Int_t)(Data()->GetNEvtSigTrain() + (1.0 - fValidationFraction)*Data()->GetNEvtBkgdTrain());
300 TString testList = TString(
"!(") + trainList +
")";
303 Log() << kHEADER <<
"Requirement for training events: \"" << trainList <<
"\"" << Endl;
304 Log() << kINFO <<
"Requirement for validation events: \"" << testList <<
"\"" << Endl;
309 if (fMLP != 0) {
delete fMLP; fMLP = 0; }
310 fMLP =
new TMultiLayerPerceptron( fMLPBuildOptions.Data(),
314 fMLP->SetEventWeight(
"weight" );
317 TMultiLayerPerceptron::ELearningMethod learningMethod = TMultiLayerPerceptron::kStochastic;
319 fLearningMethod.ToLower();
320 if (fLearningMethod ==
"stochastic" ) learningMethod = TMultiLayerPerceptron::kStochastic;
321 else if (fLearningMethod ==
"batch" ) learningMethod = TMultiLayerPerceptron::kBatch;
322 else if (fLearningMethod ==
"steepestdescent" ) learningMethod = TMultiLayerPerceptron::kSteepestDescent;
323 else if (fLearningMethod ==
"ribierepolak" ) learningMethod = TMultiLayerPerceptron::kRibierePolak;
324 else if (fLearningMethod ==
"fletcherreeves" ) learningMethod = TMultiLayerPerceptron::kFletcherReeves;
325 else if (fLearningMethod ==
"bfgs" ) learningMethod = TMultiLayerPerceptron::kBFGS;
327 Log() << kFATAL <<
"Unknown Learning Method: \"" << fLearningMethod <<
"\"" << Endl;
329 fMLP->SetLearningMethod( learningMethod );
332 fMLP->Train(fNcycles,
"" );
336 delete localTrainingTree;
343 void TMVA::MethodTMlpANN::AddWeightsXMLTo(
void* parent )
const
346 void *wght = gTools().AddChild(parent,
"Weights");
347 void* arch = gTools().AddChild( wght,
"Architecture" );
348 gTools().AddAttr( arch,
"BuildOptions", fMLPBuildOptions.Data() );
351 const TString tmpfile=GetWeightFileDir()+
"/TMlp.nn.weights.temp";
352 fMLP->DumpWeights( tmpfile.Data() );
353 std::ifstream inf( tmpfile.Data() );
357 while (inf.getline(temp,256)) {
360 if (dummy.BeginsWith(
'#')) {
361 if (ch!=0) gTools().AddRawLine( ch, data.Data() );
362 dummy = dummy.Strip(TString::kLeading,
'#');
363 dummy = dummy(0,dummy.First(
' '));
364 ch = gTools().AddChild(wght, dummy);
368 data += (dummy +
" ");
370 if (ch != 0) gTools().AddRawLine( ch, data.Data() );
379 void TMVA::MethodTMlpANN::ReadWeightsFromXML(
void* wghtnode )
381 void* ch = gTools().GetChild(wghtnode);
382 gTools().ReadAttr( ch,
"BuildOptions", fMLPBuildOptions );
384 ch = gTools().GetNextChild(ch);
385 const TString fname = GetWeightFileDir()+
"/TMlp.nn.weights.temp";
386 std::ofstream fout( fname.Data() );
387 double temp1=0,temp2=0;
389 const char* nodecontent = gTools().GetContent(ch);
390 std::stringstream content(nodecontent);
391 if (strcmp(gTools().GetName(ch),
"input")==0) {
392 fout <<
"#input normalization" << std::endl;
393 while ((content >> temp1) &&(content >> temp2)) {
394 fout << temp1 <<
" " << temp2 << std::endl;
397 if (strcmp(gTools().GetName(ch),
"output")==0) {
398 fout <<
"#output normalization" << std::endl;
399 while ((content >> temp1) &&(content >> temp2)) {
400 fout << temp1 <<
" " << temp2 << std::endl;
403 if (strcmp(gTools().GetName(ch),
"neurons")==0) {
404 fout <<
"#neurons weights" << std::endl;
405 while (content >> temp1) {
406 fout << temp1 << std::endl;
409 if (strcmp(gTools().GetName(ch),
"synapses")==0) {
410 fout <<
"#synapses weights" ;
411 while (content >> temp1) {
412 fout << std::endl << temp1 ;
415 ch = gTools().GetNextChild(ch);
421 TTHREAD_TLS_DECL_ARG(Double_t*, d,
new Double_t[Data()->GetNVariables()]);
422 TTHREAD_TLS(Int_t) type;
425 TTree * dummyTree = new TTree("dummy","Empty dummy tree", 1);
426 for (UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
427 TString vn = DataInfo().GetVariableInfo(ivar).GetInternalName();
428 dummyTree->Branch(Form(
"%s",vn.Data()), d+ivar, Form(
"%s/D",vn.Data()));
430 dummyTree->Branch(
"type", &type,
"type/I");
432 if (fMLP != 0) {
delete fMLP; fMLP = 0; }
433 fMLP =
new TMultiLayerPerceptron( fMLPBuildOptions.Data(), dummyTree );
434 fMLP->LoadWeights( fname );
442 void TMVA::MethodTMlpANN::ReadWeightsFromStream( std::istream& istr )
444 std::ofstream fout(
"./TMlp.nn.weights.temp" );
445 fout << istr.rdbuf();
449 Log() << kINFO <<
"Load TMLP weights into " << fMLP << Endl;
451 Double_t* d =
new Double_t[Data()->GetNVariables()] ;
454 TTree * dummyTree =
new TTree(
"dummy",
"Empty dummy tree", 1);
455 for (UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
456 TString vn = DataInfo().GetVariableInfo(ivar).GetLabel();
457 dummyTree->Branch(Form(
"%s",vn.Data()), d+ivar, Form(
"%s/D",vn.Data()));
459 dummyTree->Branch(
"type", &type,
"type/I");
461 if (fMLP != 0) {
delete fMLP; fMLP = 0; }
462 fMLP =
new TMultiLayerPerceptron( fMLPBuildOptions.Data(), dummyTree );
464 fMLP->LoadWeights(
"./TMlp.nn.weights.temp" );
474 void TMVA::MethodTMlpANN::MakeClass(
const TString& theClassFileName )
const
477 TString classFileName =
"";
478 if (theClassFileName ==
"")
479 classFileName = GetWeightFileDir() +
"/" + GetJobName() +
"_" + GetMethodName() +
".class";
481 classFileName = theClassFileName;
483 classFileName.ReplaceAll(
".class",
"");
484 Log() << kINFO <<
"Creating specific (TMultiLayerPerceptron) standalone response class: " << classFileName << Endl;
485 fMLP->Export( classFileName.Data() );
492 void TMVA::MethodTMlpANN::MakeClassSpecific( std::ostream& ,
const TString& )
const
502 void TMVA::MethodTMlpANN::GetHelpMessage()
const
505 Log() << gTools().Color(
"bold") <<
"--- Short description:" << gTools().Color(
"reset") << Endl;
507 Log() <<
"This feed-forward multilayer perceptron neural network is the " << Endl;
508 Log() <<
"standard implementation distributed with ROOT (class TMultiLayerPerceptron)." << Endl;
510 Log() <<
"Detailed information is available here:" << Endl;
511 if (gConfig().WriteOptionsReference()) {
512 Log() <<
"<a href=\"http://root.cern.ch/root/html/TMultiLayerPerceptron.html\">";
513 Log() <<
"http://root.cern.ch/root/html/TMultiLayerPerceptron.html</a>" << Endl;
515 else Log() <<
"http://root.cern.ch/root/html/TMultiLayerPerceptron.html" << Endl;