Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
TMVACrossValidationApplication.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// This macro provides an example of how to use TMVA for k-folds cross
5 /// evaluation in application.
6 ///
7 /// This requires that CrossValidation was run with a deterministic split, such
8 /// as `"...:splitExpr=int([eventID])%int([numFolds]):..."`.
9 ///
10 /// - Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
11 /// - Package : TMVA
12 /// - Root Macro: TMVACrossValidationApplication
13 ///
14 /// \macro_output
15 /// \macro_code
16 /// \author Kim Albertsson (adapted from code originally by Andreas Hoecker)
17 
18 #include <cstdlib>
19 #include <iostream>
20 #include <map>
21 #include <string>
22 
23 #include "TChain.h"
24 #include "TFile.h"
25 #include "TTree.h"
26 #include "TString.h"
27 #include "TObjString.h"
28 #include "TSystem.h"
29 #include "TROOT.h"
30 
31 #include "TMVA/Factory.h"
32 #include "TMVA/DataLoader.h"
33 #include "TMVA/Tools.h"
34 #include "TMVA/TMVAGui.h"
35 
36 // Helper function to load data into TTrees.
37 TTree *fillTree(TTree * tree, Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
38 {
39  TRandom3 rng(seed);
40  Float_t x = 0;
41  Float_t y = 0;
42  Int_t eventID = 0;
43 
44  tree->SetBranchAddress("x", &x);
45  tree->SetBranchAddress("y", &y);
46  tree->SetBranchAddress("eventID", &eventID);
47 
48  for (Int_t n = 0; n < nPoints; ++n) {
49  x = rng.Gaus(offset, scale);
50  y = rng.Gaus(offset, scale);
51 
52  // For our simple example it is enough that the id's are uniformly
53  // distributed and independent of the data.
54  ++eventID;
55 
56  tree->Fill();
57  }
58 
59  // Important: Disconnects the tree from the memory locations of x and y.
60  tree->ResetBranchAddresses();
61  return tree;
62 }
63 
64 int TMVACrossValidationApplication()
65 {
66  // This loads the library
67  TMVA::Tools::Instance();
68 
69  // Set up the TMVA::Reader
70  TMVA::Reader *reader = new TMVA::Reader("!Color:!Silent:!V");
71 
72  Float_t x;
73  Float_t y;
74  Int_t eventID;
75 
76  reader->AddVariable("x", &x);
77  reader->AddVariable("y", &y);
78  reader->AddSpectator("eventID", &eventID);
79 
80  // Book the serialised methods
81  TString jobname("TMVACrossEvaluation");
82  {
83  TString methodName = "BDTG";
84  TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
85 
86  Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
87  if (weightfileExists) {
88  reader->BookMVA(methodName, weightfile);
89  } else {
90  std::cout << "Weightfile for method " << methodName << " not found."
91  " Did you run TMVACrossValidation with a specified"
92  " splitExpr?" << std::endl;
93  exit(0);
94  }
95 
96  }
97  {
98  TString methodName = "Fisher";
99  TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");
100 
101  Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
102  if (weightfileExists) {
103  reader->BookMVA(methodName, weightfile);
104  } else {
105  std::cout << "Weightfile for method " << methodName << " not found."
106  " Did you run TMVACrossValidation with a specified"
107  " splitExpr?" << std::endl;
108  exit(0);
109  }
110  }
111 
112  // Load data
113  TTree *tree = new TTree();
114  tree->Branch("x", &x, "x/F");
115  tree->Branch("y", &y, "y/F");
116  tree->Branch("eventID", &eventID, "eventID/I");
117 
118  fillTree(tree, 1000, 1.0, 1.0, 100);
119  fillTree(tree, 1000, -1.0, 1.0, 101);
120  tree->SetBranchAddress("x", &x);
121  tree->SetBranchAddress("y", &y);
122  tree->SetBranchAddress("eventID", &eventID);
123 
124  // Prepare histograms
125  Int_t nbin = 100;
126  TH1F histBDTG{"BDTG", "BDTG", nbin, -1, 1};
127  TH1F histFisher{"Fisher", "Fisher", nbin, -1, 1};
128 
129  // Evaluate classifiers
130  for (Long64_t ievt = 0; ievt < tree->GetEntries(); ievt++) {
131  tree->GetEntry(ievt);
132 
133  Double_t valBDTG = reader->EvaluateMVA("BDTG");
134  Double_t valFisher = reader->EvaluateMVA("Fisher");
135 
136  histBDTG.Fill(valBDTG);
137  histFisher.Fill(valFisher);
138  }
139 
140  tree->ResetBranchAddresses();
141  delete tree;
142 
143  { // Write histograms to output file
144  TFile *target = new TFile("TMVACrossEvaluationApp.root", "RECREATE");
145  histBDTG.Write();
146  histFisher.Write();
147  target->Close();
148  delete target;
149  }
150 
151  delete reader;
152 
153  return 0;
154 }
155 
156 //
157 // This is used if the macro is compiled. If run through ROOT with
158 // `root -l -b -q MACRO.C` or similar it is unused.
159 //
160 int main(int argc, char **argv)
161 {
162  TMVACrossValidationApplication();
163 }