28 ClassImp(TMVA::CvSplit);
29 ClassImp(TMVA::CvSplitKFolds);
38 TMVA::CvSplit::CvSplit(UInt_t numFolds) : fNumFolds(numFolds), fMakeFoldDataSet(kFALSE) {}
57 void TMVA::CvSplit::PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
59 if (foldNumber >= fNumFolds) {
60 Log() << kFATAL <<
"DataSet prepared for \"" << fNumFolds <<
"\" folds, requested fold \"" << foldNumber
61 <<
"\" is outside of range." << Endl;
65 auto prepareDataSetInternal = [
this, &dsi, foldNumber](std::vector<std::vector<Event *>> vec) {
66 UInt_t numFolds = fTrainEvents.size();
69 UInt_t nTotal = std::accumulate(vec.begin(), vec.end(), 0,
70 [&](UInt_t sum, std::vector<TMVA::Event *> v) {
return sum + v.size(); });
72 UInt_t nTrain = nTotal - vec.at(foldNumber).size();
73 UInt_t nTest = vec.at(foldNumber).size();
75 std::vector<Event *> tempTrain;
76 std::vector<Event *> tempTest;
78 tempTrain.reserve(nTrain);
79 tempTest.reserve(nTest);
82 for (UInt_t i = 0; i < numFolds; ++i) {
83 if (i == foldNumber) {
87 tempTrain.insert(tempTrain.end(), vec.at(i).begin(), vec.at(i).end());
91 tempTest.insert(tempTest.end(), vec.at(foldNumber).begin(), vec.at(foldNumber).end());
93 Log() << kDEBUG <<
"Fold prepared, num events in training set: " << tempTrain.size() << Endl;
94 Log() << kDEBUG <<
"Fold prepared, num events in test set: " << tempTest.size() << Endl;
97 dsi.GetDataSet()->SetEventCollection(&tempTrain, Types::kTraining,
false);
98 dsi.GetDataSet()->SetEventCollection(&tempTest, Types::kTesting,
false);
101 if (tt == Types::kTraining) {
102 prepareDataSetInternal(fTrainEvents);
103 }
else if (tt == Types::kTesting) {
104 prepareDataSetInternal(fTestEvents);
106 Log() << kFATAL <<
"PrepareFoldDataSet can only work with training and testing data sets." << std::endl;
114 void TMVA::CvSplit::RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt)
116 if (tt != Types::kTraining) {
117 Log() << kFATAL <<
"Only kTraining is supported for CvSplit::RecombineKFoldDataSet currently." << std::endl;
120 std::vector<Event *> *tempVec =
new std::vector<Event *>;
122 for (UInt_t i = 0; i < fNumFolds; ++i) {
123 tempVec->insert(tempVec->end(), fTrainEvents.at(i).begin(), fTrainEvents.at(i).end());
126 dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTraining,
false);
127 dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTesting,
false);
139 TMVA::CvSplitKFoldsExpr::CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
140 : fDsi(dsi), fIdxFormulaParNumFolds(std::numeric_limits<UInt_t>::max()), fSplitFormula(
"", expr),
141 fParValues(fSplitFormula.GetNpar())
143 if (!fSplitFormula.IsValid()) {
144 throw std::runtime_error(
"Split expression \"" + std::string(fSplitExpr.Data()) +
"\" is not a valid TFormula.");
147 for (Int_t iFormulaPar = 0; iFormulaPar < fSplitFormula.GetNpar(); ++iFormulaPar) {
148 TString name = fSplitFormula.GetParName(iFormulaPar);
152 if (name ==
"NumFolds" || name ==
"numFolds") {
154 fIdxFormulaParNumFolds = iFormulaPar;
156 fFormulaParIdxToDsiSpecIdx.push_back(std::make_pair(iFormulaPar, GetSpectatorIndexForName(fDsi, name)));
164 UInt_t TMVA::CvSplitKFoldsExpr::Eval(UInt_t numFolds,
const Event *ev)
166 for (
auto &p : fFormulaParIdxToDsiSpecIdx) {
167 auto iFormulaPar = p.first;
168 auto iSpectator = p.second;
170 fParValues.at(iFormulaPar) = ev->GetSpectator(iSpectator);
173 if (fIdxFormulaParNumFolds < fSplitFormula.GetNpar()) {
174 fParValues[fIdxFormulaParNumFolds] = numFolds;
180 Double_t iFold_d = fSplitFormula.EvalPar(
nullptr, &fParValues[0]);
183 throw std::runtime_error(
"Output of splitExpr must be non-negative.");
186 UInt_t iFold = std::lround(iFold_d);
187 if (iFold >= numFolds) {
188 throw std::runtime_error(
"Output of splitExpr should be a non-negative"
189 "integer between 0 and numFolds-1 inclusive.");
198 Bool_t TMVA::CvSplitKFoldsExpr::Validate(TString expr)
200 return TFormula(
"", expr).IsValid();
206 UInt_t TMVA::CvSplitKFoldsExpr::GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
208 std::vector<VariableInfo> spectatorInfos = dsi.GetSpectatorInfos();
210 for (UInt_t iSpectator = 0; iSpectator < spectatorInfos.size(); ++iSpectator) {
211 VariableInfo vi = spectatorInfos[iSpectator];
212 if (vi.GetName() == name) {
214 }
else if (vi.GetLabel() == name) {
216 }
else if (vi.GetExpression() == name) {
221 throw std::runtime_error(
"Spectator \"" + std::string(name.Data()) +
"\" not found.");
243 TMVA::CvSplitKFolds::CvSplitKFolds(UInt_t numFolds, TString splitExpr, Bool_t stratified, UInt_t seed)
244 : CvSplit(numFolds), fSeed(seed), fSplitExprString(splitExpr), fStratified(stratified)
246 if (!CvSplitKFoldsExpr::Validate(fSplitExprString) && (splitExpr != TString(
""))) {
247 Log() << kFATAL <<
"Split expression \"" << fSplitExprString <<
"\" is not a valid TFormula." << Endl;
255 void TMVA::CvSplitKFolds::MakeKFoldDataSet(DataSetInfo &dsi)
260 if (fSplitExprString != TString(
"")) {
261 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(
new CvSplitKFoldsExpr(dsi, fSplitExprString));
265 if (fMakeFoldDataSet) {
266 Log() << kINFO <<
"Splitting in k-folds has been already done" << Endl;
270 fMakeFoldDataSet = kTRUE;
272 UInt_t numClasses = dsi.GetNClasses();
275 std::vector<Event *> trainData = dsi.GetDataSet()->GetEventCollection(Types::kTraining);
276 std::vector<Event *> testData = dsi.GetDataSet()->GetEventCollection(Types::kTesting);
279 fTrainEvents = SplitSets(trainData, fNumFolds, numClasses);
280 fTestEvents = SplitSets(testData, fNumFolds, numClasses);
293 std::vector<UInt_t> TMVA::CvSplitKFolds::GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed)
297 std::vector<UInt_t> fOrigToFoldMapping;
298 fOrigToFoldMapping.reserve(nEntries);
300 for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
301 fOrigToFoldMapping.push_back(iEvent % numFolds);
305 TMVA::RandomGenerator<TRandom3> rng(seed);
306 std::shuffle(fOrigToFoldMapping.begin(), fOrigToFoldMapping.end(), rng);
308 return fOrigToFoldMapping;
318 std::vector<std::vector<TMVA::Event *>>
319 TMVA::CvSplitKFolds::SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses)
321 const ULong64_t nEntries = oldSet.size();
322 const ULong64_t foldSize = nEntries / numFolds;
324 std::vector<std::vector<Event *>> tempSets;
325 tempSets.reserve(fNumFolds);
326 for (UInt_t iFold = 0; iFold < numFolds; ++iFold) {
327 tempSets.emplace_back();
328 tempSets.at(iFold).reserve(foldSize);
331 Bool_t useSplitExpr = !(fSplitExpr ==
nullptr || fSplitExprString ==
"");
335 for (ULong64_t i = 0; i < nEntries; i++) {
336 TMVA::Event *ev = oldSet[i];
337 UInt_t iFold = fSplitExpr->Eval(numFolds, ev);
338 tempSets.at((UInt_t)iFold).push_back(ev);
343 std::vector<UInt_t> fOrigToFoldMapping;
344 fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
346 for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
347 UInt_t iFold = fOrigToFoldMapping[iEvent];
348 TMVA::Event *ev = oldSet[iEvent];
349 tempSets.at(iFold).push_back(ev);
351 fEventToFoldMapping[ev] = iFold;
355 std::vector<std::vector<TMVA::Event *>> oldSets;
356 oldSets.reserve(numClasses);
358 for(UInt_t iClass = 0; iClass < numClasses; iClass++){
359 oldSets.emplace_back();
361 oldSets.reserve(nEntries);
364 for(UInt_t iEvent = 0; iEvent < nEntries; ++iEvent){
366 TMVA::Event *ev = oldSet[iEvent];
367 UInt_t iClass = ev->GetClass();
368 oldSets.at(iClass).push_back(ev);
371 for(UInt_t i = 0; i<numClasses; ++i){
373 TMVA::RandomGenerator<TRandom3> rng(fSeed);
374 std::shuffle(oldSets.at(i).begin(), oldSets.at(i).end(), rng);
377 for(UInt_t i = 0; i<numClasses; ++i) {
378 std::vector<UInt_t> fOrigToFoldMapping;
379 fOrigToFoldMapping = GetEventIndexToFoldMapping(oldSets.at(i).size(), numFolds, fSeed);
381 for (UInt_t iEvent = 0; iEvent < oldSets.at(i).size(); ++iEvent) {
382 UInt_t iFold = fOrigToFoldMapping[iEvent];
383 TMVA::Event *ev = oldSets.at(i)[iEvent];
384 tempSets.at(iFold).push_back(ev);
385 fEventToFoldMapping[ev] = iFold;