14 void classification(UInt_t jobs = 4)
16 TMVA::Tools::Instance();
19 TString fname =
"./tmva_class_example.root";
20 if (!gSystem->AccessPathName(fname)) {
21 input = TFile::Open(fname);
23 TFile::SetCacheFileDir(
".");
24 input = TFile::Open(
"http://root.cern.ch/files/tmva_class_example.root",
"CACHEREAD");
27 std::cout <<
"ERROR: could not open data file" << std::endl;
33 TTree *signalTree = (TTree *)input->Get(
"TreeS");
34 TTree *background = (TTree *)input->Get(
"TreeB");
36 TMVA::DataLoader *dataloader =
new TMVA::DataLoader(
"dataset");
46 dataloader->AddVariable(
"myvar1 := var1+var2",
'F');
47 dataloader->AddVariable(
"myvar2 := var1-var2",
"Expression 2",
"",
'F');
48 dataloader->AddVariable(
"var3",
"Variable 3",
"units",
'F');
49 dataloader->AddVariable(
"var4",
"Variable 4",
"units",
'F');
55 dataloader->AddSpectator(
"spec1 := var1*2",
"Spectator 1",
"units",
'F');
56 dataloader->AddSpectator(
"spec2 := var1*3",
"Spectator 2",
"units",
'F');
59 Double_t signalWeight = 1.0;
60 Double_t backgroundWeight = 1.0;
63 dataloader->AddSignalTree(signalTree, signalWeight);
64 dataloader->AddBackgroundTree(background, backgroundWeight);
69 dataloader->SetBackgroundWeightExpression(
"weight");
70 dataloader->PrepareTrainingAndTestTree(
71 "",
"",
"nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V");
73 TFile *outputFile = TFile::Open(
"TMVAClass.root",
"RECREATE");
75 TMVA::Experimental::Classification *cl =
new TMVA::Experimental::Classification(dataloader, Form(
"Jobs=%d", jobs));
77 cl->BookMethod(TMVA::Types::kBDT,
"BDTG",
"!H:!V:NTrees=2000:MinNodeSize=2.5%:BoostType=Grad:Shrinkage=0.10:"
78 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=2");
79 cl->BookMethod(TMVA::Types::kSVM,
"SVM",
"Gamma=0.25:Tol=0.001:VarTransform=Norm");
81 cl->BookMethod(TMVA::Types::kBDT,
"BDTB",
"!H:!V:NTrees=2000:BoostType=Bagging:SeparationType=GiniIndex:nCuts=20");
83 cl->BookMethod(TMVA::Types::kCuts,
"Cuts",
"!H:!V:FitMethod=MC:EffSel:SampleSize=200000:VarProp=FSmart");
87 auto &results = cl->GetResults();
89 TCanvas *c =
new TCanvas(Form(
"ROC"));
90 c->SetTitle(
"ROC-Integral Curve");
92 auto mg =
new TMultiGraph();
93 for (UInt_t i = 0; i < results.size(); i++) {
94 if (!results[i].IsCutsMethod()) {
95 auto roc = results[i].GetROCGraph();
96 roc->SetLineColorAlpha(i + 1, 0.1);
101 mg->GetXaxis()->SetTitle(
" Signal Efficiency ");
102 mg->GetYaxis()->SetTitle(
" Background Rejection ");
103 c->BuildLegend(0.15, 0.15, 0.3, 0.3);