Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Classification.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$ 2017
2 // Authors: Omar Zapata, Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne,
3 // Jan Therhaag
4 
5 #ifndef ROOT_TMVA_Classification
6 #define ROOT_TMVA_Classification
7 
8 #include <TString.h>
9 #include <TMultiGraph.h>
10 
11 #include <TMVA/IMethod.h>
12 #include <TMVA/MethodBase.h>
13 #include <TMVA/Configurable.h>
14 #include <TMVA/Types.h>
15 #include <TMVA/DataSet.h>
16 #include <TMVA/Event.h>
17 #include <TMVA/Results.h>
19 #include <TMVA/ResultsMulticlass.h>
20 #include <TMVA/Factory.h>
21 #include <TMVA/DataLoader.h>
22 #include <TMVA/OptionMap.h>
23 #include <TMVA/Envelope.h>
24 
25 /*! \class TMVA::ClassificationResult
26  * Class to save the results of the classifier.
27  * Every machine learning method booked have an object for the results
28  * in the classification process, in this class is stored the mvas,
29  * data loader name and ml method name and title.
30  * You can to display the resutls calling the method Show, get the ROC-integral with the
31  * method GetROCIntegral or get the TMVA::ROCCurve object calling GetROC.
32 \ingroup TMVA
33 */
34 
35 /*! \class TMVA::Classification
36  * Class to perform two class classification.
37  * The first step before any analysis is to preperate the data,
38  * to do that you need to create an object of TMVA::DataLoader,
39  * in this object you need to configure the variables and the number of events
40  * to train/test.
41  * The class TMVA::Experimental::Classification needs a TMVA::DataLoader object,
42  * optional a TFile object to save the results and some extra options in a string
43  * like "V:Color:Transformations=I;D;P;U;G:Silent:DrawProgressBar:ModelPersistence:Jobs=2" where:
44  * V = verbose output
45  * Color = coloured screen output
46  * Silent = batch mode: boolean silent flag inhibiting any output from TMVA
47  * Transformations = list of transformations to test.
48  * DrawProgressBar = draw progress bar to display training and testing.
49  * ModelPersistence = to save the trained model in xml or serialized files.
50  * Jobs = number of ml methods to test/train in parallel using MultiProc, requires to call Evaluate method.
51  * Basic example.
52  * \code
53 void classification(UInt_t jobs = 2)
54 {
55  TMVA::Tools::Instance();
56 
57  TFile *input(0);
58  TString fname = "./tmva_class_example.root";
59  if (!gSystem->AccessPathName(fname)) {
60  input = TFile::Open(fname); // check if file in local directory exists
61  } else {
62  TFile::SetCacheFileDir(".");
63  input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
64  }
65  if (!input) {
66  std::cout << "ERROR: could not open data file" << std::endl;
67  exit(1);
68  }
69 
70  // Register the training and test trees
71 
72  TTree *signalTree = (TTree *)input->Get("TreeS");
73  TTree *background = (TTree *)input->Get("TreeB");
74 
75  TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
76 
77  dataloader->AddVariable("myvar1 := var1+var2", 'F');
78  dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
79  dataloader->AddVariable("var3", "Variable 3", "units", 'F');
80  dataloader->AddVariable("var4", "Variable 4", "units", 'F');
81 
82  dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
83  dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
84 
85  // global event weights per tree (see below for setting event-wise weights)
86  Double_t signalWeight = 1.0;
87  Double_t backgroundWeight = 1.0;
88 
89  dataloader->SetBackgroundWeightExpression("weight");
90 
91  TMVA::Experimental::Classification *cl = new TMVA::Experimental::Classification(dataloader, Form("Jobs=%d", jobs));
92 
93  cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
94  "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
95  cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
96 
97  cl->Evaluate(); // Train and Test all methods
98 
99  auto &results = cl->GetResults();
100 
101  TCanvas *c = new TCanvas(Form("ROC"));
102  c->SetTitle("ROC-Integral Curve");
103 
104  auto mg = new TMultiGraph();
105  for (UInt_t i = 0; i < results.size(); i++) {
106  auto roc = results[i].GetROCGraph();
107  roc->SetLineColorAlpha(i + 1, 0.1);
108  mg->Add(roc);
109  }
110  mg->Draw("AL");
111  mg->GetXaxis()->SetTitle(" Signal Efficiency ");
112  mg->GetYaxis()->SetTitle(" Background Rejection ");
113  c->BuildLegend(0.15, 0.15, 0.3, 0.3);
114  c->Draw();
115 
116  delete cl;
117 }
118  * \endcode
119  *
120 \ingroup TMVA
121 */
122 
123 namespace TMVA {
124 class ResultsClassification;
125 namespace Experimental {
126 class ClassificationResult : public TObject {
127  friend class Classification;
128 
129 private:
130  OptionMap fMethod; //
131  TString fDataLoaderName; //
132  std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTrain; // Mvas for two-class classification
133  std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>>
134  fMvaTest; // Mvas for two-class and multiclass classification
135  std::vector<TString> fClassNames; //
136 
137  Bool_t IsMethod(TString methodname, TString methodtitle);
138  Bool_t fIsCuts; // if it is a method cuts need special output
139  Double_t fROCIntegral; //
140 
141 public:
142  ClassificationResult();
143  ClassificationResult(const ClassificationResult &cr);
144  ~ClassificationResult() {}
145 
146  const TString GetMethodName() const { return fMethod.GetValue<TString>("MethodName"); }
147  const TString GetMethodTitle() const { return fMethod.GetValue<TString>("MethodTitle"); }
148  ROCCurve *GetROC(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
149  Double_t GetROCIntegral(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
150  TString GetDataLoaderName() { return fDataLoaderName; }
151  Bool_t IsCutsMethod() { return fIsCuts; }
152 
153  void Show();
154 
155  TGraph *GetROCGraph(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
156  ClassificationResult &operator=(const ClassificationResult &r);
157 
158  ClassDef(ClassificationResult, 3);
159 };
160 
161 class Classification : public Envelope {
162  std::vector<ClassificationResult> fResults; //!
163  std::vector<IMethod *> fIMethods; //! vector of objects with booked methods
164  Types::EAnalysisType fAnalysisType; //!
165  Bool_t fCorrelations; //!
166  Bool_t fROC; //!
167 public:
168  explicit Classification(DataLoader *loader, TFile *file, TString options);
169  explicit Classification(DataLoader *loader, TString options);
170  ~Classification();
171 
172  virtual void Train();
173  virtual void TrainMethod(TString methodname, TString methodtitle);
174  virtual void TrainMethod(Types::EMVA method, TString methodtitle);
175 
176  virtual void Test();
177  virtual void TestMethod(TString methodname, TString methodtitle);
178  virtual void TestMethod(Types::EMVA method, TString methodtitle);
179 
180  virtual void Evaluate();
181 
182  std::vector<ClassificationResult> &GetResults();
183 
184  MethodBase *GetMethod(TString methodname, TString methodtitle);
185 
186 protected:
187  TString GetMethodOptions(TString methodname, TString methodtitle);
188  Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index);
189  Bool_t IsCutsMethod(TMVA::MethodBase *method);
190  TMVA::ROCCurve *
191  GetROC(TMVA::MethodBase *method, UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
192  TMVA::ROCCurve *GetROC(TString methodname, TString methodtitle, UInt_t iClass = 0,
193  TMVA::Types::ETreeType type = TMVA::Types::kTesting);
194 
195  Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass = 0);
196 
197  ClassificationResult &GetResults(TString methodname, TString methodtitle);
198  void CopyFrom(TDirectory *src, TFile *file);
199  void MergeFiles();
200 
201  ClassDef(Classification, 0);
202 };
203 } // namespace Experimental
204 } // namespace TMVA
205 
206 #endif // ROOT_TMVA_Classification