Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Factory.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3 // Updated by: Omar Zapata, Lorenzo Moneta, Sergei Gleyzer
4 
5 /**********************************************************************************
6  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
7  * Package: TMVA *
8  * Class : Factory *
9  * Web : http://tmva.sourceforge.net *
10  * *
11  * Description: *
12  * This is the main MVA steering class: it creates (books) all MVA methods, *
13  * and guides them through the training, testing and evaluation phases. *
14  * *
15  * Authors (alphabetical): *
16  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
17  * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
18  * Peter Speckmayer <peter.speckmayer@cern.ch> - CERN, Switzerland *
19  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
20  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
21  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
22  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
23  * Omar Zapata <Omar.Zapata@cern.ch> - UdeA/ITM Colombia *
24  * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
25  * Sergei Gleyzer <Sergei.Gleyzer@cern.ch> - U of Florida & CERN *
26  * *
27  * Copyright (c) 2005-2011: *
28  * CERN, Switzerland *
29  * U. of Victoria, Canada *
30  * MPI-K Heidelberg, Germany *
31  * U. of Bonn, Germany *
32  * UdeA/ITM, Colombia *
33  * U. of Florida, USA *
34  * *
35  * Redistribution and use in source and binary forms, with or without *
36  * modification, are permitted according to the terms listed in LICENSE *
37  * (http://tmva.sourceforge.net/LICENSE) *
38  **********************************************************************************/
39 
40 #ifndef ROOT_TMVA_Factory
41 #define ROOT_TMVA_Factory
42 
43 //////////////////////////////////////////////////////////////////////////
44 // //
45 // Factory //
46 // //
47 // This is the main MVA steering class: it creates all MVA methods, //
48 // and guides them through the training, testing and evaluation //
49 // phases //
50 // //
51 //////////////////////////////////////////////////////////////////////////
52 
53 #include <string>
54 #include <vector>
55 #include <map>
56 #include "TCut.h"
57 
58 #include "TMVA/Configurable.h"
59 #include "TMVA/Types.h"
60 #include "TMVA/DataSet.h"
61 
62 class TCanvas;
63 class TDirectory;
64 class TFile;
65 class TGraph;
66 class TH1F;
67 class TMultiGraph;
68 class TTree;
69 namespace TMVA {
70 
71  class IMethod;
72  class MethodBase;
73  class DataInputHandler;
74  class DataSetInfo;
75  class DataSetManager;
76  class DataLoader;
77  class ROCCurve;
78  class VariableTransformBase;
79 
80 
81  class Factory : public Configurable {
82  friend class CrossValidation;
83  public:
84 
85  typedef std::vector<IMethod*> MVector;
86  std::map<TString,MVector*> fMethodsMap;//all methods for every dataset with the same name
87 
88  // no default constructor
89  Factory( TString theJobName, TFile* theTargetFile, TString theOption = "" );
90 
91  // contructor to work without file
92  Factory( TString theJobName, TString theOption = "" );
93 
94  // default destructor
95  virtual ~Factory();
96 
97  // use TName::GetName and define correct name in constructor
98  //virtual const char* GetName() const { return "Factory"; }
99 
100 
101  MethodBase* BookMethod( DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption = "" );
102  MethodBase* BookMethod( DataLoader *loader, Types::EMVA theMethod, TString methodTitle, TString theOption = "" );
103  MethodBase* BookMethod( DataLoader *, TMVA::Types::EMVA /*theMethod*/,
104  TString /*methodTitle*/,
105  TString /*methodOption*/,
106  TMVA::Types::EMVA /*theComposite*/,
107  TString /*compositeOption = ""*/ ) { return 0; }
108 
109  // optimize all booked methods (well, if desired by the method)
110  std::map<TString,Double_t> OptimizeAllMethods (TString fomType="ROCIntegral", TString fitType="FitGA");
111  void OptimizeAllMethodsForClassification(TString fomType="ROCIntegral", TString fitType="FitGA") { OptimizeAllMethods(fomType,fitType); }
112  void OptimizeAllMethodsForRegression (TString fomType="ROCIntegral", TString fitType="FitGA") { OptimizeAllMethods(fomType,fitType); }
113 
114  // training for all booked methods
115  void TrainAllMethods ();
116  void TrainAllMethodsForClassification( void ) { TrainAllMethods(); }
117  void TrainAllMethodsForRegression ( void ) { TrainAllMethods(); }
118 
119  // testing
120  void TestAllMethods();
121 
122  // performance evaluation
123  void EvaluateAllMethods( void );
124  void EvaluateAllVariables(DataLoader *loader, TString options = "" );
125 
126  TH1F* EvaluateImportance( DataLoader *loader,VIType vitype, Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
127 
128  // delete all methods and reset the method vector
129  void DeleteAllMethods( void );
130 
131  // accessors
132  IMethod* GetMethod( const TString& datasetname, const TString& title ) const;
133  Bool_t HasMethod( const TString& datasetname, const TString& title ) const;
134 
135  Bool_t Verbose( void ) const { return fVerbose; }
136  void SetVerbose( Bool_t v=kTRUE );
137 
138  // make ROOT-independent C++ class for classifier response
139  // (classifier-specific implementation)
140  // If no classifier name is given, help messages for all booked
141  // classifiers are printed
142  virtual void MakeClass(const TString& datasetname , const TString& methodTitle = "" ) const;
143 
144  // prints classifier-specific hepl messages, dedicated to
145  // help with the optimisation and configuration options tuning.
146  // If no classifier name is given, help messages for all booked
147  // classifiers are printed
148  void PrintHelpMessage(const TString& datasetname , const TString& methodTitle = "" ) const;
149 
150  TDirectory* RootBaseDir() { return (TDirectory*)fgTargetFile; }
151 
152  Bool_t IsSilentFile() const { return fSilentFile;}
153  Bool_t IsModelPersistence() const { return fModelPersistence; }
154 
155  Double_t GetROCIntegral(DataLoader *loader, TString theMethodName, UInt_t iClass = 0);
156  Double_t GetROCIntegral(TString datasetname, TString theMethodName, UInt_t iClass = 0);
157 
158  // Methods to get a TGraph for an indicated method in dataset.
159  // Optional title and axis added with fLegend=kTRUE.
160  // Argument iClass used in multiclass settings, otherwise ignored.
161  TGraph* GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0);
162  TGraph* GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles=kTRUE, UInt_t iClass=0);
163 
164  // Methods to get a TMultiGraph for a given class and all methods in dataset.
165  TMultiGraph* GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass);
166  TMultiGraph* GetROCCurveAsMultiGraph(TString datasetname, UInt_t iClass);
167 
168  // Draw all ROC curves of a given class for all methods in the dataset.
169  TCanvas* GetROCCurve(DataLoader *loader, UInt_t iClass=0);
170  TCanvas* GetROCCurve(TString datasetname, UInt_t iClass=0);
171 
172  private:
173 
174  // the beautiful greeting message
175  void Greetings();
176 
177  //evaluate the simple case that is removing 1 variable at time
178  TH1F* EvaluateImportanceShort( DataLoader *loader,Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
179  //evaluate all variables combinations
180  TH1F* EvaluateImportanceAll( DataLoader *loader,Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
181  //evaluate randomly given a number of seeds
182  TH1F* EvaluateImportanceRandom( DataLoader *loader,UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );
183 
184  TH1F* GetImportance(const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames);
185 
186  // Helpers for public facing ROC methods
187  ROCCurve *GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass = 0,
188  Types::ETreeType type = Types::kTesting);
189  ROCCurve *GetROC(TString datasetname, TString theMethodName, UInt_t iClass = 0,
190  Types::ETreeType type = Types::kTesting);
191 
192  void WriteDataInformation(DataSetInfo& fDataSetInfo);
193 
194  void SetInputTreesFromEventAssignTrees();
195 
196  MethodBase* BookMethodWeightfile(DataLoader *dataloader, TMVA::Types::EMVA methodType, const TString &weightfile);
197 
198  private:
199 
200  // data members
201 
202  TFile* fgTargetFile; //! ROOT output file
203 
204 
205  std::vector<TMVA::VariableTransformBase*> fDefaultTrfs; //! list of transformations on default DataSet
206 
207  // cd to local directory
208  TString fOptions; //! option string given by construction (presently only "V")
209  TString fTransformations; //! list of transformations to test
210  Bool_t fVerbose; //! verbose mode
211  TString fVerboseLevel; //! verbosity level, controls granularity of logging
212  Bool_t fCorrelations; //! enable to calculate corelations
213  Bool_t fROC; //! enable to calculate ROC values
214  Bool_t fSilentFile; //! used in contructor wihtout file
215 
216  TString fJobName; //! jobname, used as extension in weight file names
217 
218  Types::EAnalysisType fAnalysisType; //! the training type
219  Bool_t fModelPersistence;//! option to save the trained model in xml file or using serialization
220 
221 
222  protected:
223 
224  ClassDef(Factory,0); // The factory creates all MVA methods, and performs their training and testing
225  };
226 
227 } // namespace TMVA
228 
229 #endif