51 #define MinNoTrainingEvents 10
54 TMVA::Experimental::ClassificationResult::ClassificationResult() : fROCIntegral(0)
59 TMVA::Experimental::ClassificationResult::ClassificationResult(
const ClassificationResult &cr) : TObject(cr)
62 fDataLoaderName = cr.fDataLoaderName;
63 fMvaTrain = cr.fMvaTrain;
64 fMvaTest = cr.fMvaTest;
66 fROCIntegral = cr.fROCIntegral;
76 Double_t TMVA::Experimental::ClassificationResult::GetROCIntegral(UInt_t iClass, TMVA::Types::ETreeType type)
81 auto roc = GetROC(iClass, type);
82 auto inte = roc->GetROCIntegral();
95 TMVA::ROCCurve *TMVA::Experimental::ClassificationResult::GetROC(UInt_t iClass, TMVA::Types::ETreeType type)
97 ROCCurve *fROCCurve =
nullptr;
98 if (type == TMVA::Types::kTesting)
99 fROCCurve =
new ROCCurve(fMvaTest[iClass]);
101 fROCCurve =
new ROCCurve(fMvaTrain[iClass]);
106 TMVA::Experimental::ClassificationResult &TMVA::Experimental::ClassificationResult::
107 operator=(
const TMVA::Experimental::ClassificationResult &cr)
109 fMethod = cr.fMethod;
110 fDataLoaderName = cr.fDataLoaderName;
111 fMvaTrain = cr.fMvaTrain;
112 fMvaTest = cr.fMvaTest;
113 fIsCuts = cr.fIsCuts;
114 fROCIntegral = cr.fROCIntegral;
123 void TMVA::Experimental::ClassificationResult::Show()
125 MsgLogger fLogger(
"Classification");
126 TMVA::MsgLogger::EnableOutput();
127 TMVA::gConfig().SetSilent(kFALSE);
128 TString hLine =
"--------------------------------------------------- :";
130 fLogger << kINFO << hLine << Endl;
131 fLogger << kINFO <<
"DataSet MVA :" << Endl;
132 fLogger << kINFO <<
"Name: Method/Title: ROC-integ :" << Endl;
133 fLogger << kINFO << hLine << Endl;
134 fLogger << kINFO << Form(
"%-20s %-15s %#1.3f :", fDataLoaderName.Data(),
135 Form(
"%s/%s", fMethod.GetValue<TString>(
"MethodName").Data(),
136 fMethod.GetValue<TString>(
"MethodTitle").Data()),
139 fLogger << kINFO << hLine << Endl;
141 TMVA::gConfig().SetSilent(kTRUE);
151 TGraph *TMVA::Experimental::ClassificationResult::GetROCGraph(UInt_t iClass, TMVA::Types::ETreeType type)
153 TGraph *roc = GetROC(iClass, type)->GetROCCurve();
154 roc->SetName(Form(
"%s/%s", GetMethodName().Data(), GetMethodTitle().Data()));
155 roc->SetTitle(Form(
"%s/%s", GetMethodName().Data(), GetMethodTitle().Data()));
156 roc->GetXaxis()->SetTitle(
" Signal Efficiency ");
157 roc->GetYaxis()->SetTitle(
" Background Rejection ");
168 Bool_t TMVA::Experimental::ClassificationResult::IsMethod(TString methodname, TString methodtitle)
170 return fMethod.GetValue<TString>(
"MethodName") == methodname &&
171 fMethod.GetValue<TString>(
"MethodTitle") == methodtitle
183 TMVA::Experimental::Classification::Classification(DataLoader *dataloader, TFile *file, TString options)
184 : TMVA::Envelope(
"Classification", dataloader, file, options), fAnalysisType(Types::kClassification),
185 fCorrelations(kFALSE), fROC(kTRUE)
187 DeclareOptionRef(fCorrelations,
"Correlations",
"boolean to show correlation in output");
188 DeclareOptionRef(fROC,
"ROC",
"boolean to show ROC in output");
190 CheckForUnusedOptions();
192 if (fModelPersistence)
193 gSystem->MakeDirectory(fDataLoader->GetName());
202 TMVA::Experimental::Classification::Classification(DataLoader *dataloader, TString options)
203 : TMVA::Envelope(
"Classification", dataloader, NULL, options), fAnalysisType(Types::kClassification),
204 fCorrelations(kFALSE), fROC(kTRUE)
208 SetConfigDescription(
"Configuration options for Classification running");
209 SetConfigName(GetName());
211 DeclareOptionRef(fCorrelations,
"Correlations",
"boolean to show correlation in output");
212 DeclareOptionRef(fROC,
"ROC",
"boolean to show ROC in output");
214 CheckForUnusedOptions();
215 if (fModelPersistence)
216 gSystem->MakeDirectory(fDataLoader->GetName());
217 fAnalysisType = TMVA::Types::kClassification;
221 TMVA::Experimental::Classification::~Classification()
223 for (
auto m : fIMethods) {
236 TString TMVA::Experimental::Classification::GetMethodOptions(TString methodname, TString methodtitle)
238 for (
auto &meth : fMethods) {
239 if (meth.GetValue<TString>(
"MethodName") == methodname && meth.GetValue<TString>(
"MethodTitle") == methodtitle)
240 return meth.GetValue<TString>(
"MethodOptions");
250 void TMVA::Experimental::Classification::Evaluate()
261 for (
auto &meth : fMethods) {
262 GetMethod(meth.GetValue<TString>(
"MethodName"), meth.GetValue<TString>(
"MethodTitle"));
265 fWorkers.SetNWorkers(fJobs);
267 auto executor = [=](UInt_t workerID) -> ClassificationResult {
268 TMVA::MsgLogger::InhibitOutput();
269 TMVA::gConfig().SetSilent(kTRUE);
270 TMVA::gConfig().SetUseColor(kFALSE);
271 TMVA::gConfig().SetDrawProgressBar(kFALSE);
272 auto methodname = fMethods[workerID].GetValue<TString>(
"MethodName");
273 auto methodtitle = fMethods[workerID].GetValue<TString>(
"MethodTitle");
274 auto meth = GetMethod(methodname, methodtitle);
275 if (!IsSilentFile()) {
276 auto fname = Form(
".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
277 auto f =
new TFile(fname,
"RECREATE");
278 f->mkdir(fDataLoader->GetName());
282 TrainMethod(methodname, methodtitle);
283 TestMethod(methodname, methodtitle);
284 if (!IsSilentFile()) {
287 return GetResults(methodname, methodtitle);
291 fResults = fWorkers.Map(executor, ROOT::TSeqI(fMethods.size()));
298 TMVA::gConfig().SetSilent(kFALSE);
300 TString hLine =
"--------------------------------------------------- :";
301 Log() << kINFO << hLine << Endl;
302 Log() << kINFO <<
"DataSet MVA :" << Endl;
303 Log() << kINFO <<
"Name: Method/Title: ROC-integ :" << Endl;
304 Log() << kINFO << hLine << Endl;
305 for (
auto &r : fResults) {
307 Log() << kINFO << Form(
"%-20s %-15s %#1.3f :", r.GetDataLoaderName().Data(),
308 Form(
"%s/%s", r.GetMethodName().Data(), r.GetMethodTitle().Data()), r.GetROCIntegral())
311 Log() << kINFO << hLine << Endl;
313 Log() << kINFO <<
"-----------------------------------------------------" << Endl;
314 Log() << kHEADER <<
"Evaluation done." << Endl << Endl;
315 Log() << kINFO << Form(
"Jobs = %d Real Time = %lf ", fJobs, fTimer.RealTime()) << Endl;
316 Log() << kINFO <<
"-----------------------------------------------------" << Endl;
317 Log() << kINFO <<
"Evaluation done." << Endl;
318 TMVA::gConfig().SetSilent(kTRUE);
325 void TMVA::Experimental::Classification::Train()
327 for (
auto &meth : fMethods) {
328 TrainMethod(meth.GetValue<TString>(
"MethodName"), meth.GetValue<TString>(
"MethodTitle"));
338 void TMVA::Experimental::Classification::TrainMethod(TString methodname, TString methodtitle)
340 auto method = GetMethod(methodname, methodtitle);
343 << Form(
"Trying to train method %s %s that maybe is not booked.", methodname.Data(), methodtitle.Data())
346 Log() << kHEADER << gTools().Color(
"bold") << Form(
"Training method %s %s", methodname.Data(), methodtitle.Data())
347 << gTools().Color(
"reset") << Endl;
349 Event::SetIsTraining(kTRUE);
350 if ((fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification) &&
351 method->DataInfo().GetNClasses() < 2)
352 Log() << kFATAL <<
"You want to do classification training, but specified less than two classes." << Endl;
357 if (method->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
358 Log() << kWARNING <<
"Method " << method->GetMethodName() <<
" not trained (training tree has less entries ["
359 << method->Data()->GetNTrainingEvents() <<
"] than required [" << MinNoTrainingEvents <<
"]" << Endl;
363 Log() << kHEADER <<
"Train method: " << method->GetMethodName() <<
" for Classification" << Endl << Endl;
364 method->TrainMethod();
365 Log() << kHEADER <<
"Training finished" << Endl << Endl;
374 void TMVA::Experimental::Classification::TrainMethod(Types::EMVA method, TString methodtitle)
376 TrainMethod(Types::Instance().GetMethodName(method), methodtitle);
387 TMVA::MethodBase *TMVA::Experimental::Classification::GetMethod(TString methodname, TString methodtitle)
390 if (!HasMethod(methodname, methodtitle)) {
391 std::cout << methodname <<
" " << methodtitle << std::endl;
392 Log() << kERROR <<
"Trying to get method not booked." << Endl;
396 if (HasMethodObject(methodname, methodtitle, index)) {
397 return dynamic_cast<MethodBase *
>(fIMethods[index]);
400 if (GetDataLoaderDataInput().GetEntries() <=
402 Log() << kFATAL <<
"No input data for the training provided!" << Endl;
404 Log() << kHEADER <<
"Loading booked method: " << gTools().Color(
"bold") << methodname <<
" " << methodtitle
405 << gTools().Color(
"reset") << Endl << Endl;
407 TString moptions = GetMethodOptions(methodname, methodtitle);
411 auto conf =
new TMVA::Configurable(moptions);
412 conf->DeclareOptionRef(boostNum = 0,
"Boost_num",
"Number of times the classifier will be boosted");
413 conf->ParseOptions();
417 if (fModelPersistence) {
418 fFileDir = fDataLoader->GetName();
419 fFileDir +=
"/" + gConfig().GetIONames().fWeightFileDir;
424 TString fJobName = GetName();
426 im = ClassifierFactory::Instance().Create(std::string(methodname.Data()), fJobName, methodtitle,
427 GetDataLoaderDataSetInfo(), moptions);
430 Log() << kDEBUG <<
"Boost Number is " << boostNum <<
" > 0: train boosted classifier" << Endl;
431 im = ClassifierFactory::Instance().Create(std::string(
"Boost"), fJobName, methodtitle, GetDataLoaderDataSetInfo(),
433 MethodBoost *methBoost =
dynamic_cast<MethodBoost *
>(im);
435 Log() << kFATAL <<
"Method with type kBoost cannot be casted to MethodCategory. /Classification" << Endl;
437 if (fModelPersistence)
438 methBoost->SetWeightFileDir(fFileDir);
439 methBoost->SetModelPersistence(fModelPersistence);
440 methBoost->SetBoostedMethodName(methodname);
441 methBoost->fDataSetManager = GetDataLoaderDataSetManager();
442 methBoost->SetFile(fFile.get());
443 methBoost->SetSilentFile(IsSilentFile());
446 MethodBase *method =
dynamic_cast<MethodBase *
>(im);
451 if (method->GetMethodType() == Types::kCategory) {
452 MethodCategory *methCat = (
dynamic_cast<MethodCategory *
>(im));
454 Log() << kFATAL <<
"Method with type kCategory cannot be casted to MethodCategory. /Classification" << Endl;
456 if (fModelPersistence)
457 methCat->SetWeightFileDir(fFileDir);
458 methCat->SetModelPersistence(fModelPersistence);
459 methCat->fDataSetManager = GetDataLoaderDataSetManager();
460 methCat->SetFile(fFile.get());
461 methCat->SetSilentFile(IsSilentFile());
464 if (!method->HasAnalysisType(fAnalysisType, GetDataLoaderDataSetInfo().GetNClasses(),
465 GetDataLoaderDataSetInfo().GetNTargets())) {
466 Log() << kWARNING <<
"Method " << method->GetMethodTypeName() <<
" is not capable of handling ";
467 Log() <<
"classification with " << GetDataLoaderDataSetInfo().GetNClasses() <<
" classes." << Endl;
471 if (fModelPersistence)
472 method->SetWeightFileDir(fFileDir);
473 method->SetModelPersistence(fModelPersistence);
474 method->SetAnalysisType(fAnalysisType);
475 method->SetupMethod();
476 method->ParseOptions();
477 method->ProcessSetup();
478 method->SetFile(fFile.get());
479 method->SetSilentFile(IsSilentFile());
482 method->CheckSetup();
483 fIMethods.push_back(method);
495 Bool_t TMVA::Experimental::Classification::HasMethodObject(TString methodname, TString methodtitle, Int_t &index)
497 if (fIMethods.empty())
499 for (UInt_t i = 0; i < fIMethods.size(); i++) {
501 auto methbase =
dynamic_cast<MethodBase *
>(fIMethods[i]);
502 if (methbase->GetMethodTypeName() == methodname && methbase->GetMethodName() == methodtitle) {
514 void TMVA::Experimental::Classification::Test()
516 for (
auto &meth : fMethods) {
517 TestMethod(meth.GetValue<TString>(
"MethodName"), meth.GetValue<TString>(
"MethodTitle"));
527 void TMVA::Experimental::Classification::TestMethod(TString methodname, TString methodtitle)
529 auto method = GetMethod(methodname, methodtitle);
532 << Form(
"Trying to train method %s %s that maybe is not booked.", methodname.Data(), methodtitle.Data())
536 Log() << kHEADER << gTools().Color(
"bold") <<
"Test all methods" << gTools().Color(
"reset") << Endl;
537 Event::SetIsTraining(kFALSE);
539 Types::EAnalysisType analysisType = method->GetAnalysisType();
540 Log() << kHEADER <<
"Test method: " << method->GetMethodName() <<
" for Classification"
541 <<
" performance" << Endl << Endl;
542 method->AddOutput(Types::kTesting, analysisType);
552 Int_t nmeth_used[2] = {0, 0};
554 std::vector<std::vector<TString>> mname(2);
555 std::vector<std::vector<Double_t>> sig(2), sep(2), roc(2);
556 std::vector<std::vector<Double_t>> eff01(2), eff10(2), eff30(2), effArea(2);
557 std::vector<std::vector<Double_t>> eff01err(2), eff10err(2), eff30err(2);
558 std::vector<std::vector<Double_t>> trainEff01(2), trainEff10(2), trainEff30(2);
560 method->SetFile(fFile.get());
561 method->SetSilentFile(IsSilentFile());
563 MethodBase *methodNoCuts = NULL;
564 if (!IsCutsMethod(method))
565 methodNoCuts = method;
567 Log() << kHEADER <<
"Evaluate classifier: " << method->GetMethodName() << Endl << Endl;
568 isel = (method->GetMethodTypeName().Contains(
"Variable")) ? 1 : 0;
571 method->TestClassification();
574 mname[isel].push_back(method->GetMethodName());
575 sig[isel].push_back(method->GetSignificance());
576 sep[isel].push_back(method->GetSeparation());
577 roc[isel].push_back(method->GetROCIntegral());
580 eff01[isel].push_back(method->GetEfficiency(
"Efficiency:0.01", Types::kTesting, err));
581 eff01err[isel].push_back(err);
582 eff10[isel].push_back(method->GetEfficiency(
"Efficiency:0.10", Types::kTesting, err));
583 eff10err[isel].push_back(err);
584 eff30[isel].push_back(method->GetEfficiency(
"Efficiency:0.30", Types::kTesting, err));
585 eff30err[isel].push_back(err);
586 effArea[isel].push_back(method->GetEfficiency(
"", Types::kTesting, err));
588 trainEff01[isel].push_back(method->GetTrainingEfficiency(
"Efficiency:0.01"));
589 trainEff10[isel].push_back(method->GetTrainingEfficiency(
"Efficiency:0.10"));
590 trainEff30[isel].push_back(method->GetTrainingEfficiency(
"Efficiency:0.30"));
594 if (!IsSilentFile()) {
595 Log() << kDEBUG <<
"\tWrite evaluation histograms to file" << Endl;
596 method->WriteEvaluationHistosToFile(Types::kTesting);
597 method->WriteEvaluationHistosToFile(Types::kTraining);
601 for (Int_t k = 0; k < 2; k++) {
602 std::vector<std::vector<Double_t>> vtemp;
603 vtemp.push_back(effArea[k]);
604 vtemp.push_back(eff10[k]);
605 vtemp.push_back(eff01[k]);
606 vtemp.push_back(eff30[k]);
607 vtemp.push_back(eff10err[k]);
608 vtemp.push_back(eff01err[k]);
609 vtemp.push_back(eff30err[k]);
610 vtemp.push_back(trainEff10[k]);
611 vtemp.push_back(trainEff01[k]);
612 vtemp.push_back(trainEff30[k]);
613 vtemp.push_back(sig[k]);
614 vtemp.push_back(sep[k]);
615 vtemp.push_back(roc[k]);
616 std::vector<TString> vtemps = mname[k];
617 gTools().UsefulSortDescending(vtemp, &vtemps);
618 effArea[k] = vtemp[0];
622 eff10err[k] = vtemp[4];
623 eff01err[k] = vtemp[5];
624 eff30err[k] = vtemp[6];
625 trainEff10[k] = vtemp[7];
626 trainEff01[k] = vtemp[8];
627 trainEff30[k] = vtemp[9];
641 const Int_t nmeth = methodNoCuts == NULL ? 0 : 1;
642 const Int_t nvar = method->fDataSetInfo.GetNVariables();
646 Double_t *dvec =
new Double_t[nmeth + nvar];
647 std::vector<Double_t> rvec;
650 TPrincipal *tpSig =
new TPrincipal(nmeth + nvar,
"");
651 TPrincipal *tpBkg =
new TPrincipal(nmeth + nvar,
"");
654 std::vector<TString> *theVars =
new std::vector<TString>;
655 std::vector<ResultsClassification *> mvaRes;
656 theVars->push_back(methodNoCuts->GetTestvarName());
657 rvec.push_back(methodNoCuts->GetSignalReferenceCut());
658 theVars->back().ReplaceAll(
"MVA_",
"");
659 mvaRes.push_back(dynamic_cast<ResultsClassification *>(
660 methodNoCuts->Data()->GetResults(methodNoCuts->GetMethodName(), Types::kTesting, Types::kMaxAnalysisType)));
663 TMatrixD *overlapS =
new TMatrixD(nmeth, nmeth);
664 TMatrixD *overlapB =
new TMatrixD(nmeth, nmeth);
669 DataSet *defDs = method->fDataSetInfo.GetDataSet();
670 defDs->SetCurrentType(Types::kTesting);
671 for (Int_t ievt = 0; ievt < defDs->GetNEvents(); ievt++) {
672 const Event *ev = defDs->GetEvent(ievt);
675 TMatrixD *theMat = 0;
676 for (Int_t im = 0; im < nmeth; im++) {
678 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
679 if (TMath::IsNaN(retval)) {
680 Log() << kWARNING <<
"Found NaN return value in event: " << ievt <<
" for method \""
681 << methodNoCuts->GetName() <<
"\"" << Endl;
686 for (Int_t iv = 0; iv < nvar; iv++)
687 dvec[iv + nmeth] = (Double_t)ev->GetValue(iv);
688 if (method->fDataSetInfo.IsSignal(ev)) {
697 for (Int_t im = 0; im < nmeth; im++) {
698 for (Int_t jm = im; jm < nmeth; jm++) {
699 if ((dvec[im] - rvec[im]) * (dvec[jm] - rvec[jm]) > 0) {
709 (*overlapS) *= (1.0 / defDs->GetNEvtSigTest());
710 (*overlapB) *= (1.0 / defDs->GetNEvtBkgdTest());
712 tpSig->MakePrincipals();
713 tpBkg->MakePrincipals();
715 const TMatrixD *covMatS = tpSig->GetCovarianceMatrix();
716 const TMatrixD *covMatB = tpBkg->GetCovarianceMatrix();
718 const TMatrixD *corrMatS = gTools().GetCorrelationMatrix(covMatS);
719 const TMatrixD *corrMatB = gTools().GetCorrelationMatrix(covMatB);
722 if (corrMatS != 0 && corrMatB != 0) {
725 TMatrixD mvaMatS(nmeth, nmeth);
726 TMatrixD mvaMatB(nmeth, nmeth);
727 for (Int_t im = 0; im < nmeth; im++) {
728 for (Int_t jm = 0; jm < nmeth; jm++) {
729 mvaMatS(im, jm) = (*corrMatS)(im, jm);
730 mvaMatB(im, jm) = (*corrMatB)(im, jm);
735 std::vector<TString> theInputVars;
736 TMatrixD varmvaMatS(nvar, nmeth);
737 TMatrixD varmvaMatB(nvar, nmeth);
738 for (Int_t iv = 0; iv < nvar; iv++) {
739 theInputVars.push_back(method->fDataSetInfo.GetVariableInfo(iv).GetLabel());
740 for (Int_t jm = 0; jm < nmeth; jm++) {
741 varmvaMatS(iv, jm) = (*corrMatS)(nmeth + iv, jm);
742 varmvaMatB(iv, jm) = (*corrMatB)(nmeth + iv, jm);
747 Log() << kINFO << Endl;
748 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
749 <<
"Inter-MVA correlation matrix (signal):" << Endl;
750 gTools().FormattedOutput(mvaMatS, *theVars, Log());
751 Log() << kINFO << Endl;
753 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
754 <<
"Inter-MVA correlation matrix (background):" << Endl;
755 gTools().FormattedOutput(mvaMatB, *theVars, Log());
756 Log() << kINFO << Endl;
759 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
760 <<
"Correlations between input variables and MVA response (signal):" << Endl;
761 gTools().FormattedOutput(varmvaMatS, theInputVars, *theVars, Log());
762 Log() << kINFO << Endl;
764 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
765 <<
"Correlations between input variables and MVA response (background):" << Endl;
766 gTools().FormattedOutput(varmvaMatB, theInputVars, *theVars, Log());
767 Log() << kINFO << Endl;
769 Log() << kWARNING << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
770 <<
"<TestAllMethods> cannot compute correlation matrices" << Endl;
773 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
774 <<
"The following \"overlap\" matrices contain the fraction of events for which " << Endl;
775 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
776 <<
"the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
777 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
778 <<
"An event is signal-like, if its MVA output exceeds the following value:" << Endl;
779 gTools().FormattedOutput(rvec, *theVars,
"Method",
"Cut value", Log());
780 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
781 <<
"which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
785 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
786 <<
"Note: no correlations and overlap with cut method are provided at present" << Endl;
789 Log() << kINFO << Endl;
790 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
791 <<
"Inter-MVA overlap matrix (signal):" << Endl;
792 gTools().FormattedOutput(*overlapS, *theVars, Log());
793 Log() << kINFO << Endl;
795 Log() << kINFO << Form(
"Dataset[%s] : ", method->fDataSetInfo.GetName())
796 <<
"Inter-MVA overlap matrix (background):" << Endl;
797 gTools().FormattedOutput(*overlapB, *theVars, Log());
817 auto &fResult = GetResults(methodname, methodtitle);
821 Log().EnableOutput();
822 gConfig().SetSilent(kFALSE);
824 TString hLine =
"------------------------------------------------------------------------------------------"
825 "-------------------------";
826 Log() << kINFO <<
"Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
827 Log() << kINFO << hLine << Endl;
828 Log() << kINFO <<
"DataSet MVA " << Endl;
829 Log() << kINFO <<
"Name: Method: ROC-integ" << Endl;
831 Log() << kDEBUG << hLine << Endl;
832 for (Int_t k = 0; k < 2; k++) {
833 if (k == 1 && nmeth_used[k] > 0) {
834 Log() << kINFO << hLine << Endl;
835 Log() << kINFO <<
"Input Variables: " << Endl << hLine << Endl;
837 for (Int_t i = 0; i < nmeth_used[k]; i++) {
838 TString datasetName = fDataLoader->GetName();
839 TString methodName = mname[k][i];
842 methodName.ReplaceAll(
"Variable_",
"");
845 TMVA::DataSet *dataset = method->Data();
846 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
847 std::vector<Bool_t> *mvaResType =
dynamic_cast<ResultsClassification *
>(results)->GetValueVectorTypes();
849 Double_t rocIntegral = 0.0;
850 if (mvaResType->size() != 0) {
851 rocIntegral = GetROCIntegral(methodname, methodtitle);
854 if (sep[k][i] < 0 || sig[k][i] < 0) {
856 fResult.fROCIntegral = effArea[k][i];
858 << Form(
"%-13s %-15s: %#1.3f", fDataLoader->GetName(), methodName.Data(), fResult.fROCIntegral)
861 fResult.fROCIntegral = rocIntegral;
862 Log() << kINFO << Form(
"%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
867 Log() << kINFO << hLine << Endl;
868 Log() << kINFO << Endl;
869 Log() << kINFO <<
"Testing efficiency compared to training efficiency (overtraining check)" << Endl;
870 Log() << kINFO << hLine << Endl;
872 <<
"DataSet MVA Signal efficiency: from test sample (from training sample) "
874 Log() << kINFO <<
"Name: Method: @B=0.01 @B=0.10 @B=0.30 "
876 Log() << kINFO << hLine << Endl;
877 for (Int_t k = 0; k < 2; k++) {
878 if (k == 1 && nmeth_used[k] > 0) {
879 Log() << kINFO << hLine << Endl;
880 Log() << kINFO <<
"Input Variables: " << Endl << hLine << Endl;
882 for (Int_t i = 0; i < nmeth_used[k]; i++) {
884 mname[k][i].ReplaceAll(
"Variable_",
"");
886 Log() << kINFO << Form(
"%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
887 method->fDataSetInfo.GetName(), (
const char *)mname[k][i], eff01[k][i],
888 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
892 Log() << kINFO << hLine << Endl;
893 Log() << kINFO << Endl;
895 if (gTools().CheckForSilentOption(GetOptions()))
896 Log().InhibitOutput();
897 }
else if (IsCutsMethod(method)) {
898 for (Int_t k = 0; k < 2; k++) {
899 for (Int_t i = 0; i < nmeth_used[k]; i++) {
901 if (sep[k][i] < 0 || sig[k][i] < 0) {
903 fResult.fROCIntegral = effArea[k][i];
909 TMVA::DataSet *dataset = method->Data();
910 dataset->SetCurrentType(Types::kTesting);
912 if (IsCutsMethod(method)) {
913 fResult.fIsCuts = kTRUE;
915 auto rocCurveTest = GetROC(methodname, methodtitle, 0, Types::kTesting);
916 fResult.fMvaTest[0] = rocCurveTest->GetMvas();
917 fResult.fROCIntegral = GetROCIntegral(methodname, methodtitle);
919 TString className = method->DataInfo().GetClassInfo(0)->GetName();
920 fResult.fClassNames.push_back(className);
922 if (!IsSilentFile()) {
924 RootBaseDir()->cd(method->fDataSetInfo.GetName());
925 method->fDataSetInfo.GetDataSet()->GetTree(Types::kTesting)->Write(
"", TObject::kOverwrite);
926 method->fDataSetInfo.GetDataSet()->GetTree(Types::kTraining)->Write(
"", TObject::kOverwrite);
936 void TMVA::Experimental::Classification::TestMethod(Types::EMVA method, TString methodtitle)
938 TestMethod(Types::Instance().GetMethodName(method), methodtitle);
946 std::vector<TMVA::Experimental::ClassificationResult> &TMVA::Experimental::Classification::GetResults()
948 if (fResults.size() == 0)
949 Log() << kFATAL <<
"No Classification results available" << Endl;
958 Bool_t TMVA::Experimental::Classification::IsCutsMethod(TMVA::MethodBase *method)
960 return method->GetMethodType() == Types::kCuts ? kTRUE : kFALSE;
970 TMVA::Experimental::ClassificationResult &
971 TMVA::Experimental::Classification::GetResults(TString methodname, TString methodtitle)
973 for (
auto &result : fResults) {
974 if (result.IsMethod(methodname, methodtitle))
977 ClassificationResult result;
978 result.fMethod[
"MethodName"] = methodname;
979 result.fMethod[
"MethodTitle"] = methodtitle;
980 result.fDataLoaderName = fDataLoader->GetName();
981 fResults.push_back(result);
982 return fResults.back();
994 TMVA::Experimental::Classification::GetROC(TMVA::MethodBase *method, UInt_t iClass, Types::ETreeType type)
996 TMVA::DataSet *dataset = method->Data();
997 dataset->SetCurrentType(type);
998 TMVA::Results *results = dataset->GetResults(method->GetName(), type, this->fAnalysisType);
1000 UInt_t nClasses = method->DataInfo().GetNClasses();
1001 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
1002 Log() << kERROR << Form(
"Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
1008 TMVA::ROCCurve *rocCurve =
nullptr;
1009 if (this->fAnalysisType == Types::kClassification) {
1011 std::vector<Float_t> *mvaRes =
dynamic_cast<ResultsClassification *
>(results)->GetValueVector();
1012 std::vector<Bool_t> *mvaResTypes =
dynamic_cast<ResultsClassification *
>(results)->GetValueVectorTypes();
1013 std::vector<Float_t> mvaResWeights;
1015 auto eventCollection = dataset->GetEventCollection(type);
1016 mvaResWeights.reserve(eventCollection.size());
1017 for (
auto ev : eventCollection) {
1018 mvaResWeights.push_back(ev->GetWeight());
1021 rocCurve =
new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
1023 }
else if (this->fAnalysisType == Types::kMulticlass) {
1024 std::vector<Float_t> mvaRes;
1025 std::vector<Bool_t> mvaResTypes;
1026 std::vector<Float_t> mvaResWeights;
1028 std::vector<std::vector<Float_t>> *rawMvaRes =
dynamic_cast<ResultsMulticlass *
>(results)->GetValueVector();
1033 mvaRes.reserve(rawMvaRes->size());
1034 for (
auto item : *rawMvaRes) {
1035 mvaRes.push_back(item[iClass]);
1038 auto eventCollection = dataset->GetEventCollection(type);
1039 mvaResTypes.reserve(eventCollection.size());
1040 mvaResWeights.reserve(eventCollection.size());
1041 for (
auto ev : eventCollection) {
1042 mvaResTypes.push_back(ev->GetClass() == iClass);
1043 mvaResWeights.push_back(ev->GetWeight());
1046 rocCurve =
new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
1061 TMVA::ROCCurve *TMVA::Experimental::Classification::GetROC(TString methodname, TString methodtitle, UInt_t iClass,
1062 TMVA::Types::ETreeType type)
1064 return GetROC(GetMethod(methodname, methodtitle), iClass, type);
1075 Double_t TMVA::Experimental::Classification::GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass)
1077 TMVA::ROCCurve *rocCurve = GetROC(methodname, methodtitle, iClass);
1080 << Form(
"ROCCurve object was not created in MethodName = %s MethodTitle = %s not found with Dataset = %s ",
1081 methodname.Data(), methodtitle.Data(), fDataLoader->GetName())
1086 Int_t npoints = TMVA::gConfig().fVariablePlotting.fNbinsXOfROCCurve + 1;
1087 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
1094 void TMVA::Experimental::Classification::CopyFrom(TDirectory *src, TFile *file)
1096 TFile *savdir = file;
1097 TDirectory *adir = savdir;
1101 TIter nextkey(src->GetListOfKeys());
1102 while ((key = (TKey *)nextkey())) {
1103 const Char_t *classname = key->GetClassName();
1104 TClass *cl = gROOT->GetClass(classname);
1107 if (cl->InheritsFrom(TDirectory::Class())) {
1108 src->cd(key->GetName());
1109 TDirectory *subdir = file;
1111 CopyFrom(subdir, file);
1113 }
else if (cl->InheritsFrom(TTree::Class())) {
1114 TTree *T = (TTree *)src->Get(key->GetName());
1116 TTree *newT = T->CloneTree(-1,
"fast");
1120 TObject *obj = key->ReadObj();
1126 adir->SaveSelf(kTRUE);
1131 void TMVA::Experimental::Classification::MergeFiles()
1134 auto dsdir = fFile->mkdir(fDataLoader->GetName());
1135 TTree *TrainTree = 0;
1136 TTree *TestTree = 0;
1139 for (UInt_t i = 0; i < fMethods.size(); i++) {
1140 auto methodname = fMethods[i].GetValue<TString>(
"MethodName");
1141 auto methodtitle = fMethods[i].GetValue<TString>(
"MethodTitle");
1142 auto fname = Form(
".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
1143 TDirectoryFile *ds = 0;
1145 ifile =
new TFile(fname);
1146 ds = (TDirectoryFile *)ifile->Get(fDataLoader->GetName());
1148 ofile =
new TFile(fname);
1149 ds = (TDirectoryFile *)ofile->Get(fDataLoader->GetName());
1151 auto tmptrain = (TTree *)ds->Get(
"TrainTree");
1152 auto tmptest = (TTree *)ds->Get(
"TestTree");
1154 fFile->cd(fDataLoader->GetName());
1156 auto methdirname = Form(
"Method_%s", methodtitle.Data());
1157 auto methdir = dsdir->mkdir(methdirname, methdirname);
1158 auto methdirbase = methdir->mkdir(methodtitle.Data(), methodtitle.Data());
1159 auto mfdir = (TDirectoryFile *)ds->Get(methdirname);
1160 auto mfdirbase = (TDirectoryFile *)mfdir->Get(methodtitle.Data());
1162 CopyFrom(mfdirbase, (TFile *)methdirbase);
1165 TrainTree = tmptrain->CopyTree(
"");
1166 TestTree = tmptest->CopyTree(
"");
1169 auto trainbranch = TrainTree->Branch(methodtitle.Data(), &mva);
1170 tmptrain->SetBranchAddress(methodtitle.Data(), &mva);
1171 auto entries = tmptrain->GetEntries();
1172 for (UInt_t ev = 0; ev < entries; ev++) {
1173 tmptrain->GetEntry(ev);
1174 trainbranch->Fill();
1176 auto testbranch = TestTree->Branch(methodtitle.Data(), &mva);
1177 tmptest->SetBranchAddress(methodtitle.Data(), &mva);
1178 entries = tmptest->GetEntries();
1179 for (UInt_t ev = 0; ev < entries; ev++) {
1180 tmptest->GetEntry(ev);
1190 for (UInt_t i = 0; i < fMethods.size(); i++) {
1191 auto methodname = fMethods[i].GetValue<TString>(
"MethodName");
1192 auto methodtitle = fMethods[i].GetValue<TString>(
"MethodTitle");
1193 auto fname = Form(
".%s%s%s.root", fDataLoader->GetName(), methodname.Data(), methodtitle.Data());
1194 gSystem->Unlink(fname);