41 TMVA::CrossValidationResult::CrossValidationResult(UInt_t numFolds)
42 :fROCCurves(new TMultiGraph())
44 fSigs.resize(numFolds);
45 fSeps.resize(numFolds);
46 fEff01s.resize(numFolds);
47 fEff10s.resize(numFolds);
48 fEff30s.resize(numFolds);
49 fEffAreas.resize(numFolds);
50 fTrainEff01s.resize(numFolds);
51 fTrainEff10s.resize(numFolds);
52 fTrainEff30s.resize(numFolds);
56 TMVA::CrossValidationResult::CrossValidationResult(
const CrossValidationResult &obj)
59 fROCCurves = obj.fROCCurves;
63 fEff01s = obj.fEff01s;
64 fEff10s = obj.fEff10s;
65 fEff30s = obj.fEff30s;
66 fEffAreas = obj.fEffAreas;
67 fTrainEff01s = obj.fTrainEff01s;
68 fTrainEff10s = obj.fTrainEff10s;
69 fTrainEff30s = obj.fTrainEff30s;
73 void TMVA::CrossValidationResult::Fill(CrossValidationFoldResult
const & fr)
75 UInt_t iFold = fr.fFold;
77 fROCs[iFold] = fr.fROCIntegral;
78 fROCCurves->Add(dynamic_cast<TGraph *>(fr.fROC.Clone()));
80 fSigs[iFold] = fr.fSig;
81 fSeps[iFold] = fr.fSep;
82 fEff01s[iFold] = fr.fEff01;
83 fEff10s[iFold] = fr.fEff10;
84 fEff30s[iFold] = fr.fEff30;
85 fEffAreas[iFold] = fr.fEffArea;
86 fTrainEff01s[iFold] = fr.fTrainEff01;
87 fTrainEff10s[iFold] = fr.fTrainEff10;
88 fTrainEff30s[iFold] = fr.fTrainEff30;
92 TMultiGraph *TMVA::CrossValidationResult::GetROCCurves(Bool_t )
94 return fROCCurves.get();
107 TGraph *TMVA::CrossValidationResult::GetAvgROCCurve(UInt_t numSamples)
const
110 Double_t increment = 1.0 / (numSamples-1);
111 std::vector<Double_t> x(numSamples), y(numSamples);
113 TList *rocCurveList = fROCCurves.get()->GetListOfGraphs();
115 for(UInt_t iSample = 0; iSample < numSamples; iSample++) {
116 Double_t xPoint = iSample * increment;
119 for(Int_t iGraph = 0; iGraph < rocCurveList->GetSize(); iGraph++) {
120 TGraph *foldROC =
static_cast<TGraph *
>(rocCurveList->At(iGraph));
121 rocSum += foldROC->Eval(xPoint);
125 y[iSample] = rocSum/rocCurveList->GetSize();
128 return new TGraph(numSamples, &x[0], &y[0]);
132 Float_t TMVA::CrossValidationResult::GetROCAverage()
const
135 for(
auto &roc : fROCs) {
138 return avg/fROCs.size();
142 Float_t TMVA::CrossValidationResult::GetROCStandardDeviation()
const
146 Float_t avg=GetROCAverage();
147 for(
auto &roc : fROCs) {
148 std+=TMath::Power(roc.second-avg, 2);
150 return TMath::Sqrt(std/
float(fROCs.size()-1.0));
154 void TMVA::CrossValidationResult::Print()
const
156 TMVA::MsgLogger::EnableOutput();
157 TMVA::gConfig().SetSilent(kFALSE);
159 MsgLogger fLogger(
"CrossValidation");
160 fLogger << kHEADER <<
" ==== Results ====" << Endl;
161 for(
auto &item:fROCs) {
162 fLogger << kINFO << Form(
"Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
165 fLogger << kINFO <<
"------------------------" << Endl;
166 fLogger << kINFO << Form(
"Average ROC-Int : %.4f",GetROCAverage()) << Endl;
167 fLogger << kINFO << Form(
"Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) << Endl;
169 TMVA::gConfig().SetSilent(kTRUE);
173 TCanvas* TMVA::CrossValidationResult::Draw(
const TString name)
const
175 auto *c =
new TCanvas(name.Data());
176 fROCCurves->Draw(
"AL");
177 fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
178 fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
179 Float_t adjust=1+fROCs.size()*0.01;
180 c->BuildLegend(0.15,0.15,0.4*adjust,0.5*adjust);
181 c->SetTitle(
"Cross Validation ROC Curves");
187 TCanvas* TMVA::CrossValidationResult::DrawAvgROCCurve(Bool_t drawFolds, TString title)
const
193 for (
auto foldRocObj : *(*fROCCurves).GetListOfGraphs()) {
194 TGraph * foldRocGraph =
dynamic_cast<TGraph *
>(foldRocObj->Clone());
195 foldRocGraph->SetLineColor(1);
196 foldRocGraph->SetLineWidth(1);
197 rocs.Add(foldRocGraph);
202 TGraph *avgRocGraph = GetAvgROCCurve(100);
203 avgRocGraph->SetTitle(
"Avg ROC Curve");
204 avgRocGraph->SetLineColor(2);
205 avgRocGraph->SetLineWidth(3);
206 rocs.Add(avgRocGraph);
209 TCanvas *c =
new TCanvas();
212 title =
"Cross Validation Average ROC Curve";
215 rocs.SetTitle(title);
216 rocs.GetXaxis()->SetTitle(
"Signal Efficiency");
217 rocs.GetYaxis()->SetTitle(
"Background Rejection");
218 rocs.DrawClone(
"AL");
221 TLegend *leg =
new TLegend();
222 TList *ROCCurveList = rocs.GetListOfGraphs();
225 Int_t nCurves = ROCCurveList->GetSize();
226 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(nCurves-1)),
227 "Avg ROC Curve",
"l");
228 leg->AddEntry(static_cast<TGraph *>(ROCCurveList->At(0)),
229 "Fold ROC Curves",
"l");
236 c->SetTitle(
"Cross Validation Average ROC Curve");
276 TMVA::CrossValidation::CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile,
278 : TMVA::Envelope(jobName, dataloader, nullptr, options),
279 fAnalysisType(Types::kMaxAnalysisType),
280 fAnalysisTypeStr(
"Auto"),
281 fSplitTypeStr(
"Random"),
282 fCorrelations(kFALSE),
283 fCvFactoryOptions(
""),
284 fDrawProgressBar(kFALSE),
285 fFoldFileOutput(kFALSE),
290 fOutputFactoryOptions(
""),
291 fOutputFile(outputFile),
293 fSplitExprString(
""),
295 fTransformations(
""),
300 CrossValidation::ParseOptions();
301 CheckForUnusedOptions();
307 TMVA::CrossValidation::CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options)
308 : CrossValidation(jobName, dataloader, nullptr, options)
315 TMVA::CrossValidation::~CrossValidation() =
default;
320 void TMVA::CrossValidation::InitOptions()
323 DeclareOptionRef(fSilent,
"Silent",
324 "Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory "
325 "class object (default: False)");
326 DeclareOptionRef(fVerbose,
"V",
"Verbose flag");
327 DeclareOptionRef(fVerboseLevel = TString(
"Info"),
"VerboseLevel",
"VerboseLevel (Debug/Verbose/Info)");
328 AddPreDefVal(TString(
"Debug"));
329 AddPreDefVal(TString(
"Verbose"));
330 AddPreDefVal(TString(
"Info"));
332 DeclareOptionRef(fTransformations,
"Transformations",
333 "List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for "
334 "identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation "
337 DeclareOptionRef(fDrawProgressBar,
"DrawProgressBar",
"Boolean to show draw progress bar");
338 DeclareOptionRef(fCorrelations,
"Correlations",
"Boolean to show correlation in output");
339 DeclareOptionRef(fROC,
"ROC",
"Boolean to show ROC in output");
341 TString analysisType(
"Auto");
342 DeclareOptionRef(fAnalysisTypeStr,
"AnalysisType",
343 "Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)");
344 AddPreDefVal(TString(
"Classification"));
345 AddPreDefVal(TString(
"Regression"));
346 AddPreDefVal(TString(
"Multiclass"));
347 AddPreDefVal(TString(
"Auto"));
350 DeclareOptionRef(fSplitTypeStr,
"SplitType",
351 "Set the split type (Deterministic, Random, RandomStratified) (default: Random)");
352 AddPreDefVal(TString(
"Deterministic"));
353 AddPreDefVal(TString(
"Random"));
354 AddPreDefVal(TString(
"RandomStratified"));
356 DeclareOptionRef(fSplitExprString,
"SplitExpr",
"The expression used to assign events to folds");
357 DeclareOptionRef(fNumFolds,
"NumFolds",
"Number of folds to generate");
358 DeclareOptionRef(fNumWorkerProcs,
"NumWorkerProcs",
359 "Determines how many processes to use for evaluation. 1 means no"
360 " parallelisation. 2 means use 2 processes. 0 means figure out the"
361 " number automatically based on the number of cpus available. Default"
364 DeclareOptionRef(fFoldFileOutput,
"FoldFileOutput",
365 "If given a TMVA output file will be generated for each fold. Filename will be the same as "
366 "specifed for the combined output with a _foldX suffix. (default: false)");
368 DeclareOptionRef(fOutputEnsembling = TString(
"None"),
"OutputEnsembling",
369 "Combines output from contained methods. If None, no combination is performed. (default None)");
370 AddPreDefVal(TString(
"None"));
371 AddPreDefVal(TString(
"Avg"));
377 void TMVA::CrossValidation::ParseOptions()
379 this->Envelope::ParseOptions();
381 if (fSplitTypeStr !=
"Deterministic" && fSplitExprString !=
"") {
382 Log() << kFATAL <<
"SplitExpr can only be used with Deterministic Splitting" << Endl;
386 fAnalysisTypeStr.ToLower();
387 if (fAnalysisTypeStr ==
"classification") {
388 fAnalysisType = Types::kClassification;
389 }
else if (fAnalysisTypeStr ==
"regression") {
390 fAnalysisType = Types::kRegression;
391 }
else if (fAnalysisTypeStr ==
"multiclass") {
392 fAnalysisType = Types::kMulticlass;
393 }
else if (fAnalysisTypeStr ==
"auto") {
394 fAnalysisType = Types::kNoAnalysisType;
398 fCvFactoryOptions +=
"V:";
399 fOutputFactoryOptions +=
"V:";
401 fCvFactoryOptions +=
"!V:";
402 fOutputFactoryOptions +=
"!V:";
405 fCvFactoryOptions += Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
406 fOutputFactoryOptions += Form(
"VerboseLevel=%s:", fVerboseLevel.Data());
408 fCvFactoryOptions += Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
409 fOutputFactoryOptions += Form(
"AnalysisType=%s:", fAnalysisTypeStr.Data());
411 if (!fDrawProgressBar) {
412 fCvFactoryOptions +=
"!DrawProgressBar:";
413 fOutputFactoryOptions +=
"!DrawProgressBar:";
416 if (fTransformations !=
"") {
417 fCvFactoryOptions += Form(
"Transformations=%s:", fTransformations.Data());
418 fOutputFactoryOptions += Form(
"Transformations=%s:", fTransformations.Data());
422 fCvFactoryOptions +=
"Correlations:";
423 fOutputFactoryOptions +=
"Correlations:";
425 fCvFactoryOptions +=
"!Correlations:";
426 fOutputFactoryOptions +=
"!Correlations:";
430 fCvFactoryOptions +=
"ROC:";
431 fOutputFactoryOptions +=
"ROC:";
433 fCvFactoryOptions +=
"!ROC:";
434 fOutputFactoryOptions +=
"!ROC:";
438 fCvFactoryOptions += Form(
"Silent:");
439 fOutputFactoryOptions += Form(
"Silent:");
443 if (fFoldFileOutput && fOutputFile ==
nullptr) {
444 Log() << kFATAL <<
"No output file given, cannot generate per fold output." << Endl;
449 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, fCvFactoryOptions);
454 if (fOutputFile ==
nullptr) {
455 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFactoryOptions);
457 fFactory = std::make_unique<TMVA::Factory>(fJobName, fOutputFile, fOutputFactoryOptions);
460 if(fSplitTypeStr ==
"Random"){
461 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString, kFALSE));
462 }
else if(fSplitTypeStr ==
"RandomStratified"){
463 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString, kTRUE));
465 fSplit = std::unique_ptr<CvSplitKFolds>(
new CvSplitKFolds(fNumFolds, fSplitExprString));
471 void TMVA::CrossValidation::SetNumFolds(UInt_t i)
473 if (i != fNumFolds) {
475 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
476 fDataLoader->MakeKFoldDataSet(*fSplit);
484 void TMVA::CrossValidation::SetSplitExpr(TString splitExpr)
486 if (splitExpr != fSplitExprString) {
487 fSplitExprString = splitExpr;
488 fSplit = std::make_unique<CvSplitKFolds>(fNumFolds, fSplitExprString);
489 fDataLoader->MakeKFoldDataSet(*fSplit);
504 TMVA::CrossValidationFoldResult TMVA::CrossValidation::ProcessFold(UInt_t iFold,
const OptionMap & methodInfo)
506 TString methodTypeName = methodInfo.GetValue<TString>(
"MethodName");
507 TString methodTitle = methodInfo.GetValue<TString>(
"MethodTitle");
508 TString methodOptions = methodInfo.GetValue<TString>(
"MethodOptions");
509 TString foldTitle = methodTitle + TString(
"_fold") + TString::Format(
"%i", iFold + 1);
511 Log() << kDEBUG <<
"Processing " << methodTitle <<
" fold " << iFold << Endl;
514 TFile *foldOutputFile =
nullptr;
516 if (fFoldFileOutput && fOutputFile !=
nullptr) {
517 TString path = std::string(
"") + gSystem->DirName(fOutputFile->GetName()) +
"/" + foldTitle +
".root";
518 foldOutputFile = TFile::Open(path,
"RECREATE");
519 Log() << kINFO <<
"Creating fold output at:" << path << Endl;
520 fFoldFactory = std::make_unique<TMVA::Factory>(fJobName, foldOutputFile, fCvFactoryOptions);
523 fDataLoader->PrepareFoldDataSet(*fSplit, iFold, TMVA::Types::kTraining);
524 MethodBase *smethod = fFoldFactory->BookMethod(fDataLoader.get(), methodTypeName, foldTitle, methodOptions);
527 Event::SetIsTraining(kTRUE);
528 smethod->TrainMethod();
529 Event::SetIsTraining(kFALSE);
531 fFoldFactory->TestAllMethods();
532 fFoldFactory->EvaluateAllMethods();
534 TMVA::CrossValidationFoldResult result(iFold);
537 if (fAnalysisType == Types::kClassification || fAnalysisType == Types::kMulticlass) {
538 result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
540 TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle,
true);
541 gr->SetLineColor(iFold + 1);
543 gr->SetTitle(foldTitle.Data());
546 result.fSig = smethod->GetSignificance();
547 result.fSep = smethod->GetSeparation();
549 if (fAnalysisType == Types::kClassification) {
551 result.fEff01 = smethod->GetEfficiency(
"Efficiency:0.01", Types::kTesting, err);
552 result.fEff10 = smethod->GetEfficiency(
"Efficiency:0.10", Types::kTesting, err);
553 result.fEff30 = smethod->GetEfficiency(
"Efficiency:0.30", Types::kTesting, err);
554 result.fEffArea = smethod->GetEfficiency(
"", Types::kTesting, err);
555 result.fTrainEff01 = smethod->GetTrainingEfficiency(
"Efficiency:0.01");
556 result.fTrainEff10 = smethod->GetTrainingEfficiency(
"Efficiency:0.10");
557 result.fTrainEff30 = smethod->GetTrainingEfficiency(
"Efficiency:0.30");
558 }
else if (fAnalysisType == Types::kMulticlass) {
564 if (fFoldFileOutput && foldOutputFile !=
nullptr) {
565 foldOutputFile->Close();
570 smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
571 smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
574 fFoldFactory->DeleteAllMethods();
575 fFoldFactory->fMethodsMap.clear();
585 void TMVA::CrossValidation::Evaluate()
589 fDataLoader->MakeKFoldDataSet(*fSplit);
593 fResults.reserve(fMethods.size());
594 for (
auto & methodInfo : fMethods) {
595 CrossValidationResult result{fNumFolds};
597 TString methodTypeName = methodInfo.GetValue<TString>(
"MethodName");
598 TString methodTitle = methodInfo.GetValue<TString>(
"MethodTitle");
600 if (methodTypeName ==
"") {
601 Log() << kFATAL <<
"No method booked for cross-validation" << Endl;
604 TMVA::MsgLogger::EnableOutput();
605 Log() << kINFO << Endl;
606 Log() << kINFO << Endl;
607 Log() << kINFO <<
"========================================" << Endl;
608 Log() << kINFO <<
"Processing folds for method " << methodTitle << Endl;
609 Log() << kINFO <<
"========================================" << Endl;
610 Log() << kINFO << Endl;
613 auto nWorkers = fNumWorkerProcs;
616 nWorkers = TMVA::gConfig().GetNumWorkers();
619 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
620 auto fold_result = ProcessFold(iFold, methodInfo);
621 result.Fill(fold_result);
625 ROOT::TProcessExecutor workers(nWorkers);
626 std::vector<CrossValidationFoldResult> result_vector;
628 auto workItem = [
this, methodInfo](UInt_t iFold) {
629 return ProcessFold(iFold, methodInfo);
632 result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
634 for (
auto && fold_result : result_vector) {
635 result.Fill(fold_result);
640 fResults.push_back(result);
644 Form(
"SplitExpr=%s:NumFolds=%i"
645 ":EncapsulatedMethodName=%s"
646 ":EncapsulatedMethodTypeName=%s"
647 ":OutputEnsembling=%s",
648 fSplitExprString.Data(), fNumFolds, methodTitle.Data(), methodTypeName.Data(), fOutputEnsembling.Data());
650 fFactory->BookMethod(fDataLoader.get(), Types::kCrossValidation, methodTitle, options);
654 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
655 auto *method =
dynamic_cast<MethodCrossValidation *
>(method_interface);
657 method->fEventToFoldMapping = fSplit->fEventToFoldMapping;
660 Log() << kINFO << Endl;
661 Log() << kINFO << Endl;
662 Log() << kINFO <<
"========================================" << Endl;
663 Log() << kINFO <<
"Folds processed for all methods, evaluating." << Endl;
664 Log() << kINFO <<
"========================================" << Endl;
665 Log() << kINFO << Endl;
668 fDataLoader->RecombineKFoldDataSet(*fSplit);
671 for (
auto & methodInfo : fMethods) {
672 TString methodTypeName = methodInfo.GetValue<TString>(
"MethodName");
673 TString methodTitle = methodInfo.GetValue<TString>(
"MethodTitle");
675 IMethod *method_interface = fFactory->GetMethod(fDataLoader->GetName(), methodTitle);
676 auto method =
dynamic_cast<MethodCrossValidation *
>(method_interface);
678 if (fOutputFile !=
nullptr) {
679 fFactory->WriteDataInformation(method->fDataSetInfo);
682 Event::SetIsTraining(kTRUE);
683 method->TrainMethod();
684 Event::SetIsTraining(kFALSE);
688 fFactory->TestAllMethods();
691 fFactory->EvaluateAllMethods();
693 Log() << kINFO <<
"Evaluation done." << Endl;
697 const std::vector<TMVA::CrossValidationResult> &TMVA::CrossValidation::GetResults()
const
699 if (fResults.empty()) {
700 Log() << kFATAL <<
"No cross-validation results available" << Endl;