Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
HyperParameterOptimisation.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson.
3 
5 
6 #include "TMVA/Configurable.h"
7 #include "TMVA/CvSplit.h"
8 #include "TMVA/DataSet.h"
9 #include "TMVA/Event.h"
10 #include "TMVA/MethodBase.h"
12 #include "TMVA/Types.h"
13 
14 #include "TGraph.h"
15 #include "TMultiGraph.h"
16 #include "TString.h"
17 #include "TSystem.h"
18 
19 #include <iostream>
20 #include <memory>
21 #include <vector>
22 
23 /*! \class TMVA::HyperParameterOptimisationResult
24 \ingroup TMVA
25 
26 */
27 
28 /*! \class TMVA::HyperParameterOptimisation
29 \ingroup TMVA
30 
31 */
32 
33 //_______________________________________________________________________
34 TMVA::HyperParameterOptimisationResult::HyperParameterOptimisationResult()
35  : fROCAVG(0.0), fROCCurves(std::make_shared<TMultiGraph>())
36 {
37 }
38 
39 //_______________________________________________________________________
40 TMVA::HyperParameterOptimisationResult::~HyperParameterOptimisationResult()
41 {
42 }
43 
44 //_______________________________________________________________________
45 TMultiGraph *TMVA::HyperParameterOptimisationResult::GetROCCurves(Bool_t /* fLegend */)
46 {
47 
48  return fROCCurves.get();
49 }
50 
51 //_______________________________________________________________________
52 void TMVA::HyperParameterOptimisationResult::Print() const
53 {
54  TMVA::MsgLogger::EnableOutput();
55  TMVA::gConfig().SetSilent(kFALSE);
56 
57  MsgLogger fLogger("HyperParameterOptimisation");
58 
59  for(UInt_t j=0; j<fFoldParameters.size(); ++j) {
60  fLogger<<kHEADER<< "===========================================================" << Endl;
61  fLogger<<kINFO<< "Optimisation for " << fMethodName << " fold " << j+1 << Endl;
62 
63  for(auto &it : fFoldParameters.at(j)) {
64  fLogger<<kINFO<< it.first << " " << it.second << Endl;
65  }
66  }
67 
68  TMVA::gConfig().SetSilent(kTRUE);
69 
70 }
71 
72 //_______________________________________________________________________
73 TMVA::HyperParameterOptimisation::HyperParameterOptimisation(TMVA::DataLoader *dataloader):Envelope("HyperParameterOptimisation",dataloader),
74  fFomType("Separation"),
75  fFitType("Minuit"),
76  fNumFolds(5),
77  fResults(),
78  fClassifier(new TMVA::Factory("HyperParameterOptimisation","!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
79 {
80  fFoldStatus=kFALSE;
81 }
82 
83 //_______________________________________________________________________
84 TMVA::HyperParameterOptimisation::~HyperParameterOptimisation()
85 {
86  fClassifier=nullptr;
87 }
88 
89 //_______________________________________________________________________
90 void TMVA::HyperParameterOptimisation::SetNumFolds(UInt_t i)
91 {
92  fNumFolds = i;
93  // fDataLoader->MakeKFoldDataSet(fNumFolds);
94  fFoldStatus = kFALSE;
95 }
96 
97 //_______________________________________________________________________
98 void TMVA::HyperParameterOptimisation::Evaluate()
99 {
100  for (auto &meth : fMethods) {
101  TString methodName = meth.GetValue<TString>("MethodName");
102  TString methodTitle = meth.GetValue<TString>("MethodTitle");
103  TString methodOptions = meth.GetValue<TString>("MethodOptions");
104 
105  CvSplitKFolds split{fNumFolds, "", kFALSE, 0};
106  if (!fFoldStatus) {
107  fDataLoader->MakeKFoldDataSet(split);
108  fFoldStatus = kTRUE;
109  }
110  fResults.fMethodName = methodName;
111 
112  for (UInt_t i = 0; i < fNumFolds; ++i) {
113  TString foldTitle = methodTitle;
114  foldTitle += "_opt";
115  foldTitle += i + 1;
116 
117  Event::SetIsTraining(kTRUE);
118  fDataLoader->PrepareFoldDataSet(split, i, TMVA::Types::kTraining);
119 
120  auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
121 
122  auto params = smethod->OptimizeTuningParameters(fFomType, fFitType);
123  fResults.fFoldParameters.push_back(params);
124 
125  smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
126 
127  fClassifier->DeleteAllMethods();
128 
129  fClassifier->fMethodsMap.clear();
130  }
131  }
132 }