Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
MethodCompositeBase.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss,Or Cohen
3 
4 /*****************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodCompositeBase *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Virtual base class for all MVA method *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
16  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18  * Or Cohen <orcohenor@gmail.com> - Weizmann Inst., Israel *
19  * *
20  * Copyright (c) 2005: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * LAPP, Annecy, France *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://tmva.sourceforge.net/LICENSE) *
29  *****************************************************************************/
30 
31 /*! \class TMVA::MethodCompositeBase
32 \ingroup TMVA
33 
34 Virtual base class for combining several TMVA method.
35 
36 This class is virtual class meant to combine more than one classifier
37 together. The training of the classifiers is done by classes that are
38 derived from this one, while the saving and loading of weights file
39 and the evaluation is done here.
40 */
41 
43 
44 #include "TMVA/ClassifierFactory.h"
45 #include "TMVA/DataSetInfo.h"
46 #include "TMVA/Factory.h"
47 #include "TMVA/IMethod.h"
48 #include "TMVA/MethodBase.h"
49 #include "TMVA/MethodBoost.h"
50 #include "TMVA/MsgLogger.h"
51 #include "TMVA/Tools.h"
52 #include "TMVA/Types.h"
53 #include "TMVA/Config.h"
54 
55 #include "Riostream.h"
56 #include "TRandom3.h"
57 #include "TMath.h"
58 #include "TObjString.h"
59 
60 #include <algorithm>
61 #include <iomanip>
62 #include <vector>
63 
64 
65 using std::vector;
66 
67 ClassImp(TMVA::MethodCompositeBase);
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 
71 TMVA::MethodCompositeBase::MethodCompositeBase( const TString& jobName,
72  Types::EMVA methodType,
73  const TString& methodTitle,
74  DataSetInfo& theData,
75  const TString& theOption )
76 : TMVA::MethodBase( jobName, methodType, methodTitle, theData, theOption),
77  fCurrentMethodIdx(0), fCurrentMethod(0)
78 {}
79 
80 ////////////////////////////////////////////////////////////////////////////////
81 
82 TMVA::MethodCompositeBase::MethodCompositeBase( Types::EMVA methodType,
83  DataSetInfo& dsi,
84  const TString& weightFile)
85  : TMVA::MethodBase( methodType, dsi, weightFile),
86  fCurrentMethodIdx(0), fCurrentMethod(0)
87 {}
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 /// returns pointer to MVA that corresponds to given method title
91 
92 TMVA::IMethod* TMVA::MethodCompositeBase::GetMethod( const TString &methodTitle ) const
93 {
94  std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin();
95  std::vector<IMethod*>::const_iterator itrMethodEnd = fMethods.end();
96 
97  for (; itrMethod != itrMethodEnd; ++itrMethod) {
98  MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
99  if ( (mva->GetMethodName())==methodTitle ) return mva;
100  }
101  return 0;
102 }
103 
104 ////////////////////////////////////////////////////////////////////////////////
105 /// returns pointer to MVA that corresponds to given method index
106 
107 TMVA::IMethod* TMVA::MethodCompositeBase::GetMethod( const Int_t index ) const
108 {
109  std::vector<IMethod*>::const_iterator itrMethod = fMethods.begin()+index;
110  if (itrMethod<fMethods.end()) return *itrMethod;
111  else return 0;
112 }
113 
114 
115 ////////////////////////////////////////////////////////////////////////////////
116 
117 void TMVA::MethodCompositeBase::AddWeightsXMLTo( void* parent ) const
118 {
119  void* wght = gTools().AddChild(parent, "Weights");
120  gTools().AddAttr( wght, "NMethods", fMethods.size() );
121  for (UInt_t i=0; i< fMethods.size(); i++)
122  {
123  void* methxml = gTools().AddChild( wght, "Method" );
124  MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
125  gTools().AddAttr(methxml,"Index", i );
126  gTools().AddAttr(methxml,"Weight", fMethodWeight[i]);
127  gTools().AddAttr(methxml,"MethodSigCut", method->GetSignalReferenceCut());
128  gTools().AddAttr(methxml,"MethodSigCutOrientation", method->GetSignalReferenceCutOrientation());
129  gTools().AddAttr(methxml,"MethodTypeName", method->GetMethodTypeName());
130  gTools().AddAttr(methxml,"MethodName", method->GetMethodName() );
131  gTools().AddAttr(methxml,"JobName", method->GetJobName());
132  gTools().AddAttr(methxml,"Options", method->GetOptions());
133  if (method->fTransformationPointer)
134  gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("true"));
135  else
136  gTools().AddAttr(methxml,"UseMainMethodTransformation", TString("false"));
137  method->AddWeightsXMLTo(methxml);
138  }
139 }
140 
141 ////////////////////////////////////////////////////////////////////////////////
142 /// delete methods
143 
144 TMVA::MethodCompositeBase::~MethodCompositeBase( void )
145 {
146  std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
147  for (; itrMethod != fMethods.end(); ++itrMethod) {
148  Log() << kVERBOSE << "Delete method: " << (*itrMethod)->GetName() << Endl;
149  delete (*itrMethod);
150  }
151  fMethods.clear();
152 }
153 
154 ////////////////////////////////////////////////////////////////////////////////
155 /// XML streamer
156 
157 void TMVA::MethodCompositeBase::ReadWeightsFromXML( void* wghtnode )
158 {
159  UInt_t nMethods;
160  TString methodName, methodTypeName, jobName, optionString;
161 
162  for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
163  fMethods.clear();
164  fMethodWeight.clear();
165  gTools().ReadAttr( wghtnode, "NMethods", nMethods );
166  void* ch = gTools().GetChild(wghtnode);
167  for (UInt_t i=0; i< nMethods; i++) {
168  Double_t methodWeight, methodSigCut, methodSigCutOrientation;
169  gTools().ReadAttr( ch, "Weight", methodWeight );
170  gTools().ReadAttr( ch, "MethodSigCut", methodSigCut);
171  gTools().ReadAttr( ch, "MethodSigCutOrientation", methodSigCutOrientation);
172  gTools().ReadAttr( ch, "MethodTypeName", methodTypeName );
173  gTools().ReadAttr( ch, "MethodName", methodName );
174  gTools().ReadAttr( ch, "JobName", jobName );
175  gTools().ReadAttr( ch, "Options", optionString );
176 
177  // Bool_t rerouteTransformation = kFALSE;
178  if (gTools().HasAttr( ch, "UseMainMethodTransformation")) {
179  TString rerouteString("");
180  gTools().ReadAttr( ch, "UseMainMethodTransformation", rerouteString );
181  rerouteString.ToLower();
182  // if (rerouteString=="true")
183  // rerouteTransformation=kTRUE;
184  }
185 
186  //remove trailing "~" to signal that options have to be reused
187  optionString.ReplaceAll("~","");
188  //ignore meta-options for method Boost
189  optionString.ReplaceAll("Boost_","~Boost_");
190  optionString.ReplaceAll("!~","~!");
191 
192  if (i==0){
193  // the cast on MethodBoost is ugly, but a similar line is also in ReadWeightsFromFile --> needs to be fixed later
194  ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodTypeName), methodName, optionString );
195  }
196  fMethods.push_back(
197  ClassifierFactory::Instance().Create(methodTypeName.Data(), jobName, methodName, DataInfo(), optionString));
198 
199  fMethodWeight.push_back(methodWeight);
200  MethodBase* meth = dynamic_cast<MethodBase*>(fMethods.back());
201 
202  if(meth==0)
203  Log() << kFATAL << "Could not read method from XML" << Endl;
204 
205  void* methXML = gTools().GetChild(ch);
206 
207  TString _fFileDir= meth->DataInfo().GetName();
208  _fFileDir+="/"+gConfig().GetIONames().fWeightFileDir;
209  meth->SetWeightFileDir(_fFileDir);
210  meth->SetModelPersistence(IsModelPersistence());
211  meth->SetSilentFile(IsSilentFile());
212  meth->SetupMethod();
213  meth->SetMsgType(kWARNING);
214  meth->ParseOptions();
215  meth->ProcessSetup();
216  meth->CheckSetup();
217  meth->ReadWeightsFromXML(methXML);
218  meth->SetSignalReferenceCut(methodSigCut);
219  meth->SetSignalReferenceCutOrientation(methodSigCutOrientation);
220 
221  meth->RerouteTransformationHandler (&(this->GetTransformationHandler()));
222 
223  ch = gTools().GetNextChild(ch);
224  }
225  //Log() << kINFO << "Reading methods from XML done " << Endl;
226 }
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 /// text streamer
230 
231 void TMVA::MethodCompositeBase::ReadWeightsFromStream( std::istream& istr )
232 {
233  TString var, dummy;
234  TString methodName, methodTitle=GetMethodName(),
235  jobName=GetJobName(),optionString=GetOptions();
236  UInt_t methodNum; Double_t methodWeight;
237  // and read the Weights (BDT coefficients)
238  // coverity[tainted_data_argument]
239  istr >> dummy >> methodNum;
240  Log() << kINFO << "Read " << methodNum << " Classifiers" << Endl;
241  for (UInt_t i=0;i<fMethods.size();i++) delete fMethods[i];
242  fMethods.clear();
243  fMethodWeight.clear();
244  for (UInt_t i=0; i<methodNum; i++) {
245  istr >> dummy >> methodName >> dummy >> fCurrentMethodIdx >> dummy >> methodWeight;
246  if ((UInt_t)fCurrentMethodIdx != i) {
247  Log() << kFATAL << "Error while reading weight file; mismatch MethodIndex="
248  << fCurrentMethodIdx << " i=" << i
249  << " MethodName " << methodName
250  << " dummy " << dummy
251  << " MethodWeight= " << methodWeight
252  << Endl;
253  }
254  if (GetMethodType() != Types::kBoost || i==0) {
255  istr >> dummy >> jobName;
256  istr >> dummy >> methodTitle;
257  istr >> dummy >> optionString;
258  if (GetMethodType() == Types::kBoost)
259  ((TMVA::MethodBoost*)this)->BookMethod( Types::Instance().GetMethodType( methodName), methodTitle, optionString );
260  }
261  else methodTitle=Form("%s (%04i)",GetMethodName().Data(),fCurrentMethodIdx);
262  fMethods.push_back(
263  ClassifierFactory::Instance().Create(methodName.Data(), jobName, methodTitle, DataInfo(), optionString));
264  fMethodWeight.push_back( methodWeight );
265  if(MethodBase* m = dynamic_cast<MethodBase*>(fMethods.back()) )
266  m->ReadWeightsFromStream(istr);
267  }
268 }
269 
270 ////////////////////////////////////////////////////////////////////////////////
271 /// return composite MVA response
272 
273 Double_t TMVA::MethodCompositeBase::GetMvaValue( Double_t* err, Double_t* errUpper )
274 {
275  Double_t mvaValue = 0;
276  for (UInt_t i=0;i< fMethods.size(); i++) mvaValue+=fMethods[i]->GetMvaValue()*fMethodWeight[i];
277 
278  // cannot determine error
279  NoErrorCalc(err, errUpper);
280 
281  return mvaValue;
282 }