77 TFile * getDataFile(TString fname) {
80 if (!gSystem->AccessPathName(fname)) {
81 input = TFile::Open(fname);
84 TFile::SetCacheFileDir(
".");
85 input = TFile::Open(
"http://root.cern.ch/files/tmva_reg_example.root",
"CACHEREAD");
89 std::cout <<
"ERROR: could not open data file " << fname << std::endl;
96 int TMVACrossValidationRegression()
99 TMVA::Tools::Instance();
104 TString outfileName(
"TMVARegCv.root");
105 TFile * outputFile = TFile::Open(outfileName,
"RECREATE");
107 TString infileName(
"./files/tmva_reg_example.root");
108 TFile * inputFile = getDataFile(infileName);
110 TMVA::DataLoader *dataloader=
new TMVA::DataLoader(
"dataset");
112 dataloader->AddVariable(
"var1",
"Variable 1",
"units",
'F');
113 dataloader->AddVariable(
"var2",
"Variable 2",
"units",
'F');
116 dataloader->AddTarget(
"fvalue");
118 TTree * regTree = (TTree*)inputFile->Get(
"TreeR");
119 dataloader->AddRegressionTree(regTree, 1.0);
124 std::cout <<
"--- TMVACrossValidationRegression: Using input file: " << inputFile->GetName() << std::endl;
133 TCut selectionCut =
"";
134 dataloader->PrepareTrainingAndTestTree(selectionCut,
"nTest_Regression=1"
136 ":NormMode=NumEvents"
150 TString analysisType =
"Regression";
151 TString splitExpr =
"";
153 TString cvOptions = Form(
"!V"
160 analysisType.Data(), numFolds, splitExpr.Data());
162 TMVA::CrossValidation cv{
"TMVACrossValidationRegression", dataloader, outputFile, cvOptions};
169 cv.BookMethod(TMVA::Types::kBDT,
"BDTG",
170 "!H:!V:NTrees=500:BoostType=Grad:Shrinkage=0.1:"
171 "UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=3");
189 std::cout <<
"==> Wrote root file: " << outputFile->GetName() << std::endl;
190 std::cout <<
"==> TMVACrossValidationRegression is done!" << std::endl;
197 if (!gROOT->IsBatch()) {
198 TMVA::TMVAGui(outfileName);
208 int main(
int argc,
char **argv)
210 TMVACrossValidationRegression();