Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
CrossValidation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson and Pourya Vakilipourtakalou
3 // Modified: Kim Albertsson 2017
4 
5 /*************************************************************************
6  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
7  * All rights reserved. *
8  * *
9  * For the licensing terms see $ROOTSYS/LICENSE. *
10  * For the list of contributors see $ROOTSYS/README/CREDITS. *
11  *************************************************************************/
12 
13 #include "TMVA/CrossValidation.h"
14 
15 #include "TMVA/ClassifierFactory.h"
16 #include "TMVA/Config.h"
17 #include "TMVA/CvSplit.h"
18 #include "TMVA/DataSet.h"
19 #include "TMVA/Event.h"
20 #include "TMVA/MethodBase.h"
22 #include "TMVA/MsgLogger.h"
24 #include "TMVA/ResultsMulticlass.h"
25 #include "TMVA/ROCCurve.h"
26 #include "TMVA/tmvaglob.h"
27 #include "TMVA/Types.h"
28 
29 #include "TSystem.h"
30 #include "TAxis.h"
31 #include "TCanvas.h"
32 #include "TGraph.h"
33 #include "TMath.h"
34 
35 #include "ROOT/RMakeUnique.hxx"
36 
37 #include <iostream>
38 #include <memory>
39 
40 //_______________________________________________________________________
41 TMVA::CrossValidationResult::CrossValidationResult(UInt_t numFolds)
42 :fROCCurves(new TMultiGraph())
43 {
44  fSigs.resize(numFolds);
45  fSeps.resize(numFolds);
46  fEff01s.resize(numFolds);
47  fEff10s.resize(numFolds);
48  fEff30s.resize(numFolds);
49  fEffAreas.resize(numFolds);
50  fTrainEff01s.resize(numFolds);
51  fTrainEff10s.resize(numFolds);
52  fTrainEff30s.resize(numFolds);
53 }
54 
55 //_______________________________________________________________________
56 TMVA::CrossValidationResult::CrossValidationResult(const CrossValidationResult &obj)
57 {
58  fROCs=obj.fROCs;
59  fROCCurves = obj.fROCCurves;
60 
61  fSigs = obj.fSigs;
62  fSeps = obj.fSeps;
63  fEff01s = obj.fEff01s;
64  fEff10s = obj.fEff10s;
65  fEff30s = obj.fEff30s;
66  fEffAreas = obj.fEffAreas;
67  fTrainEff01s = obj.fTrainEff01s;
68  fTrainEff10s = obj.fTrainEff10s;
69  fTrainEff30s = obj.fTrainEff30s;
70 }
71 
72 //_______________________________________________________________________
73 void TMVA::CrossValidationResult::Fill(CrossValidationFoldResult const & fr)
74 {
75  UInt_t iFold = fr.fFold;
76 
77  fROCs[iFold] = fr.fROCIntegral;
78  fROCCurves->Add(dynamic_cast<TGraph *>(fr.fROC.Clone()));
79 
80  fSigs[iFold] = fr.fSig;
81  fSeps[iFold] = fr.fSep;
82  fEff01s[iFold] = fr.fEff01;
83  fEff10s[iFold] = fr.fEff10;
84  fEff30s[iFold] = fr.fEff30;
85  fEffAreas[iFold] = fr.fEffArea;
86  fTrainEff01s[iFold] = fr.fTrainEff01;
87  fTrainEff10s[iFold] = fr.fTrainEff10;
88  fTrainEff30s[iFold] = fr.fTrainEff30;
89 }
90 
91 //_______________________________________________________________________
92 TMultiGraph *TMVA::CrossValidationResult::GetROCCurves(Bool_t /*fLegend*/)
93 {
94  return fROCCurves.get();
95 }
96 
97 ////////////////////////////////////////////////////////////////////////////////
98 /// \brief Generates a multigraph that contains an average ROC Curve.
99 ///
100 /// \note You own the returned pointer.
101 ///
102 /// \param numSamples[in] Number of samples used for generating the average ROC
103 /// Curve. Avg. curve will be evaluated only at these
104 /// points (using interpolation if necessary).
105 ///
106 
107 TGraph *TMVA::CrossValidationResult::GetAvgROCCurve(UInt_t numSamples) const
108 {
109  // `numSamples * increment` should equal 1.0!
110  Double_t increment = 1.0 / (numSamples-1);
111  std::vector<Double_t> x(numSamples), y(numSamples);
112 
113  TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
114 
115  for(UInt_t iSample = 0; iSample < numSamples; iSample++) {
116  Double_t xPoint = iSample * increment;
117  Double_t rocSum = 0;
118 
119  for(Int_t iGraph = 0; iGraph < rocCurveList->GetSize(); iGraph++) {
120  TGraph *foldROC = static_cast<TGraph *>(rocCurveList->At(iGraph));
121  rocSum += foldROC->Eval(xPoint);
122  }
123 
124  x[iSample] = xPoint;
125  y[iSample] = rocSum/rocCurveList->GetSize();
126  }
127 
128  return new TGraph(numSamples, &x[0], &y[0]);
129 }
130 
131 //_______________________________________________________________________
132 Float_t TMVA::CrossValidationResult::GetROCAverage() const
133 {
134  Float_t avg=0;
135  for(auto &roc : fROCs) {
136  avg+=roc.second;
137  }
138  return avg/fROCs.size();
139 }
140 
141 //_______________________________________________________________________
142 Float_t TMVA::CrossValidationResult::GetROCStandardDeviation() const
143 {
144  // NOTE: We are using here the unbiased estimation of the standard deviation.
145  Float_t std=0;
146  Float_t avg=GetROCAverage();
147  for(auto &roc : fROCs) {
148  std+=TMath::Power(roc.second-avg, 2);
149  }
150  return TMath::Sqrt(std/float(fROCs.size()-1.0));
151 }
152 
153 //_______________________________________________________________________
154 void TMVA::CrossValidationResult::Print() const
155 {
156  TMVA::MsgLogger::EnableOutput();
157  TMVA::gConfig().SetSilent(kFALSE);
158 
159  MsgLogger fLogger("CrossValidation");
160  fLogger << kHEADER << " ==== Results ====" << Endl;
161  for(auto &item:fROCs) {
162  fLogger << kINFO << Form("Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
163  }
164 
165  fLogger << kINFO << "------------------------" << Endl;
166  fLogger << kINFO << Form("Average ROC-Int : %.4f",GetROCAverage()) << Endl;
167  fLogger << kINFO << Form("Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
168 
169  TMVA::gConfig().SetSilent(kTRUE);
170 }
171 
172 //_______________________________________________________________________
173 TCanvas* TMVA::CrossValidationResult::Draw(const TString name) const
174 {
175  auto *c = new TCanvas(name.Data());
176  fROCCurves->Draw("AL");
177  fROCCurves->GetXaxis()->SetTitle(" Signal Efficiency ");
178  fROCCurves->GetYaxis()->SetTitle(" Background Rejection ");
179  Float_t adjust=1+fROCs.size()*0.01;
180  c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
181  c->SetTitle("Cross Validation ROC Curves");
182  c->Draw();
183  return c;
184 }
185 
186 //
187 TCanvas* TMVA::CrossValidationResult::DrawAvgROCCurve(Bool_t drawFolds, TString title) const
188 {
189  TMultiGraph rocs{};
190 
191  // Potentially add the folds
192  if (drawFolds) {
193  for (auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
194  TGraph * foldRocGraph = dynamic_cast<TGraph *>(foldRocObj->Clone());
195  foldRocGraph->SetLineColor(1);
196  foldRocGraph->SetLineWidth(1);
197  rocs.Add(foldRocGraph);
198  }
199  }
200 
201  // Add the average roc curve
202  TGraph *avgRocGraph = GetAvgROCCurve(100);
203  avgRocGraph->SetTitle("Avg ROC Curve");
204  avgRocGraph->SetLineColor(2);
205  avgRocGraph->SetLineWidth(3);
206  rocs.Add(avgRocGraph);
207 
208  // Draw
209  TCanvas *c = new TCanvas();
210 
211  if (title != "") {
212  title = "Cross Validation Average ROC Curve";
213  }
214 
215  rocs.SetTitle(title);
216  rocs.GetXaxis()->SetTitle("Signal Efficiency");
217  rocs.GetYaxis()->SetTitle("Background Rejection");
218  rocs.DrawClone("AL");
219 
220  // Build legend
221  TLegend *leg = new TLegend();
222  TList *ROCCurveList = rocs.GetListOfGraphs();
223 
224  if (drawFolds) {
225  Int_t nCurves = ROCCurveList->GetSize();
226  leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(nCurves-1)),
227  "Avg ROC Curve", "l");
228  leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(0)),
229  "Fold ROC Curves", "l");
230  leg->Draw();
231  } else {
232  c->BuildLegend();
233  }
234 
235  // Draw Canvas
236  c->SetTitle("Cross Validation Average ROC Curve");
237  c->Draw();
238  return c;
239 }
240 
241 /**
242 * \class TMVA::CrossValidation
243 * \ingroup TMVA
244 * \brief
245 
246 Use html for explicit line breaking<br>
247 Markdown links? [class reference](#reference)?
248 
249 
250 ~~~{.cpp}
251 ce->BookMethod(dataloader, options);
252 ce->Evaluate();
253 ~~~
254 
255 Cross-evaluation will generate a new training and a test set dynamically from
256 from `K` folds. These `K` folds are generated by splitting the input training
257 set. The input test set is currently ignored.
258 
259 This means that when you specify your DataSet you should include all events
260 in your training set. One way of doing this would be the following:
261 
262 ~~~{.cpp}
263 dataloader->AddTree( signalTree, "cls1" );
264 dataloader->AddTree( background, "cls2" );
265 dataloader->PrepareTrainingAndTestTree( "", "", "nTest_cls1=1:nTest_cls2=1" );
266 ~~~
267 
268 ## Split Expression
269 See CVSplit documentation?
270 
271 */
272 
273 ////////////////////////////////////////////////////////////////////////////////
274 ///
275 
276 TMVA::CrossValidation::CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile,
277  TString options)
278  : TMVA::Envelope(jobName, dataloader, nullptr, options),
279  fAnalysisType(Types::kMaxAnalysisType),
280  fAnalysisTypeStr("Auto"),
281  fSplitTypeStr("Random"),
282  fCorrelations(kFALSE),
283  fCvFactoryOptions(""),
284  fDrawProgressBar(kFALSE),
285  fFoldFileOutput(kFALSE),
286  fFoldStatus(kFALSE),
287  fJobName(jobName),
288  fNumFolds(2),
289  fNumWorkerProcs(1),
290  fOutputFactoryOptions(""),
291  fOutputFile(outputFile),
292  fSilent(kFALSE),
293  fSplitExprString(""),
294  fROC(kTRUE),
295  fTransformations(""),
296  fVerbose(kFALSE),
297  fVerboseLevel(kINFO)
298 {
299  InitOptions();
300  CrossValidation::ParseOptions();
301  CheckForUnusedOptions();
302 }
303 
304 ////////////////////////////////////////////////////////////////////////////////
305 ///
306 
307 TMVA::CrossValidation::CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
308  : CrossValidation(jobName, dataloader, nullptr, options)
309 {
310 }
311 
312 ////////////////////////////////////////////////////////////////////////////////
313 ///
314 
315 TMVA::CrossValidation::~CrossValidation() = default;
316 
317 ////////////////////////////////////////////////////////////////////////////////
318 ///
319 
320 void TMVA::CrossValidation::InitOptions()
321 {
322  // Forwarding of Factory options
323  DeclareOptionRef(fSilent, "Silent",
324  "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
325  "class object (default: False)");
326  DeclareOptionRef(fVerbose, "V", "Verbose flag");
327  DeclareOptionRef(fVerboseLevel = TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)");
328  AddPreDefVal(TString("Debug"));
329  AddPreDefVal(TString("Verbose"));
330  AddPreDefVal(TString("Info"));
331 
332  DeclareOptionRef(fTransformations, "Transformations",
333  "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
334  "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
335  "transformations");
336 
337  DeclareOptionRef(fDrawProgressBar, "DrawProgressBar", "Boolean to show draw progress bar");
338  DeclareOptionRef(fCorrelations, "Correlations", "Boolean to show correlation in output");
339  DeclareOptionRef(fROC, "ROC", "Boolean to show ROC in output");
340 
341  TString analysisType("Auto");
342  DeclareOptionRef(fAnalysisTypeStr, "AnalysisType",
343  "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
344  AddPreDefVal(TString("Classification"));
345  AddPreDefVal(TString("Regression"));
346  AddPreDefVal(TString("Multiclass"));
347  AddPreDefVal(TString("Auto"));
348 
349  // Options specific to CE
350  DeclareOptionRef(fSplitTypeStr, "SplitType",
351  "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
352  AddPreDefVal(TString("Deterministic"));
353  AddPreDefVal(TString("Random"));
354  AddPreDefVal(TString("RandomStratified"));
355 
356  DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
357  DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
358  DeclareOptionRef(fNumWorkerProcs, "NumWorkerProcs",
359  "Determines how many processes to use for evaluation. 1 means no"
360  " parallelisation. 2 means use 2 processes. 0 means figure out the"
361  " number automatically based on the number of cpus available. Default"
362  " 1.");
363 
364  DeclareOptionRef(fFoldFileOutput, "FoldFileOutput",
365  "If given a TMVA output file will be generated for each fold. Filename will be the same as "
366  "specifed for the combined output with a _foldX suffix. (default: false)");
367 
368  DeclareOptionRef(fOutputEnsembling = TString("None"), "OutputEnsembling",
369  "Combines output from contained methods. If None, no combination is performed. (default None)");
370  AddPreDefVal(TString("None"));
371  AddPreDefVal(TString("Avg"));
372 }
373 
374 ////////////////////////////////////////////////////////////////////////////////
375 ///
376 
377 void TMVA::CrossValidation::ParseOptions()
378 {
379  this->Envelope::ParseOptions();
380 
381  if (fSplitTypeStr != "Deterministic" && fSplitExprString != "") {
382  Log() << kFATAL << "SplitExpr can only be used with Deterministic Splitting" << Endl;
383  }
384 
385  // Factory options
386  fAnalysisTypeStr.ToLower();
387  if (fAnalysisTypeStr == "classification") {
388  fAnalysisType = Types::kClassification;
389  } else if (fAnalysisTypeStr == "regression") {
390  fAnalysisType = Types::kRegression;
391  } else if (fAnalysisTypeStr == "multiclass") {
392  fAnalysisType = Types::kMulticlass;
393  } else if (fAnalysisTypeStr == "auto") {
394  fAnalysisType = Types::kNoAnalysisType;
395  }
396 
397  if (fVerbose) {
398  fCvFactoryOptions += "V:";
399  fOutputFactoryOptions += "V:";
400  } else {
401  fCvFactoryOptions += "!V:";
402  fOutputFactoryOptions += "!V:";
403  }
404 
405  fCvFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
406  fOutputFactoryOptions += Form("VerboseLevel=%s:", fVerboseLevel.Data());
407 
408  fCvFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
409  fOutputFactoryOptions += Form("AnalysisType=%s:", fAnalysisTypeStr.Data());
410 
411  if (!fDrawProgressBar) {
412  fCvFactoryOptions += "!DrawProgressBar:";
413  fOutputFactoryOptions += "!DrawProgressBar:";
414  }
415 
416  if (fTransformations != "") {
417  fCvFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
418  fOutputFactoryOptions += Form("Transformations=%s:", fTransformations.Data());
419  }
420 
421  if (fCorrelations) {
422  fCvFactoryOptions += "Correlations:";
423  fOutputFactoryOptions += "Correlations:";
424  } else {
425  fCvFactoryOptions += "!Correlations:";
426  fOutputFactoryOptions += "!Correlations:";
427  }
428 
429  if (fROC) {
430  fCvFactoryOptions += "ROC:";
431  fOutputFactoryOptions += "ROC:";
432  } else {
433  fCvFactoryOptions += "!ROC:";
434  fOutputFactoryOptions += "!ROC:";
435  }
436 
437  if (fSilent) {
438  fCvFactoryOptions += Form("Silent:");
439  fOutputFactoryOptions += Form("Silent:");
440  }
441 
442  // CE specific options
443  if (fFoldFileOutput && fOutputFile == nullptr) {
444  Log() << kFATAL << "No output file given, cannot generate per fold output." << Endl;
445  }
446 
447  // Initialisations
448 
449  fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
450 
451  // The fOutputFactory should always have !ModelPersistence set since we use a custom code path for this.
452  // In this case we create a special method (MethodCrossValidation) that can only be used by
453  // CrossValidation and the Reader.
454  if (fOutputFile == nullptr) {
455  fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
456  } else {
457  fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
458  }
459 
460  if(fSplitTypeStr == "Random"){
461  fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kFALSE));
462  } else if(fSplitTypeStr == "RandomStratified"){
463  fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString, kTRUE));
464  } else {
465  fSplit = std::unique_ptr<CvSplitKFolds>(new CvSplitKFolds(fNumFolds, fSplitExprString));
466  }
467 
468 }
469 
470 //_______________________________________________________________________
471 void TMVA::CrossValidation::SetNumFolds(UInt_t i)
472 {
473  if (i != fNumFolds) {
474  fNumFolds = i;
475  fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
476  fDataLoader->MakeKFoldDataSet(*fSplit);
477  fFoldStatus = kTRUE;
478  }
479 }
480 
481 ////////////////////////////////////////////////////////////////////////////////
482 ///
483 
484 void TMVA::CrossValidation::SetSplitExpr(TString splitExpr)
485 {
486  if (splitExpr != fSplitExprString) {
487  fSplitExprString = splitExpr;
488  fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
489  fDataLoader->MakeKFoldDataSet(*fSplit);
490  fFoldStatus = kTRUE;
491  }
492 }
493 
494 ////////////////////////////////////////////////////////////////////////////////
495 /// Evaluates each fold in turn.
496 /// - Prepares train and test data sets
497 /// - Trains method
498 /// - Evalutes on test set
499 /// - Stores the evaluation internally
500 ///
501 /// @param iFold fold to evaluate
502 ///
503 
504 TMVA::CrossValidationFoldResult TMVA::CrossValidation::ProcessFold(UInt_t iFold, const OptionMap & methodInfo)
505 {
506  TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
507  TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
508  TString methodOptions = methodInfo.GetValue<TString>("MethodOptions");
509  TString foldTitle = methodTitle + TString("_fold") + TString::Format("%i", iFold + 1);
510 
511  Log() << kDEBUG << "Processing " << methodTitle << " fold " << iFold << Endl;
512 
513  // Only used if fFoldOutputFile == true
514  TFile *foldOutputFile = nullptr;
515 
516  if (fFoldFileOutput && fOutputFile != nullptr) {
517  TString path = std::string("") + gSystem->DirName(fOutputFile->GetName()) + "/" + foldTitle + ".root";
518  foldOutputFile = TFile::Open(path, "RECREATE");
519  Log() << kINFO << "Creating fold output at:" << path << Endl;
520  fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
521  }
522 
523  fDataLoader->PrepareFoldDataSet(*fSplit, iFold, TMVA::Types::kTraining);
524  MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
525 
526  // Train method (train method and eval train set)
527  Event::SetIsTraining(kTRUE);
528  smethod->TrainMethod();
529  Event::SetIsTraining(kFALSE);
530 
531  fFoldFactory->TestAllMethods();
532  fFoldFactory->EvaluateAllMethods();
533 
534  TMVA::CrossValidationFoldResult result(iFold);
535 
536  // Results for aggregation (ROC integral, efficiencies etc.)
537  if (fAnalysisType == Types::kClassification || fAnalysisType == Types::kMulticlass) {
538  result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
539 
540  TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, true);
541  gr->SetLineColor(iFold + 1);
542  gr->SetLineWidth(2);
543  gr->SetTitle(foldTitle.Data());
544  result.fROC = *gr;
545 
546  result.fSig = smethod->GetSignificance();
547  result.fSep = smethod->GetSeparation();
548 
549  if (fAnalysisType == Types::kClassification) {
550  Double_t err;
551  result.fEff01 = smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err);
552  result.fEff10 = smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err);
553  result.fEff30 = smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err);
554  result.fEffArea = smethod->GetEfficiency("", Types::kTesting, err);
555  result.fTrainEff01 = smethod->GetTrainingEfficiency("Efficiency:0.01");
556  result.fTrainEff10 = smethod->GetTrainingEfficiency("Efficiency:0.10");
557  result.fTrainEff30 = smethod->GetTrainingEfficiency("Efficiency:0.30");
558  } else if (fAnalysisType == Types::kMulticlass) {
559  // Nothing here for now
560  }
561  }
562 
563  // Per-fold file output
564  if (fFoldFileOutput && foldOutputFile != nullptr) {
565  foldOutputFile->Close();
566  }
567 
568  // Clean-up for this fold
569  {
570  smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
571  smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
572  }
573 
574  fFoldFactory->DeleteAllMethods();
575  fFoldFactory->fMethodsMap.clear();
576 
577  return result;
578 }
579 
580 ////////////////////////////////////////////////////////////////////////////////
581 /// Does training, test set evaluation and performance evaluation of using
582 /// cross-evalution.
583 ///
584 
585 void TMVA::CrossValidation::Evaluate()
586 {
587  // Generate K folds on given dataset
588  if (!fFoldStatus) {
589  fDataLoader->MakeKFoldDataSet(*fSplit);
590  fFoldStatus = kTRUE;
591  }
592 
593  fResults.reserve(fMethods.size());
594  for (auto & methodInfo : fMethods) {
595  CrossValidationResult result{fNumFolds};
596 
597  TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
598  TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
599 
600  if (methodTypeName == "") {
601  Log() << kFATAL << "No method booked for cross-validation" << Endl;
602  }
603 
604  TMVA::MsgLogger::EnableOutput();
605  Log() << kINFO << Endl;
606  Log() << kINFO << Endl;
607  Log() << kINFO << "========================================" << Endl;
608  Log() << kINFO << "Processing folds for method " << methodTitle << Endl;
609  Log() << kINFO << "========================================" << Endl;
610  Log() << kINFO << Endl;
611 
612  // Process K folds
613  auto nWorkers = fNumWorkerProcs;
614  if (nWorkers == 1) {
615  // Fall back to global config
616  nWorkers = TMVA::gConfig().GetNumWorkers();
617  }
618  if (nWorkers == 1) {
619  for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
620  auto fold_result = ProcessFold(iFold, methodInfo);
621  result.Fill(fold_result);
622  }
623  } else {
624 #ifndef _MSC_VER
625  ROOT::TProcessExecutor workers(nWorkers);
626  std::vector<CrossValidationFoldResult> result_vector;
627 
628  auto workItem = [this, methodInfo](UInt_t iFold) {
629  return ProcessFold(iFold, methodInfo);
630  };
631 
632  result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
633 
634  for (auto && fold_result : result_vector) {
635  result.Fill(fold_result);
636  }
637 #endif
638  }
639 
640  fResults.push_back(result);
641 
642  // Serialise the cross evaluated method
643  TString options =
644  Form("SplitExpr=%s:NumFolds=%i"
645  ":EncapsulatedMethodName=%s"
646  ":EncapsulatedMethodTypeName=%s"
647  ":OutputEnsembling=%s",
648  fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.Data(), fOutputEnsembling.Data());
649 
650  fFactory->BookMethod(fDataLoader.get(), Types::kCrossValidation, methodTitle, options);
651 
652  // Feed EventToFold mapping used when random fold assignments are used
653  // (when splitExpr="").
654  IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
655  auto *method = dynamic_cast<MethodCrossValidation *>(method_interface);
656 
657  method->fEventToFoldMapping = fSplit->fEventToFoldMapping;
658  }
659 
660  Log() << kINFO << Endl;
661  Log() << kINFO << Endl;
662  Log() << kINFO << "========================================" << Endl;
663  Log() << kINFO << "Folds processed for all methods, evaluating." << Endl;
664  Log() << kINFO << "========================================" << Endl;
665  Log() << kINFO << Endl;
666 
667  // Recombination of data (making sure there is data in training and testing trees).
668  fDataLoader->RecombineKFoldDataSet(*fSplit);
669 
670  // "Eval" on training set
671  for (auto & methodInfo : fMethods) {
672  TString methodTypeName = methodInfo.GetValue<TString>("MethodName");
673  TString methodTitle = methodInfo.GetValue<TString>("MethodTitle");
674 
675  IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
676  auto method = dynamic_cast<MethodCrossValidation *>(method_interface);
677 
678  if (fOutputFile != nullptr) {
679  fFactory->WriteDataInformation(method->fDataSetInfo);
680  }
681 
682  Event::SetIsTraining(kTRUE);
683  method->TrainMethod();
684  Event::SetIsTraining(kFALSE);
685  }
686 
687  // Eval on Testing set
688  fFactory->TestAllMethods();
689 
690  // Calc statistics
691  fFactory->EvaluateAllMethods();
692 
693  Log() << kINFO << "Evaluation done." << Endl;
694 }
695 
696 //_______________________________________________________________________
697 const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults() const
698 {
699  if (fResults.empty()) {
700  Log() << kFATAL << "No cross-validation results available" << Endl;
701  }
702  return fResults;
703 }