Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
CvSplit.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #ifndef ROOT_TMVA_CvSplit
13 #define ROOT_TMVA_CvSplit
14 
15 #include "TMVA/Configurable.h"
16 #include "TMVA/Types.h"
17 
18 #include <Rtypes.h>
19 #include <TFormula.h>
20 
21 #include <memory>
22 
23 class TString;
24 
25 namespace TMVA {
26 
27 class CrossValidation;
28 class DataSetInfo;
29 class Event;
30 
31 /* =============================================================================
32  TMVA::CvSplit
33 ============================================================================= */
34 
35 class CvSplit : public Configurable {
36 public:
37  CvSplit(UInt_t numFolds);
38  virtual ~CvSplit() {}
39 
40  virtual void MakeKFoldDataSet(DataSetInfo &dsi) = 0;
41  virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt);
42  virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt = Types::kTraining);
43 
44  UInt_t GetNumFolds() { return fNumFolds; }
45  Bool_t NeedsRebuild() { return fMakeFoldDataSet; }
46 
47 protected:
48  UInt_t fNumFolds;
49  Bool_t fMakeFoldDataSet;
50 
51  std::vector<std::vector<TMVA::Event *>> fTrainEvents;
52  std::vector<std::vector<TMVA::Event *>> fTestEvents;
53 
54 protected:
55  ClassDef(CvSplit, 0);
56 };
57 
58 /* =============================================================================
59  TMVA::CvSplitKFoldsExpr
60 ============================================================================= */
61 
62 class CvSplitKFoldsExpr {
63 public:
64  CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr);
65  ~CvSplitKFoldsExpr() {}
66 
67  UInt_t Eval(UInt_t numFolds, const Event *ev);
68 
69  static Bool_t Validate(TString expr);
70 
71 private:
72  UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name);
73 
74 private:
75  DataSetInfo &fDsi;
76 
77  std::vector<std::pair<Int_t, Int_t>>
78  fFormulaParIdxToDsiSpecIdx; //! Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
79  Int_t fIdxFormulaParNumFolds; //! Keeps track of the index of reserved par "NumFolds" in splitExpr.
80  TString fSplitExpr; //! Expression used to split data into folds. Should output values between 0 and numFolds.
81  TFormula fSplitFormula; //! TFormula for splitExpr.
82 
83  std::vector<Double_t> fParValues;
84 };
85 
86 /* =============================================================================
87  TMVA::CvSplitKFolds
88 ============================================================================= */
89 
90 class CvSplitKFolds : public CvSplit {
91 
92  friend CrossValidation;
93 
94 public:
95  CvSplitKFolds(UInt_t numFolds, TString splitExpr = "", Bool_t stratified = kTRUE, UInt_t seed = 100);
96  ~CvSplitKFolds() override {}
97 
98  void MakeKFoldDataSet(DataSetInfo &dsi) override;
99 
100 private:
101  std::vector<std::vector<Event *>> SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses);
102  std::vector<UInt_t> GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed = 100);
103 
104 private:
105  UInt_t fSeed;
106  TString fSplitExprString; //! Expression used to split data into folds. Should output values between 0 and numFolds.
107  std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
108  Bool_t fStratified; // If true, use stratified split. (Balance class presence in each fold).
109 
110  // Used for CrossValidation with random splits (not using the
111  // CVSplitKFoldsExpr functionality) to communicate Event to fold mapping.
112  std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
113 
114 private:
115  ClassDefOverride(CvSplitKFolds, 0);
116 };
117 
118 } // end namespace TMVA
119 
120 #endif