Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
classification.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva_envelope
3 /// \notebook -nodraw
4 
5 /// \macro_output
6 /// \macro_code
7 
8 
9 #include "TMVA/Factory.h"
10 #include "TMVA/DataLoader.h"
11 #include "TMVA/Tools.h"
12 #include "TMVA/Classification.h"
13 
14 void classification(UInt_t jobs = 4)
15 {
16  TMVA::Tools::Instance();
17 
18  TFile *input(0);
19  TString fname = "./tmva_class_example.root";
20  if (!gSystem->AccessPathName(fname)) {
21  input = TFile::Open(fname); // check if file in local directory exists
22  } else {
23  TFile::SetCacheFileDir(".");
24  input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
25  }
26  if (!input) {
27  std::cout << "ERROR: could not open data file" << std::endl;
28  exit(1);
29  }
30 
31  // Register the training and test trees
32 
33  TTree *signalTree = (TTree *)input->Get("TreeS");
34  TTree *background = (TTree *)input->Get("TreeB");
35 
36  TMVA::DataLoader *dataloader = new TMVA::DataLoader("dataset");
37  // If you wish to modify default settings
38  // (please check "src/Config.h" to see all available global options)
39  //
40  // (TMVA::gConfig().GetVariablePlotting()).fTimesRMS = 8.0;
41  // (TMVA::gConfig().GetIONames()).fWeightFileDir = "myWeightDirectory";
42 
43  // Define the input variables that shall be used for the MVA training
44  // note that you may also use variable expressions, such as: "3*var1/var2*abs(var3)"
45  // [all types of expressions that can also be parsed by TTree::Draw( "expression" )]
46  dataloader->AddVariable("myvar1 := var1+var2", 'F');
47  dataloader->AddVariable("myvar2 := var1-var2", "Expression 2", "", 'F');
48  dataloader->AddVariable("var3", "Variable 3", "units", 'F');
49  dataloader->AddVariable("var4", "Variable 4", "units", 'F');
50 
51  // You can add so-called "Spectator variables", which are not used in the MVA training,
52  // but will appear in the final "TestTree" produced by TMVA. This TestTree will contain the
53  // input variables, the response values of all trained MVAs, and the spectator variables
54 
55  dataloader->AddSpectator("spec1 := var1*2", "Spectator 1", "units", 'F');
56  dataloader->AddSpectator("spec2 := var1*3", "Spectator 2", "units", 'F');
57 
58  // global event weights per tree (see below for setting event-wise weights)
59  Double_t signalWeight = 1.0;
60  Double_t backgroundWeight = 1.0;
61 
62  // You can add an arbitrary number of signal or background trees
63  dataloader->AddSignalTree(signalTree, signalWeight);
64  dataloader->AddBackgroundTree(background, backgroundWeight);
65 
66  // Set individual event weights (the variables must exist in the original TTree)
67  // - for signal : `dataloader->SetSignalWeightExpression ("weight1*weight2");`
68  // - for background: `dataloader->SetBackgroundWeightExpression("weight1*weight2");`
69  dataloader->SetBackgroundWeightExpression("weight");
70  dataloader->PrepareTrainingAndTestTree(
71  "", "", "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V");
72 
73  TFile *outputFile = TFile::Open("TMVAClass.root", "RECREATE");
74 
75  TMVA::Experimental::Classification *cl = new TMVA::Experimental::Classification(dataloader, Form("Jobs=%d", jobs));
76 
77  cl->BookMethod(TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
78  "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
79  cl->BookMethod(TMVA::Types::kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm");
80 
81  cl->BookMethod(TMVA::Types::kBDT, "BDTB", "!H:!V:NTrees=2000:BoostType=Bagging:SeparationType=GiniIndex:nCuts=20");
82 
83  cl->BookMethod(TMVA::Types::kCuts, "Cuts", "!H:!V:FitMethod=MC:EffSel:SampleSize=200000:VarProp=FSmart");
84 
85  cl->Evaluate(); // Train and Test all methods
86 
87  auto &results = cl->GetResults();
88 
89  TCanvas *c = new TCanvas(Form("ROC"));
90  c->SetTitle("ROC-Integral Curve");
91 
92  auto mg = new TMultiGraph();
93  for (UInt_t i = 0; i < results.size(); i++) {
94  if (!results[i].IsCutsMethod()) {
95  auto roc = results[i].GetROCGraph();
96  roc->SetLineColorAlpha(i + 1, 0.1);
97  mg->Add(roc);
98  }
99  }
100  mg->Draw("AL");
101  mg->GetXaxis()->SetTitle(" Signal Efficiency ");
102  mg->GetYaxis()->SetTitle(" Background Rejection ");
103  c->BuildLegend(0.15, 0.15, 0.3, 0.3);
104  c->Draw();
105 
106  outputFile->Close();
107  delete cl;
108 }