Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
CrossValidation.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata, Thomas James Stevenson, Pourya Vakilipourtakalou, Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #ifndef ROOT_TMVA_CROSS_EVALUATION
13 #define ROOT_TMVA_CROSS_EVALUATION
14 
15 #include "TGraph.h"
16 #include "TMultiGraph.h"
17 #include "TString.h"
18 
19 #include "TMVA/IMethod.h"
20 #include "TMVA/Configurable.h"
21 #include "TMVA/Types.h"
22 #include "TMVA/DataSet.h"
23 #include "TMVA/Event.h"
24 #include <TMVA/Results.h>
25 #include <TMVA/Factory.h>
26 #include <TMVA/DataLoader.h>
27 #include <TMVA/OptionMap.h>
28 #include <TMVA/Envelope.h>
29 
30 /*! \class TMVA::CrossValidationResult
31  * Class to save the results of cross validation,
32  * the metric for the classification ins ROC and you can ROC curves
33  * ROC integrals, ROC average and ROC standard deviation.
34 \ingroup TMVA
35 */
36 
37 /*! \class TMVA::CrossValidation
38  * Class to perform cross validation, splitting the dataloader into folds.
39 \ingroup TMVA
40 */
41 
42 namespace TMVA {
43 
44 class CvSplitKFolds;
45 
46 using EventCollection_t = std::vector<Event *>;
47 using EventTypes_t = std::vector<Bool_t>;
48 using EventOutputs_t = std::vector<Float_t>;
49 using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;
50 
51 class CrossValidationFoldResult {
52 public:
53  CrossValidationFoldResult() {} // For multi-proc serialisation
54  CrossValidationFoldResult(UInt_t iFold)
55  : fFold(iFold)
56  {}
57 
58  UInt_t fFold;
59 
60  Float_t fROCIntegral;
61  TGraph fROC;
62 
63  Double_t fSig;
64  Double_t fSep;
65  Double_t fEff01;
66  Double_t fEff10;
67  Double_t fEff30;
68  Double_t fEffArea;
69  Double_t fTrainEff01;
70  Double_t fTrainEff10;
71  Double_t fTrainEff30;
72 };
73 
74 // Used internally to keep per-fold aggregate statistics
75 // such as ROC curves, ROC integrals and efficiencies.
76 class CrossValidationResult {
77  friend class CrossValidation;
78 
79 private:
80  std::map<UInt_t, Float_t> fROCs;
81  std::shared_ptr<TMultiGraph> fROCCurves;
82 
83  std::vector<Double_t> fSigs;
84  std::vector<Double_t> fSeps;
85  std::vector<Double_t> fEff01s;
86  std::vector<Double_t> fEff10s;
87  std::vector<Double_t> fEff30s;
88  std::vector<Double_t> fEffAreas;
89  std::vector<Double_t> fTrainEff01s;
90  std::vector<Double_t> fTrainEff10s;
91  std::vector<Double_t> fTrainEff30s;
92 
93 public:
94  CrossValidationResult(UInt_t numFolds);
95  CrossValidationResult(const CrossValidationResult &);
96  ~CrossValidationResult() { fROCCurves = nullptr; }
97 
98  std::map<UInt_t, Float_t> GetROCValues() const { return fROCs; }
99  Float_t GetROCAverage() const;
100  Float_t GetROCStandardDeviation() const;
101  TMultiGraph *GetROCCurves(Bool_t fLegend = kTRUE);
102  TGraph *GetAvgROCCurve(UInt_t numSamples = 100) const;
103  void Print() const;
104 
105  TCanvas *Draw(const TString name = "CrossValidation") const;
106  TCanvas *DrawAvgROCCurve(Bool_t drawFolds=kFALSE, TString title="") const;
107 
108  std::vector<Double_t> GetSigValues() const { return fSigs; }
109  std::vector<Double_t> GetSepValues() const { return fSeps; }
110  std::vector<Double_t> GetEff01Values() const { return fEff01s; }
111  std::vector<Double_t> GetEff10Values() const { return fEff10s; }
112  std::vector<Double_t> GetEff30Values() const { return fEff30s; }
113  std::vector<Double_t> GetEffAreaValues() const { return fEffAreas; }
114  std::vector<Double_t> GetTrainEff01Values() const { return fTrainEff01s; }
115  std::vector<Double_t> GetTrainEff10Values() const { return fTrainEff10s; }
116  std::vector<Double_t> GetTrainEff30Values() const { return fTrainEff30s; }
117 
118 private:
119  void Fill(CrossValidationFoldResult const & fr);
120 };
121 
122 class CrossValidation : public Envelope {
123 
124 public:
125  explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TString options);
126  explicit CrossValidation(TString jobName, TMVA::DataLoader *dataloader, TFile *outputFile, TString options);
127  ~CrossValidation();
128 
129  void InitOptions();
130  void ParseOptions();
131 
132  void SetNumFolds(UInt_t i);
133  void SetSplitExpr(TString splitExpr);
134 
135  UInt_t GetNumFolds() { return fNumFolds; }
136  TString GetSplitExpr() { return fSplitExprString; }
137 
138  Factory &GetFactory() { return *fFactory; }
139 
140  const std::vector<CrossValidationResult> &GetResults() const;
141 
142  void Evaluate();
143 
144 private:
145  CrossValidationFoldResult ProcessFold(UInt_t iFold, const OptionMap & methodInfo);
146 
147  Types::EAnalysisType fAnalysisType;
148  TString fAnalysisTypeStr;
149  TString fSplitTypeStr;
150  Bool_t fCorrelations;
151  TString fCvFactoryOptions;
152  Bool_t fDrawProgressBar;
153  Bool_t fFoldFileOutput; //! If true: generate output file for each fold
154  Bool_t fFoldStatus; //! If true: dataset is prepared
155  TString fJobName;
156  UInt_t fNumFolds; //! Number of folds to prepare
157  UInt_t fNumWorkerProcs; //! Number of processes to use for fold evaluation.
158  //!(Default, no parallel evaluation)
159  TString fOutputFactoryOptions;
160  TString fOutputEnsembling; //! How to combine output of individual folds
161  TFile *fOutputFile;
162  Bool_t fSilent;
163  TString fSplitExprString;
164  std::vector<CrossValidationResult> fResults; //!
165  Bool_t fROC;
166  TString fTransformations;
167  Bool_t fVerbose;
168  TString fVerboseLevel;
169 
170  std::unique_ptr<Factory> fFoldFactory;
171  std::unique_ptr<Factory> fFactory;
172  std::unique_ptr<CvSplitKFolds> fSplit;
173 
174  ClassDef(CrossValidation, 0);
175  };
176 
177 } // namespace TMVA
178 
179 #endif // ROOT_TMVA_CROSS_EVALUATION