Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
MethodTMlpANN.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne
3 /**********************************************************************************
4  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
5  * Package: TMVA *
6  * Class : MethodTMlpANN *
7  * Web : http://tmva.sourceforge.net *
8  * *
9  * Description: *
10  * Implementation (see header for description) *
11  * *
12  * Authors (alphabetical): *
13  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
14  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
15  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
16  * *
17  * Copyright (c) 2005: *
18  * CERN, Switzerland *
19  * U. of Victoria, Canada *
20  * MPI-K Heidelberg, Germany *
21  * *
22  * Redistribution and use in source and binary forms, with or without *
23  * modification, are permitted according to the terms listed in LICENSE *
24  * (http://tmva.sourceforge.net/LICENSE) *
25  **********************************************************************************/
26 
27 /*! \class TMVA::MethodTMlpANN
28 \ingroup TMVA
29 
30 This is the TMVA TMultiLayerPerceptron interface class. It provides the
31 training and testing the ROOT internal MLP class in the TMVA framework.
32 
33 Available learning methods:<br>
34 
35  - Stochastic
36  - Batch
37  - SteepestDescent
38  - RibierePolak
39  - FletcherReeves
40  - BFGS
41 
42 See the TMultiLayerPerceptron class description
43 for details on this ANN.
44 */
45 
46 #include "TMVA/MethodTMlpANN.h"
47 
48 #include "TMVA/Config.h"
49 #include "TMVA/Configurable.h"
50 #include "TMVA/DataSet.h"
51 #include "TMVA/DataSetInfo.h"
52 #include "TMVA/IMethod.h"
53 #include "TMVA/MethodBase.h"
54 #include "TMVA/MsgLogger.h"
55 #include "TMVA/Types.h"
56 #include "TMVA/VariableInfo.h"
57 
58 #include "TMVA/ClassifierFactory.h"
59 #include "TMVA/Tools.h"
60 
61 #include "Riostream.h"
62 #include "TLeaf.h"
63 #include "TEventList.h"
64 #include "TObjString.h"
65 #include "TROOT.h"
66 #include "TMultiLayerPerceptron.h"
67 
68 #include <cstdlib>
69 #include <iostream>
70 #include <fstream>
71 
72 
73 using std::atoi;
74 
75 // some additional TMlpANN options
76 const Bool_t EnforceNormalization__=kTRUE;
77 
78 REGISTER_METHOD(TMlpANN)
79 
80 ClassImp(TMVA::MethodTMlpANN);
81 
82 ////////////////////////////////////////////////////////////////////////////////
83 /// standard constructor
84 
85  TMVA::MethodTMlpANN::MethodTMlpANN( const TString& jobName,
86  const TString& methodTitle,
87  DataSetInfo& theData,
88  const TString& theOption) :
89  TMVA::MethodBase( jobName, Types::kTMlpANN, methodTitle, theData, theOption),
90  fMLP(0),
91  fLocalTrainingTree(0),
92  fNcycles(100),
93  fValidationFraction(0.5),
94  fLearningMethod( "" )
95 {
96 }
97 
98 ////////////////////////////////////////////////////////////////////////////////
99 /// constructor from weight file
100 
101 TMVA::MethodTMlpANN::MethodTMlpANN( DataSetInfo& theData,
102  const TString& theWeightFile) :
103  TMVA::MethodBase( Types::kTMlpANN, theData, theWeightFile),
104  fMLP(0),
105  fLocalTrainingTree(0),
106  fNcycles(100),
107  fValidationFraction(0.5),
108  fLearningMethod( "" )
109 {
110 }
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 /// TMlpANN can handle classification with 2 classes
114 
115 Bool_t TMVA::MethodTMlpANN::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses,
116  UInt_t /*numberTargets*/ )
117 {
118  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
119  return kFALSE;
120 }
121 
122 
123 ////////////////////////////////////////////////////////////////////////////////
124 /// default initialisations
125 
126 void TMVA::MethodTMlpANN::Init( void )
127 {
128 }
129 
130 ////////////////////////////////////////////////////////////////////////////////
131 /// destructor
132 
133 TMVA::MethodTMlpANN::~MethodTMlpANN( void )
134 {
135  if (fMLP) delete fMLP;
136 }
137 
138 ////////////////////////////////////////////////////////////////////////////////
139 /// translates options from option string into TMlpANN language
140 
141 void TMVA::MethodTMlpANN::CreateMLPOptions( TString layerSpec )
142 {
143  fHiddenLayer = ":";
144 
145  while (layerSpec.Length()>0) {
146  TString sToAdd="";
147  if (layerSpec.First(',')<0) {
148  sToAdd = layerSpec;
149  layerSpec = "";
150  }
151  else {
152  sToAdd = layerSpec(0,layerSpec.First(','));
153  layerSpec = layerSpec(layerSpec.First(',')+1,layerSpec.Length());
154  }
155  int nNodes = 0;
156  if (sToAdd.BeginsWith("N")) { sToAdd.Remove(0,1); nNodes = GetNvar(); }
157  nNodes += atoi(sToAdd);
158  fHiddenLayer = Form( "%s%i:", (const char*)fHiddenLayer, nNodes );
159  }
160 
161  // set input vars
162  std::vector<TString>::iterator itrVar = (*fInputVars).begin();
163  std::vector<TString>::iterator itrVarEnd = (*fInputVars).end();
164  fMLPBuildOptions = "";
165  for (; itrVar != itrVarEnd; ++itrVar) {
166  if (EnforceNormalization__) fMLPBuildOptions += "@";
167  TString myVar = *itrVar; ;
168  fMLPBuildOptions += myVar;
169  fMLPBuildOptions += ",";
170  }
171  fMLPBuildOptions.Chop(); // remove last ","
172 
173  // prepare final options for MLP kernel
174  fMLPBuildOptions += fHiddenLayer;
175  fMLPBuildOptions += "type";
176 
177  Log() << kINFO << "Use " << fNcycles << " training cycles" << Endl;
178  Log() << kINFO << "Use configuration (nodes per hidden layer): " << fHiddenLayer << Endl;
179 }
180 
181 ////////////////////////////////////////////////////////////////////////////////
182 /// define the options (their key words) that can be set in the option string
183 ///
184 /// know options:
185 ///
186 /// - NCycles <integer> Number of training cycles (too many cycles could overtrain the network)
187 /// - HiddenLayers <string> Layout of the hidden layers (nodes per layer)
188 /// * specifications for each hidden layer are separated by comma
189 /// * for each layer the number of nodes can be either absolut (simply a number)
190 /// or relative to the number of input nodes to the neural net (N)
191 /// * there is always a single node in the output layer
192 ///
193 /// example: a net with 6 input nodes and "Hiddenlayers=N-1,N-2" has 6,5,4,1 nodes in the
194 /// layers 1,2,3,4, respectively
195 
196 void TMVA::MethodTMlpANN::DeclareOptions()
197 {
198  DeclareOptionRef( fNcycles = 200, "NCycles", "Number of training cycles" );
199  DeclareOptionRef( fLayerSpec = "N,N-1", "HiddenLayers", "Specification of hidden layer architecture (N stands for number of variables; any integers may also be used)" );
200 
201  DeclareOptionRef( fValidationFraction = 0.5, "ValidationFraction",
202  "Fraction of events in training tree used for cross validation" );
203 
204  DeclareOptionRef( fLearningMethod = "Stochastic", "LearningMethod", "Learning method" );
205  AddPreDefVal( TString("Stochastic") );
206  AddPreDefVal( TString("Batch") );
207  AddPreDefVal( TString("SteepestDescent") );
208  AddPreDefVal( TString("RibierePolak") );
209  AddPreDefVal( TString("FletcherReeves") );
210  AddPreDefVal( TString("BFGS") );
211 }
212 
213 ////////////////////////////////////////////////////////////////////////////////
214 /// builds the neural network as specified by the user
215 
216 void TMVA::MethodTMlpANN::ProcessOptions()
217 {
218  CreateMLPOptions(fLayerSpec);
219 
220  if (IgnoreEventsWithNegWeightsInTraining()) {
221  Log() << kFATAL << "Mechanism to ignore events with negative weights in training not available for method"
222  << GetMethodTypeName()
223  << " --> please remove \"IgnoreNegWeightsInTraining\" option from booking string."
224  << Endl;
225  }
226 }
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 /// calculate the value of the neural net for the current event
230 
231 Double_t TMVA::MethodTMlpANN::GetMvaValue( Double_t* err, Double_t* errUpper )
232 {
233  const Event* ev = GetEvent();
234  TTHREAD_TLS_DECL_ARG(Double_t*, d, new Double_t[Data()->GetNVariables()]);
235 
236  for (UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
237  d[ivar] = (Double_t)ev->GetValue(ivar);
238  }
239  Double_t mvaVal = fMLP->Evaluate(0,d);
240 
241  // cannot determine error
242  NoErrorCalc(err, errUpper);
243 
244  return mvaVal;
245 }
246 
247 ////////////////////////////////////////////////////////////////////////////////
248 /// performs TMlpANN training
249 /// available learning methods:
250 ///
251 /// - TMultiLayerPerceptron::kStochastic
252 /// - TMultiLayerPerceptron::kBatch
253 /// - TMultiLayerPerceptron::kSteepestDescent
254 /// - TMultiLayerPerceptron::kRibierePolak
255 /// - TMultiLayerPerceptron::kFletcherReeves
256 /// - TMultiLayerPerceptron::kBFGS
257 ///
258 /// TMultiLayerPerceptron wants test and training tree at once
259 /// so merge the training and testing trees from the MVA factory first:
260 
261 void TMVA::MethodTMlpANN::Train( void )
262 {
263  Int_t type;
264  Float_t weight;
265  const Long_t basketsize = 128000;
266  Float_t* vArr = new Float_t[GetNvar()];
267 
268  TTree *localTrainingTree = new TTree( "TMLPtrain", "Local training tree for TMlpANN" );
269  localTrainingTree->Branch( "type", &type, "type/I", basketsize );
270  localTrainingTree->Branch( "weight", &weight, "weight/F", basketsize );
271 
272  for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
273  const char* myVar = GetInternalVarName(ivar).Data();
274  localTrainingTree->Branch( myVar, &vArr[ivar], Form("Var%02i/F", ivar), basketsize );
275  }
276 
277  for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
278  const Event *ev = GetEvent(ievt);
279  for (UInt_t i=0; i<GetNvar(); i++) {
280  vArr[i] = ev->GetValue( i );
281  }
282  type = DataInfo().IsSignal( ev ) ? 1 : 0;
283  weight = ev->GetWeight();
284  localTrainingTree->Fill();
285  }
286 
287  // These are the event lists for the mlp train method
288  // first events in the tree are for training
289  // the rest for internal testing (cross validation)...
290  // NOTE: the training events are ordered: first part is signal, second part background
291  TString trainList = "Entry$<";
292  trainList += 1.0-fValidationFraction;
293  trainList += "*";
294  trainList += (Int_t)Data()->GetNEvtSigTrain();
295  trainList += " || (Entry$>";
296  trainList += (Int_t)Data()->GetNEvtSigTrain();
297  trainList += " && Entry$<";
298  trainList += (Int_t)(Data()->GetNEvtSigTrain() + (1.0 - fValidationFraction)*Data()->GetNEvtBkgdTrain());
299  trainList += ")";
300  TString testList = TString("!(") + trainList + ")";
301 
302  // print the requirements
303  Log() << kHEADER << "Requirement for training events: \"" << trainList << "\"" << Endl;
304  Log() << kINFO << "Requirement for validation events: \"" << testList << "\"" << Endl;
305 
306  // localTrainingTree->Print();
307 
308  // create NN
309  if (fMLP != 0) { delete fMLP; fMLP = 0; }
310  fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(),
311  localTrainingTree,
312  trainList,
313  testList );
314  fMLP->SetEventWeight( "weight" );
315 
316  // set learning method
317  TMultiLayerPerceptron::ELearningMethod learningMethod = TMultiLayerPerceptron::kStochastic;
318 
319  fLearningMethod.ToLower();
320  if (fLearningMethod == "stochastic" ) learningMethod = TMultiLayerPerceptron::kStochastic;
321  else if (fLearningMethod == "batch" ) learningMethod = TMultiLayerPerceptron::kBatch;
322  else if (fLearningMethod == "steepestdescent" ) learningMethod = TMultiLayerPerceptron::kSteepestDescent;
323  else if (fLearningMethod == "ribierepolak" ) learningMethod = TMultiLayerPerceptron::kRibierePolak;
324  else if (fLearningMethod == "fletcherreeves" ) learningMethod = TMultiLayerPerceptron::kFletcherReeves;
325  else if (fLearningMethod == "bfgs" ) learningMethod = TMultiLayerPerceptron::kBFGS;
326  else {
327  Log() << kFATAL << "Unknown Learning Method: \"" << fLearningMethod << "\"" << Endl;
328  }
329  fMLP->SetLearningMethod( learningMethod );
330 
331  // train NN
332  fMLP->Train(fNcycles, "" ); //"text,update=50" );
333 
334  // write weights to File;
335  // this is not nice, but fMLP gets deleted at the end of Train()
336  delete localTrainingTree;
337  delete [] vArr;
338 }
339 
340 ////////////////////////////////////////////////////////////////////////////////
341 /// write weights to xml file
342 
343 void TMVA::MethodTMlpANN::AddWeightsXMLTo( void* parent ) const
344 {
345  // first the architecture
346  void *wght = gTools().AddChild(parent, "Weights");
347  void* arch = gTools().AddChild( wght, "Architecture" );
348  gTools().AddAttr( arch, "BuildOptions", fMLPBuildOptions.Data() );
349 
350  // dump weights first in temporary txt file, read from there into xml
351  const TString tmpfile=GetWeightFileDir()+"/TMlp.nn.weights.temp";
352  fMLP->DumpWeights( tmpfile.Data() );
353  std::ifstream inf( tmpfile.Data() );
354  char temp[256];
355  TString data("");
356  void *ch=NULL;
357  while (inf.getline(temp,256)) {
358  TString dummy(temp);
359  //std::cout << dummy << std::endl; // remove annoying debug printout with std::cout
360  if (dummy.BeginsWith('#')) {
361  if (ch!=0) gTools().AddRawLine( ch, data.Data() );
362  dummy = dummy.Strip(TString::kLeading, '#');
363  dummy = dummy(0,dummy.First(' '));
364  ch = gTools().AddChild(wght, dummy);
365  data.Resize(0);
366  continue;
367  }
368  data += (dummy + " ");
369  }
370  if (ch != 0) gTools().AddRawLine( ch, data.Data() );
371 
372  inf.close();
373 }
374 
375 ////////////////////////////////////////////////////////////////////////////////
376 /// rebuild temporary textfile from xml weightfile and load this
377 /// file into MLP
378 
379 void TMVA::MethodTMlpANN::ReadWeightsFromXML( void* wghtnode )
380 {
381  void* ch = gTools().GetChild(wghtnode);
382  gTools().ReadAttr( ch, "BuildOptions", fMLPBuildOptions );
383 
384  ch = gTools().GetNextChild(ch);
385  const TString fname = GetWeightFileDir()+"/TMlp.nn.weights.temp";
386  std::ofstream fout( fname.Data() );
387  double temp1=0,temp2=0;
388  while (ch) {
389  const char* nodecontent = gTools().GetContent(ch);
390  std::stringstream content(nodecontent);
391  if (strcmp(gTools().GetName(ch),"input")==0) {
392  fout << "#input normalization" << std::endl;
393  while ((content >> temp1) &&(content >> temp2)) {
394  fout << temp1 << " " << temp2 << std::endl;
395  }
396  }
397  if (strcmp(gTools().GetName(ch),"output")==0) {
398  fout << "#output normalization" << std::endl;
399  while ((content >> temp1) &&(content >> temp2)) {
400  fout << temp1 << " " << temp2 << std::endl;
401  }
402  }
403  if (strcmp(gTools().GetName(ch),"neurons")==0) {
404  fout << "#neurons weights" << std::endl;
405  while (content >> temp1) {
406  fout << temp1 << std::endl;
407  }
408  }
409  if (strcmp(gTools().GetName(ch),"synapses")==0) {
410  fout << "#synapses weights" ;
411  while (content >> temp1) {
412  fout << std::endl << temp1 ;
413  }
414  }
415  ch = gTools().GetNextChild(ch);
416  }
417  fout.close();;
418 
419  // Here we create a dummy tree necessary to create a minimal NN
420  // to be used for testing, evaluation and application
421  TTHREAD_TLS_DECL_ARG(Double_t*, d, new Double_t[Data()->GetNVariables()]);
422  TTHREAD_TLS(Int_t) type;
423 
424  gROOT->cd();
425  TTree * dummyTree = new TTree("dummy","Empty dummy tree", 1);
426  for (UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
427  TString vn = DataInfo().GetVariableInfo(ivar).GetInternalName();
428  dummyTree->Branch(Form("%s",vn.Data()), d+ivar, Form("%s/D",vn.Data()));
429  }
430  dummyTree->Branch("type", &type, "type/I");
431 
432  if (fMLP != 0) { delete fMLP; fMLP = 0; }
433  fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(), dummyTree );
434  fMLP->LoadWeights( fname );
435 }
436 
437 ////////////////////////////////////////////////////////////////////////////////
438 /// read weights from stream
439 /// since the MLP can not read from the stream, we
440 /// 1st: write the weights to temporary file
441 
442 void TMVA::MethodTMlpANN::ReadWeightsFromStream( std::istream& istr )
443 {
444  std::ofstream fout( "./TMlp.nn.weights.temp" );
445  fout << istr.rdbuf();
446  fout.close();
447  // 2nd: load the weights from the temporary file into the MLP
448  // the MLP is already build
449  Log() << kINFO << "Load TMLP weights into " << fMLP << Endl;
450 
451  Double_t* d = new Double_t[Data()->GetNVariables()] ;
452  Int_t type;
453  gROOT->cd();
454  TTree * dummyTree = new TTree("dummy","Empty dummy tree", 1);
455  for (UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
456  TString vn = DataInfo().GetVariableInfo(ivar).GetLabel();
457  dummyTree->Branch(Form("%s",vn.Data()), d+ivar, Form("%s/D",vn.Data()));
458  }
459  dummyTree->Branch("type", &type, "type/I");
460 
461  if (fMLP != 0) { delete fMLP; fMLP = 0; }
462  fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(), dummyTree );
463 
464  fMLP->LoadWeights( "./TMlp.nn.weights.temp" );
465  // here we can delete the temporary file
466  // how?
467  delete [] d;
468 }
469 
470 ////////////////////////////////////////////////////////////////////////////////
471 /// create reader class for classifier -> overwrites base class function
472 /// create specific class for TMultiLayerPerceptron
473 
474 void TMVA::MethodTMlpANN::MakeClass( const TString& theClassFileName ) const
475 {
476  // the default consists of
477  TString classFileName = "";
478  if (theClassFileName == "")
479  classFileName = GetWeightFileDir() + "/" + GetJobName() + "_" + GetMethodName() + ".class";
480  else
481  classFileName = theClassFileName;
482 
483  classFileName.ReplaceAll(".class","");
484  Log() << kINFO << "Creating specific (TMultiLayerPerceptron) standalone response class: " << classFileName << Endl;
485  fMLP->Export( classFileName.Data() );
486 }
487 
488 ////////////////////////////////////////////////////////////////////////////////
489 /// write specific classifier response
490 /// nothing to do here - all taken care of by TMultiLayerPerceptron
491 
492 void TMVA::MethodTMlpANN::MakeClassSpecific( std::ostream& /*fout*/, const TString& /*className*/ ) const
493 {
494 }
495 
496 ////////////////////////////////////////////////////////////////////////////////
497 /// get help message text
498 ///
499 /// typical length of text line:
500 /// "|--------------------------------------------------------------|"
501 
502 void TMVA::MethodTMlpANN::GetHelpMessage() const
503 {
504  Log() << Endl;
505  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
506  Log() << Endl;
507  Log() << "This feed-forward multilayer perceptron neural network is the " << Endl;
508  Log() << "standard implementation distributed with ROOT (class TMultiLayerPerceptron)." << Endl;
509  Log() << Endl;
510  Log() << "Detailed information is available here:" << Endl;
511  if (gConfig().WriteOptionsReference()) {
512  Log() << "<a href=\"http://root.cern.ch/root/html/TMultiLayerPerceptron.html\">";
513  Log() << "http://root.cern.ch/root/html/TMultiLayerPerceptron.html</a>" << Endl;
514  }
515  else Log() << "http://root.cern.ch/root/html/TMultiLayerPerceptron.html" << Endl;
516  Log() << Endl;
517 }