Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
TMVACrossValidationRegression.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// This macro provides an example of how to use TMVA for k-folds cross
5 /// evaluation.
6 ///
7 /// As input data is used a toy-MC sample consisting of two guassian
8 /// distributions.
9 ///
10 /// The output file "TMVA.root" can be analysed with the use of dedicated
11 /// macros (simply say: root -l <macro.C>), which can be conveniently
12 /// invoked through a GUI that will appear at the end of the run of this macro.
13 /// Launch the GUI via the command:
14 ///
15 /// ```
16 /// root -l -e 'TMVA::TMVAGui("TMVA.root")'
17 /// ```
18 ///
19 /// ## Cross Evaluation
20 /// Cross evaluation is a special case of k-folds cross validation where the
21 /// splitting into k folds is computed deterministically. This ensures that the
22 /// a given event will always end up in the same fold.
23 ///
24 /// In addition all resulting classifiers are saved and can be applied to new
25 /// data using `MethodCrossValidation`. One requirement for this to work is a
26 /// splitting function that is evaluated for each event to determine into what
27 /// fold it goes (for training/evaluation) or to what classifier (for
28 /// application).
29 ///
30 /// ## Split Expression
31 /// Cross evaluation uses a deterministic split to partition the data into
32 /// folds called the split expression. The expression can be any valid
33 /// `TFormula` as long as all parts used are defined.
34 ///
35 /// For each event the split expression is evaluated to a number and the event
36 /// is put in the fold corresponding to that number.
37 ///
38 /// It is recommended to always use `%int([NumFolds])` at the end of the
39 /// expression.
40 ///
41 /// The split expression has access to all spectators and variables defined in
42 /// the dataloader. Additionally, the number of folds in the split can be
43 /// accessed with `NumFolds` (or `numFolds`).
44 ///
45 /// ### Example
46 /// ```
47 /// "int(fabs([eventID]))%int([NumFolds])"
48 /// ```
49 ///
50 /// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
51 /// - Package : TMVA
52 /// - Root Macro: TMVACrossValidationRegression
53 ///
54 /// \macro_output
55 /// \macro_code
56 /// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
57 
58 #include <cstdlib>
59 #include <iostream>
60 #include <map>
61 #include <string>
62 
63 #include "TChain.h"
64 #include "TFile.h"
65 #include "TTree.h"
66 #include "TString.h"
67 #include "TObjString.h"
68 #include "TSystem.h"
69 #include "TROOT.h"
70 
71 #include "TMVA/Factory.h"
72 #include "TMVA/DataLoader.h"
73 #include "TMVA/Tools.h"
74 #include "TMVA/TMVAGui.h"
75 #include "TMVA/CrossValidation.h"
76 
77 TFile * getDataFile(TString fname) {
78  TFile *input(0);
79 
80  if (!gSystem->AccessPathName(fname)) {
81  input = TFile::Open(fname); // check if file in local directory exists
82  } else {
83  // if not: download from ROOT server
84  TFile::SetCacheFileDir(".");
85  input = TFile::Open("http://root.cern.ch/files/tmva_reg_example.root", "CACHEREAD");
86  }
87 
88  if (!input) {
89  std::cout << "ERROR: could not open data file " << fname << std::endl;
90  exit(1);
91  }
92 
93  return input;
94 }
95 
96 int TMVACrossValidationRegression()
97 {
98  // This loads the library
99  TMVA::Tools::Instance();
100 
101  // --------------------------------------------------------------------------
102 
103  // Create a ROOT output file where TMVA will store ntuples, histograms, etc.
104  TString outfileName("TMVARegCv.root");
105  TFile * outputFile = TFile::Open(outfileName, "RECREATE");
106 
107  TString infileName("./files/tmva_reg_example.root");
108  TFile * inputFile = getDataFile(infileName);
109 
110  TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset");
111 
112  dataloader->AddVariable("var1", "Variable 1", "units", 'F');
113  dataloader->AddVariable("var2", "Variable 2", "units", 'F');
114 
115  // Add the variable carrying the regression target
116  dataloader->AddTarget("fvalue");
117 
118  TTree * regTree = (TTree*)inputFile->Get("TreeR");
119  dataloader->AddRegressionTree(regTree, 1.0);
120 
121  // Individual events can be weighted
122  // dataloader->SetWeightExpression("weight", "Regression");
123 
124  std::cout << "--- TMVACrossValidationRegression: Using input file: " << inputFile->GetName() << std::endl;
125 
126  // Bypasses the normal splitting mechanism, CV uses a new system for this.
127  // Unfortunately the old system is unhappy if we leave the test set empty so
128  // we ensure that there is at least one event by placing the first event in
129  // it.
130  // You can with the selection cut place a global cut on the defined
131  // variables. Only events passing the cut will be using in training/testing.
132  // Example: `TCut selectionCut = "var1 < 1";`
133  TCut selectionCut = "";
134  dataloader->PrepareTrainingAndTestTree(selectionCut, "nTest_Regression=1"
135  ":SplitMode=Block"
136  ":NormMode=NumEvents"
137  ":!V");
138 
139  // --------------------------------------------------------------------------
140 
141  //
142  // This sets up a CrossValidation class (which wraps a TMVA::Factory
143  // internally) for 2-fold cross validation. The data will be split into the
144  // two folds randomly if `splitExpr` is `""`.
145  //
146  // One can also give a deterministic split using spectator variables. An
147  // example would be e.g. `"int(fabs([spec1]))%int([NumFolds])"`.
148  //
149  UInt_t numFolds = 2;
150  TString analysisType = "Regression";
151  TString splitExpr = "";
152 
153  TString cvOptions = Form("!V"
154  ":!Silent"
155  ":ModelPersistence"
156  ":!FoldFileOutput"
157  ":AnalysisType=%s"
158  ":NumFolds=%i"
159  ":SplitExpr=%s",
160  analysisType.Data(), numFolds, splitExpr.Data());
161 
162  TMVA::CrossValidation cv{"TMVACrossValidationRegression", dataloader, outputFile, cvOptions};
163 
164  // --------------------------------------------------------------------------
165 
166  //
167  // Books a method to use for evaluation
168  //
169  cv.BookMethod(TMVA::Types::kBDT, "BDTG",
170  "!H:!V:NTrees=500:BoostType=Grad:Shrinkage=0.1:"
171  "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=3");
172 
173  // --------------------------------------------------------------------------
174 
175  //
176  // Train, test and evaluate the booked methods.
177  // Evaluates the booked methods once for each fold and aggregates the result
178  // in the specified output file.
179  //
180  cv.Evaluate();
181 
182  // --------------------------------------------------------------------------
183 
184  //
185  // Save the output
186  //
187  outputFile->Close();
188 
189  std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
190  std::cout << "==> TMVACrossValidationRegression is done!" << std::endl;
191 
192  // --------------------------------------------------------------------------
193 
194  //
195  // Launch the GUI for the root macros
196  //
197  if (!gROOT->IsBatch()) {
198  TMVA::TMVAGui(outfileName);
199  }
200 
201  return 0;
202 }
203 
204 //
205 // This is used if the macro is compiled. If run through ROOT with
206 // `root -l -b -q MACRO.C` or similar it is unused.
207 //
208 int main(int argc, char **argv)
209 {
210  TMVACrossValidationRegression();
211 }