34 TMVA::HyperParameterOptimisationResult::HyperParameterOptimisationResult()
35 : fROCAVG(0.0), fROCCurves(std::make_shared<TMultiGraph>())
40 TMVA::HyperParameterOptimisationResult::~HyperParameterOptimisationResult()
45 TMultiGraph *TMVA::HyperParameterOptimisationResult::GetROCCurves(Bool_t )
48 return fROCCurves.get();
52 void TMVA::HyperParameterOptimisationResult::Print()
const
54 TMVA::MsgLogger::EnableOutput();
55 TMVA::gConfig().SetSilent(kFALSE);
57 MsgLogger fLogger(
"HyperParameterOptimisation");
59 for(UInt_t j=0; j<fFoldParameters.size(); ++j) {
60 fLogger<<kHEADER<<
"===========================================================" << Endl;
61 fLogger<<kINFO<<
"Optimisation for " << fMethodName <<
" fold " << j+1 << Endl;
63 for(
auto &it : fFoldParameters.at(j)) {
64 fLogger<<kINFO<< it.first <<
" " << it.second << Endl;
68 TMVA::gConfig().SetSilent(kTRUE);
73 TMVA::HyperParameterOptimisation::HyperParameterOptimisation(TMVA::DataLoader *dataloader):Envelope(
"HyperParameterOptimisation",dataloader),
74 fFomType(
"Separation"),
78 fClassifier(new TMVA::Factory(
"HyperParameterOptimisation",
"!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
84 TMVA::HyperParameterOptimisation::~HyperParameterOptimisation()
90 void TMVA::HyperParameterOptimisation::SetNumFolds(UInt_t i)
98 void TMVA::HyperParameterOptimisation::Evaluate()
100 for (
auto &meth : fMethods) {
101 TString methodName = meth.GetValue<TString>(
"MethodName");
102 TString methodTitle = meth.GetValue<TString>(
"MethodTitle");
103 TString methodOptions = meth.GetValue<TString>(
"MethodOptions");
105 CvSplitKFolds split{fNumFolds,
"", kFALSE, 0};
107 fDataLoader->MakeKFoldDataSet(split);
110 fResults.fMethodName = methodName;
112 for (UInt_t i = 0; i < fNumFolds; ++i) {
113 TString foldTitle = methodTitle;
117 Event::SetIsTraining(kTRUE);
118 fDataLoader->PrepareFoldDataSet(split, i, TMVA::Types::kTraining);
120 auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
122 auto params = smethod->OptimizeTuningParameters(fFomType, fFitType);
123 fResults.fFoldParameters.push_back(params);
125 smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
127 fClassifier->DeleteAllMethods();
129 fClassifier->fMethodsMap.clear();