13 using namespace TMVA::Experimental;
15 void train(
const std::string &filename)
18 auto output = TFile::Open(
"TMVA.root",
"RECREATE");
19 auto factory =
new TMVA::Factory(
"tmva003",
20 output,
"!V:!DrawProgressBar:AnalysisType=Classification");
23 auto data = TFile::Open(filename.c_str());
24 auto signal = (TTree *)data->Get(
"TreeS");
25 auto background = (TTree *)data->Get(
"TreeB");
28 auto dataloader =
new TMVA::DataLoader(
"tmva003_BDT");
29 const std::vector<std::string> variables = {
"var1",
"var2",
"var3",
"var4"};
30 for (
const auto &var : variables) {
31 dataloader->AddVariable(var);
33 dataloader->AddSignalTree(signal, 1.0);
34 dataloader->AddBackgroundTree(background, 1.0);
35 dataloader->PrepareTrainingAndTestTree(
"",
"");
38 factory->BookMethod(dataloader, TMVA::Types::kBDT,
"BDT",
"!V:!H:NTrees=300:MaxDepth=2");
39 factory->TrainAllMethods();
42 void tmva003_RReader()
45 const std::string filename =
"http://root.cern.ch/files/tmva_class_example.root";
49 RReader model(
"tmva003_BDT/weights/tmva003_BDT.weights.xml");
53 auto variables = model.GetVariableNames();
64 auto prediction = model.Compute({0.5, 1.0, -0.2, 1.5});
65 std::cout <<
"Single-event inference: " << prediction[0] <<
"\n\n";
71 ROOT::RDataFrame df(
"TreeS", filename);
72 auto df2 = df.Range(3);
73 auto x = AsTensor<float>(df2, variables);
74 auto y = model.Compute(x);
76 std::cout <<
"RTensor input for inference on data of multiple events:\n" << x <<
"\n\n";
77 std::cout <<
"Prediction performed on multiple events: " << y <<
"\n\n";
82 auto make_histo = [&](
const std::string &treename) {
83 ROOT::RDataFrame df(treename, filename);
84 auto df2 = df.Define(
"y", Compute<4, float>(model), variables);
85 return df2.Histo1D({treename.c_str(),
";BDT score;N_{Events}", 30, -0.5, 0.5},
"y");
88 auto sig = make_histo(
"TreeS");
89 auto bkg = make_histo(
"TreeB");
92 gStyle->SetOptStat(0);
93 auto c =
new TCanvas(
"",
"", 800, 800);
95 sig->SetLineColor(kRed);
96 bkg->SetLineColor(kBlue);
100 sig->Draw(
"HIST SAME");
102 TLegend legend(0.7, 0.7, 0.89, 0.89);
103 legend.SetBorderSize(0);
104 legend.AddEntry(
"TreeS",
"Signal",
"l");
105 legend.AddEntry(
"TreeB",
"Background",
"l");