78 TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
85 TTree *data =
new TTree();
86 data->Branch(
"x", &x,
"x/F");
87 data->Branch(
"y", &y,
"y/F");
88 data->Branch(
"eventID", &eventID,
"eventID/I");
90 for (Int_t n = 0; n < nPoints; ++n) {
91 x = rng.Gaus(offset, scale);
92 y = rng.Gaus(offset, scale);
102 data->ResetBranchAddresses();
106 int TMVACrossValidation()
109 TMVA::Tools::Instance();
120 TTree *sigTree = genTree(1000, 1.0, 1.0, 100);
121 TTree *bkgTree = genTree(1000, -1.0, 1.0, 101);
124 TString outfileName(
"TMVA.root");
125 TFile *outputFile = TFile::Open(outfileName,
"RECREATE");
129 TMVA::DataLoader *dataloader =
new TMVA::DataLoader(
"dataset");
132 dataloader->AddVariable(
"x",
'F');
133 dataloader->AddVariable(
"y",
'F');
136 dataloader->AddSpectator(
"eventID",
'I');
148 dataloader->AddSignalTree(sigTree, 1.0);
149 dataloader->AddBackgroundTree(bkgTree, 1.0);
156 dataloader->PrepareTrainingAndTestTree(
"",
"",
158 ":nTest_Background=1"
160 ":NormMode=NumEvents"
176 TString analysisType =
"Classification";
177 TString splitType =
"Random";
178 TString splitExpr =
"";
198 TString cvOptions = Form(
"!V"
205 analysisType.Data(), splitType.Data(), numFolds,
208 TMVA::CrossValidation cv{
"TMVACrossValidation", dataloader, outputFile, cvOptions};
215 cv.BookMethod(TMVA::Types::kBDT,
"BDTG",
216 "!H:!V:NTrees=100:MinNodeSize=2.5%:BoostType=Grad"
217 ":NegWeightTreatment=Pray:Shrinkage=0.10:nCuts=20"
220 cv.BookMethod(TMVA::Types::kFisher,
"Fisher",
221 "!H:!V:Fisher:VarTransform=None");
239 for (
auto && result : cv.GetResults()) {
240 std::cout <<
"Summary for method " << cv.GetMethods()[iMethod++].GetValue<TString>(
"MethodName")
242 for (UInt_t iFold = 0; iFold<cv.GetNumFolds(); ++iFold) {
243 std::cout <<
"\tFold " << iFold <<
": "
244 <<
"ROC int: " << result.GetROCValues()[iFold]
246 <<
"BkgEff@SigEff=0.3: " << result.GetEff30Values()[iFold]
258 std::cout <<
"==> Wrote root file: " << outputFile->GetName() << std::endl;
259 std::cout <<
"==> TMVACrossValidation is done!" << std::endl;
266 if (!gROOT->IsBatch()) {
268 cv.GetResults()[0].DrawAvgROCCurve(kTRUE,
"Avg ROC for BDTG");
269 cv.GetResults()[0].DrawAvgROCCurve(kTRUE,
"Avg ROC for Fisher");
272 TMVA::TMVAGui(outfileName);
282 int main(
int argc,
char **argv)
284 TMVACrossValidation();