12 #ifndef ROOT_TMVA_CvSplit
13 #define ROOT_TMVA_CvSplit
27 class CrossValidation;
35 class CvSplit :
public Configurable {
37 CvSplit(UInt_t numFolds);
40 virtual void MakeKFoldDataSet(DataSetInfo &dsi) = 0;
41 virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt);
42 virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt = Types::kTraining);
44 UInt_t GetNumFolds() {
return fNumFolds; }
45 Bool_t NeedsRebuild() {
return fMakeFoldDataSet; }
49 Bool_t fMakeFoldDataSet;
51 std::vector<std::vector<TMVA::Event *>> fTrainEvents;
52 std::vector<std::vector<TMVA::Event *>> fTestEvents;
62 class CvSplitKFoldsExpr {
64 CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr);
65 ~CvSplitKFoldsExpr() {}
67 UInt_t Eval(UInt_t numFolds,
const Event *ev);
69 static Bool_t Validate(TString expr);
72 UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name);
77 std::vector<std::pair<Int_t, Int_t>>
78 fFormulaParIdxToDsiSpecIdx;
79 Int_t fIdxFormulaParNumFolds;
81 TFormula fSplitFormula;
83 std::vector<Double_t> fParValues;
90 class CvSplitKFolds :
public CvSplit {
92 friend CrossValidation;
95 CvSplitKFolds(UInt_t numFolds, TString splitExpr =
"", Bool_t stratified = kTRUE, UInt_t seed = 100);
96 ~CvSplitKFolds()
override {}
98 void MakeKFoldDataSet(DataSetInfo &dsi)
override;
101 std::vector<std::vector<Event *>> SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses);
102 std::vector<UInt_t> GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed = 100);
106 TString fSplitExprString;
107 std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
112 std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
115 ClassDefOverride(CvSplitKFolds, 0);