Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
TMVACrossValidation.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// This macro provides an example of how to use TMVA for k-folds cross
5 /// evaluation.
6 ///
7 /// As input data is used a toy-MC sample consisting of two guassian
8 /// distributions.
9 ///
10 /// The output file "TMVA.root" can be analysed with the use of dedicated
11 /// macros (simply say: root -l <macro.C>), which can be conveniently
12 /// invoked through a GUI that will appear at the end of the run of this macro.
13 /// Launch the GUI via the command:
14 ///
15 /// ```
16 /// root -l -e 'TMVA::TMVAGui("TMVA.root")'
17 /// ```
18 ///
19 /// ## Cross Evaluation
20 /// Cross evaluation is a special case of k-folds cross validation where the
21 /// splitting into k folds is computed deterministically. This ensures that the
22 /// a given event will always end up in the same fold.
23 ///
24 /// In addition all resulting classifiers are saved and can be applied to new
25 /// data using `MethodCrossValidation`. One requirement for this to work is a
26 /// splitting function that is evaluated for each event to determine into what
27 /// fold it goes (for training/evaluation) or to what classifier (for
28 /// application).
29 ///
30 /// ## Split Expression
31 /// Cross evaluation uses a deterministic split to partition the data into
32 /// folds called the split expression. The expression can be any valid
33 /// `TFormula` as long as all parts used are defined.
34 ///
35 /// For each event the split expression is evaluated to a number and the event
36 /// is put in the fold corresponding to that number.
37 ///
38 /// It is recommended to always use `%int([NumFolds])` at the end of the
39 /// expression.
40 ///
41 /// The split expression has access to all spectators and variables defined in
42 /// the dataloader. Additionally, the number of folds in the split can be
43 /// accessed with `NumFolds` (or `numFolds`).
44 ///
45 /// ### Example
46 /// ```
47 /// "int(fabs([eventID]))%int([NumFolds])"
48 /// ```
49 ///
50 /// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
51 /// - Package : TMVA
52 /// - Root Macro: TMVACrossValidation
53 ///
54 /// \macro_output
55 /// \macro_code
56 /// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
57 
58 #include <cstdlib>
59 #include <iostream>
60 #include <map>
61 #include <string>
62 
63 #include "TChain.h"
64 #include "TFile.h"
65 #include "TTree.h"
66 #include "TString.h"
67 #include "TObjString.h"
68 #include "TSystem.h"
69 #include "TROOT.h"
70 
71 #include "TMVA/CrossValidation.h"
72 #include "TMVA/DataLoader.h"
73 #include "TMVA/Factory.h"
74 #include "TMVA/Tools.h"
75 #include "TMVA/TMVAGui.h"
76 
77 // Helper function to load data into TTrees.
78 TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
79 {
80  TRandom3 rng(seed);
81  Float_t x = 0;
82  Float_t y = 0;
83  UInt_t eventID = 0;
84 
85  TTree *data = new TTree();
86  data->Branch("x", &x, "x/F");
87  data->Branch("y", &y, "y/F");
88  data->Branch("eventID", &eventID, "eventID/I");
89 
90  for (Int_t n = 0; n < nPoints; ++n) {
91  x = rng.Gaus(offset, scale);
92  y = rng.Gaus(offset, scale);
93 
94  // For our simple example it is enough that the id's are uniformly
95  // distributed and independent of the data.
96  ++eventID;
97 
98  data->Fill();
99  }
100 
101  // Important: Disconnects the tree from the memory locations of x and y.
102  data->ResetBranchAddresses();
103  return data;
104 }
105 
106 int TMVACrossValidation()
107 {
108  // This loads the library
109  TMVA::Tools::Instance();
110 
111  // --------------------------------------------------------------------------
112 
113  // Load the data into TTrees. If you load data from file you can use a
114  // variant of
115  // ```
116  // TString filename = "/path/to/file";
117  // TFile * input = TFile::Open( filename );
118  // TTree * signalTree = (TTree*)input->Get("TreeName");
119  // ```
120  TTree *sigTree = genTree(1000, 1.0, 1.0, 100);
121  TTree *bkgTree = genTree(1000, -1.0, 1.0, 101);
122 
123  // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
124  TString outfileName("TMVA.root");
125  TFile *outputFile = TFile::Open(outfileName, "RECREATE");
126 
127  // DataLoader definitions; We declare variables in the tree so that TMVA can
128  // find them. For more information see TMVAClassification tutorial.
129  TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
130 
131  // Data variables
132  dataloader->AddVariable("x", 'F');
133  dataloader->AddVariable("y", 'F');
134 
135  // Spectator used for split
136  dataloader->AddSpectator("eventID", 'I');
137 
138  // NOTE: Currently TMVA treats all input variables, spectators etc as
139  // floats. Thus, if the absolute value of the input is too large
140  // there can be precision loss. This can especially be a problem for
141  // cross validation with large event numbers.
142  // A workaround is to define your splitting variable as:
143  // `dataloader->AddSpectator("eventID := eventID % 4096", 'I');`
144  // where 4096 should be a number much larger than the number of folds
145  // you intend to run with.
146 
147  // Attaches the trees so they can be read from
148  dataloader->AddSignalTree(sigTree, 1.0);
149  dataloader->AddBackgroundTree(bkgTree, 1.0);
150 
151  // The CV mechanism of TMVA splits up the training set into several folds.
152  // The test set is currently left unused. The `nTest_ClassName=1` assigns
153  // one event to the the test set for each class and puts the rest in the
154  // training set. A value of 0 is a special value and would split the
155  // datasets 50 / 50.
156  dataloader->PrepareTrainingAndTestTree("", "",
157  "nTest_Signal=1"
158  ":nTest_Background=1"
159  ":SplitMode=Random"
160  ":NormMode=NumEvents"
161  ":!V");
162 
163  // --------------------------------------------------------------------------
164 
165  //
166  // This sets up a CrossValidation class (which wraps a TMVA::Factory
167  // internally) for 2-fold cross validation.
168  //
169  // The split type can be "Random", "RandomStratified" or "Deterministic".
170  // For the last option, check the comment below. Random splitting randomises
171  // the order of events and distributes events as evenly as possible.
172  // RandomStratified applies the same logic but distributes events within a
173  // class as evenly as possible over the folds.
174  //
175  UInt_t numFolds = 2;
176  TString analysisType = "Classification";
177  TString splitType = "Random";
178  TString splitExpr = "";
179 
180  //
181  // One can also use a custom splitting function for producing the folds.
182  // The example uses a dataset spectator `eventID`.
183  //
184  // The idea here is that eventID should be an event number that is integral,
185  // random and independent of the data, generated only once. This last
186  // property ensures that if a calibration is changed the same event will
187  // still be assigned the same fold.
188  //
189  // This can be used to use the cross validated classifiers in application,
190  // a technique that can simplify statistical analysis.
191  //
192  // If you want to run TMVACrossValidationApplication, make sure you have
193  // run this tutorial with the below line uncommented first.
194  //
195 
196  // TString splitExpr = "int(fabs([eventID]))%int([NumFolds])";
197 
198  TString cvOptions = Form("!V"
199  ":!Silent"
200  ":ModelPersistence"
201  ":AnalysisType=%s"
202  ":SplitType=%s"
203  ":NumFolds=%i"
204  ":SplitExpr=%s",
205  analysisType.Data(), splitType.Data(), numFolds,
206  splitExpr.Data());
207 
208  TMVA::CrossValidation cv{"TMVACrossValidation", dataloader, outputFile, cvOptions};
209 
210  // --------------------------------------------------------------------------
211 
212  //
213  // Books a method to use for evaluation
214  //
215  cv.BookMethod(TMVA::Types::kBDT, "BDTG",
216  "!H:!V:NTrees=100:MinNodeSize=2.5%:BoostType=Grad"
217  ":NegWeightTreatment=Pray:Shrinkage=0.10:nCuts=20"
218  ":MaxDepth=2");
219 
220  cv.BookMethod(TMVA::Types::kFisher, "Fisher",
221  "!H:!V:Fisher:VarTransform=None");
222 
223  // --------------------------------------------------------------------------
224 
225  //
226  // Train, test and evaluate the booked methods.
227  // Evaluates the booked methods once for each fold and aggregates the result
228  // in the specified output file.
229  //
230  cv.Evaluate();
231 
232  // --------------------------------------------------------------------------
233 
234  //
235  // Process some output programatically, printing the ROC score for each
236  // booked method.
237  //
238  size_t iMethod = 0;
239  for (auto && result : cv.GetResults()) {
240  std::cout << "Summary for method " << cv.GetMethods()[iMethod++].GetValue<TString>("MethodName")
241  << std::endl;
242  for (UInt_t iFold = 0; iFold<cv.GetNumFolds(); ++iFold) {
243  std::cout << "\tFold " << iFold << ": "
244  << "ROC int: " << result.GetROCValues()[iFold]
245  << ", "
246  << "BkgEff@SigEff=0.3: " << result.GetEff30Values()[iFold]
247  << std::endl;
248  }
249  }
250 
251  // --------------------------------------------------------------------------
252 
253  //
254  // Save the output
255  //
256  outputFile->Close();
257 
258  std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
259  std::cout << "==> TMVACrossValidation is done!" << std::endl;
260 
261  // --------------------------------------------------------------------------
262 
263  //
264  // Launch the GUI for the root macros
265  //
266  if (!gROOT->IsBatch()) {
267  // Draw cv-specific graphs
268  cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for BDTG");
269  cv.GetResults()[0].DrawAvgROCCurve(kTRUE, "Avg ROC for Fisher");
270 
271  // You can also use the classical gui
272  TMVA::TMVAGui(outfileName);
273  }
274 
275  return 0;
276 }
277 
278 //
279 // This is used if the macro is compiled. If run through ROOT with
280 // `root -l -b -q MACRO.C` or similar it is unused.
281 //
282 int main(int argc, char **argv)
283 {
284  TMVACrossValidation();
285 }