12 #ifndef ROOT_TMVA_CROSS_EVALUATION
13 #define ROOT_TMVA_CROSS_EVALUATION
46 using EventCollection_t = std::vector<Event *>;
47 using EventTypes_t = std::vector<Bool_t>;
48 using EventOutputs_t = std::vector<Float_t>;
49 using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;
51 class CrossValidationFoldResult {
53 CrossValidationFoldResult() {}
54 CrossValidationFoldResult(UInt_t iFold)
76 class CrossValidationResult {
77 friend class CrossValidation;
80 std::map<UInt_t, Float_t> fROCs;
81 std::shared_ptr<TMultiGraph> fROCCurves;
83 std::vector<Double_t> fSigs;
84 std::vector<Double_t> fSeps;
85 std::vector<Double_t> fEff01s;
86 std::vector<Double_t> fEff10s;
87 std::vector<Double_t> fEff30s;
88 std::vector<Double_t> fEffAreas;
89 std::vector<Double_t> fTrainEff01s;
90 std::vector<Double_t> fTrainEff10s;
91 std::vector<Double_t> fTrainEff30s;
94 CrossValidationResult(UInt_t numFolds);
95 CrossValidationResult(
const CrossValidationResult &);
96 ~CrossValidationResult() { fROCCurves =
nullptr; }
98 std::map<UInt_t, Float_t> GetROCValues()
const {
return fROCs; }
99 Float_t GetROCAverage()
const;
100 Float_t GetROCStandardDeviation()
const;
101 TMultiGraph *GetROCCurves(Bool_t fLegend = kTRUE);
102 TGraph *GetAvgROCCurve(UInt_t numSamples = 100)
const;
105 TCanvas *Draw(
const TString name =
"CrossValidation")
const;
106 TCanvas *DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title=
"")
const;
108 std::vector<Double_t> GetSigValues()
const {
return fSigs; }
109 std::vector<Double_t> GetSepValues()
const {
return fSeps; }
110 std::vector<Double_t> GetEff01Values()
const {
return fEff01s; }
111 std::vector<Double_t> GetEff10Values()
const {
return fEff10s; }
112 std::vector<Double_t> GetEff30Values()
const {
return fEff30s; }
113 std::vector<Double_t> GetEffAreaValues()
const {
return fEffAreas; }
114 std::vector<Double_t> GetTrainEff01Values()
const {
return fTrainEff01s; }
115 std::vector<Double_t> GetTrainEff10Values()
const {
return fTrainEff10s; }
116 std::vector<Double_t> GetTrainEff30Values()
const {
return fTrainEff30s; }
119 void Fill(CrossValidationFoldResult
const & fr);
122 class CrossValidation :
public Envelope {
125 explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options);
126 explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile, TString options);
132 void SetNumFolds(UInt_t i);
133 void SetSplitExpr(TString splitExpr);
135 UInt_t GetNumFolds() {
return fNumFolds; }
136 TString GetSplitExpr() {
return fSplitExprString; }
138 Factory &GetFactory() {
return *fFactory; }
140 const std::vector<CrossValidationResult> &GetResults()
const;
145 CrossValidationFoldResult ProcessFold(UInt_t iFold,
const OptionMap & methodInfo);
147 Types::EAnalysisType fAnalysisType;
148 TString fAnalysisTypeStr;
149 TString fSplitTypeStr;
150 Bool_t fCorrelations;
151 TString fCvFactoryOptions;
152 Bool_t fDrawProgressBar;
153 Bool_t fFoldFileOutput;
157 UInt_t fNumWorkerProcs;
159 TString fOutputFactoryOptions;
160 TString fOutputEnsembling;
163 TString fSplitExprString;
164 std::vector<CrossValidationResult> fResults;
166 TString fTransformations;
168 TString fVerboseLevel;
170 std::unique_ptr<Factory> fFoldFactory;
171 std::unique_ptr<Factory> fFactory;
172 std::unique_ptr<CvSplitKFolds> fSplit;
174 ClassDef(CrossValidation, 0);
179 #endif // ROOT_TMVA_CROSS_EVALUATION