Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
tmva003_RReader.C
Go to the documentation of this file.
1 /// \file
2 /// \ingroup tutorial_tmva
3 /// \notebook -nodraw
4 /// This tutorial shows how to apply with the modern interfaces models saved in
5 /// TMVA XML files.
6 ///
7 /// \macro_code
8 /// \macro_output
9 ///
10 /// \date July 2019
11 /// \author Stefan Wunsch
12 
13 using namespace TMVA::Experimental;
14 
15 void train(const std::string &filename)
16 {
17  // Create factory
18  auto output = TFile::Open("TMVA.root", "RECREATE");
19  auto factory = new TMVA::Factory("tmva003",
20  output, "!V:!DrawProgressBar:AnalysisType=Classification");
21 
22  // Open trees with signal and background events
23  auto data = TFile::Open(filename.c_str());
24  auto signal = (TTree *)data->Get("TreeS");
25  auto background = (TTree *)data->Get("TreeB");
26 
27  // Add variables and register the trees with the dataloader
28  auto dataloader = new TMVA::DataLoader("tmva003_BDT");
29  const std::vector<std::string> variables = {"var1", "var2", "var3", "var4"};
30  for (const auto &var : variables) {
31  dataloader->AddVariable(var);
32  }
33  dataloader->AddSignalTree(signal, 1.0);
34  dataloader->AddBackgroundTree(background, 1.0);
35  dataloader->PrepareTrainingAndTestTree("", "");
36 
37  // Train a TMVA method
38  factory->BookMethod(dataloader, TMVA::Types::kBDT, "BDT", "!V:!H:NTrees=300:MaxDepth=2");
39  factory->TrainAllMethods();
40 }
41 
42 void tmva003_RReader()
43 {
44  // First, let's train a model with TMVA.
45  const std::string filename = "http://root.cern.ch/files/tmva_class_example.root";
46  train(filename);
47 
48  // Next, we load the model from the TMVA XML file.
49  RReader model("tmva003_BDT/weights/tmva003_BDT.weights.xml");
50 
51  // In case you need a reminder of the names and order of the variables during
52  // training, you can ask the model for it.
53  auto variables = model.GetVariableNames();
54 
55  // The model can now be applied in different scenarios:
56  // 1) Event-by-event inference
57  // 2) Batch inference on data of multiple events
58  // 3) Inference as part of an RDataFrame graph
59 
60  // 1) Event-by-event inference
61  // The event-by-event inference takes the values of the variables as a std::vector<float>.
62  // Note that the return value is as well a std::vector<float> since the reader
63  // is also capable to process models with multiple outputs.
64  auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5});
65  std::cout << "Single-event inference: " << prediction[0] << "\n\n";
66 
67  // 2) Batch inference on data of multiple events
68  // For batch inference, the data needs to be structured as a matrix. For this
69  // purpose, TMVA makes use of the RTensor class. For convenience, we use RDataFrame
70  // and the AsTensor utility to make the read-out from the ROOT file.
71  ROOT::RDataFrame df("TreeS", filename);
72  auto df2 = df.Range(3); // Read only a small subset of the dataset
73  auto x = AsTensor<float>(df2, variables);
74  auto y = model.Compute(x);
75 
76  std::cout << "RTensor input for inference on data of multiple events:\n" << x << "\n\n";
77  std::cout << "Prediction performed on multiple events: " << y << "\n\n";
78 
79  // 3) Perform inference as part of an RDataFrame graph
80  // We write a small lambda function that performs for us the inference on
81  // a dataframe to omit code duplication.
82  auto make_histo = [&](const std::string &treename) {
83  ROOT::RDataFrame df(treename, filename);
84  auto df2 = df.Define("y", Compute<4, float>(model), variables);
85  return df2.Histo1D({treename.c_str(), ";BDT score;N_{Events}", 30, -0.5, 0.5}, "y");
86  };
87 
88  auto sig = make_histo("TreeS");
89  auto bkg = make_histo("TreeB");
90 
91  // Make plot
92  gStyle->SetOptStat(0);
93  auto c = new TCanvas("", "", 800, 800);
94 
95  sig->SetLineColor(kRed);
96  bkg->SetLineColor(kBlue);
97  sig->SetLineWidth(2);
98  bkg->SetLineWidth(2);
99  bkg->Draw("HIST");
100  sig->Draw("HIST SAME");
101 
102  TLegend legend(0.7, 0.7, 0.89, 0.89);
103  legend.SetBorderSize(0);
104  legend.AddEntry("TreeS", "Signal", "l");
105  legend.AddEntry("TreeB", "Background", "l");
106  legend.Draw();
107 
108  c->DrawClone();
109 }