Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
CvSplit.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #include "TMVA/CvSplit.h"
13 
14 #include "TMVA/DataSet.h"
15 #include "TMVA/DataSetFactory.h"
16 #include "TMVA/DataSetInfo.h"
17 #include "TMVA/Event.h"
18 #include "TMVA/MsgLogger.h"
19 #include "TMVA/Tools.h"
20 
21 #include <TString.h>
22 #include <TFormula.h>
23 
24 #include <algorithm>
25 #include <numeric>
26 #include <stdexcept>
27 
28 ClassImp(TMVA::CvSplit);
29 ClassImp(TMVA::CvSplitKFolds);
30 
31 /* =============================================================================
32  TMVA::CvSplit
33 ============================================================================= */
34 
35 ////////////////////////////////////////////////////////////////////////////////
36 ///
37 
38 TMVA::CvSplit::CvSplit(UInt_t numFolds) : fNumFolds(numFolds), fMakeFoldDataSet(kFALSE) {}
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 /// \brief Set training and test set vectors of dataset described by `dsi`.
42 /// \param[in] dsi DataSetInfo for data set to be split
43 /// \param[in] foldNumber Ordinal of fold to prepare
44 /// \param[in] tt The set used to prepare fold. If equal to `Types::kTraining`
45 /// splitting will be based off the original train set. If instead
46 /// equal to `Types::kTesting` the test set will be used.
47 /// The original training/test set is the set as defined by
48 /// `DataLoader::PrepareTrainingAndTestSet`.
49 ///
50 /// Sets the training and test set vectors of the DataSet described by `dsi` as
51 /// defined by the split. If `tt` is eqal to `Types::kTraining` the split will
52 /// be based off of the original training set.
53 ///
54 /// Note: Requires `MakeKFoldDataSet` to have been called first.
55 ///
56 
57 void TMVA::CvSplit::PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt)
58 {
59  if (foldNumber >= fNumFolds) {
60  Log() << kFATAL << "DataSet prepared for \"" << fNumFolds << "\" folds, requested fold \"" << foldNumber
61  << "\" is outside of range." << Endl;
62  return;
63  }
64 
65  auto prepareDataSetInternal = [this, &dsi, foldNumber](std::vector<std::vector<Event *>> vec) {
66  UInt_t numFolds = fTrainEvents.size();
67 
68  // Events in training set (excludes current fold)
69  UInt_t nTotal = std::accumulate(vec.begin(), vec.end(), 0,
70  [&](UInt_t sum, std::vector<TMVA::Event *> v) { return sum + v.size(); });
71 
72  UInt_t nTrain = nTotal - vec.at(foldNumber).size();
73  UInt_t nTest = vec.at(foldNumber).size();
74 
75  std::vector<Event *> tempTrain;
76  std::vector<Event *> tempTest;
77 
78  tempTrain.reserve(nTrain);
79  tempTest.reserve(nTest);
80 
81  // Insert data into training set
82  for (UInt_t i = 0; i < numFolds; ++i) {
83  if (i == foldNumber) {
84  continue;
85  }
86 
87  tempTrain.insert(tempTrain.end(), vec.at(i).begin(), vec.at(i).end());
88  }
89 
90  // Insert data into test set
91  tempTest.insert(tempTest.end(), vec.at(foldNumber).begin(), vec.at(foldNumber).end());
92 
93  Log() << kDEBUG << "Fold prepared, num events in training set: " << tempTrain.size() << Endl;
94  Log() << kDEBUG << "Fold prepared, num events in test set: " << tempTest.size() << Endl;
95 
96  // Assign the vectors of the events to rebuild the dataset
97  dsi.GetDataSet()->SetEventCollection(&tempTrain, Types::kTraining, false);
98  dsi.GetDataSet()->SetEventCollection(&tempTest, Types::kTesting, false);
99  };
100 
101  if (tt == Types::kTraining) {
102  prepareDataSetInternal(fTrainEvents);
103  } else if (tt == Types::kTesting) {
104  prepareDataSetInternal(fTestEvents);
105  } else {
106  Log() << kFATAL << "PrepareFoldDataSet can only work with training and testing data sets." << std::endl;
107  return;
108  }
109 }
110 
111 ////////////////////////////////////////////////////////////////////////////////
112 ///
113 
114 void TMVA::CvSplit::RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt)
115 {
116  if (tt != Types::kTraining) {
117  Log() << kFATAL << "Only kTraining is supported for CvSplit::RecombineKFoldDataSet currently." << std::endl;
118  }
119 
120  std::vector<Event *> *tempVec = new std::vector<Event *>;
121 
122  for (UInt_t i = 0; i < fNumFolds; ++i) {
123  tempVec->insert(tempVec->end(), fTrainEvents.at(i).begin(), fTrainEvents.at(i).end());
124  }
125 
126  dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTraining, false);
127  dsi.GetDataSet()->SetEventCollection(tempVec, Types::kTesting, false);
128 
129  delete tempVec;
130 }
131 
132 /* =============================================================================
133  TMVA::CvSplitKFoldsExpr
134 ============================================================================= */
135 
136 ////////////////////////////////////////////////////////////////////////////////
137 ///
138 
139 TMVA::CvSplitKFoldsExpr::CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr)
140  : fDsi(dsi), fIdxFormulaParNumFolds(std::numeric_limits<UInt_t>::max()), fSplitFormula("", expr),
141  fParValues(fSplitFormula.GetNpar())
142 {
143  if (!fSplitFormula.IsValid()) {
144  throw std::runtime_error("Split expression \"" + std::string(fSplitExpr.Data()) + "\" is not a valid TFormula.");
145  }
146 
147  for (Int_t iFormulaPar = 0; iFormulaPar < fSplitFormula.GetNpar(); ++iFormulaPar) {
148  TString name = fSplitFormula.GetParName(iFormulaPar);
149 
150  // std::cout << "Found variable with name \"" << name << "\"." << std::endl;
151 
152  if (name == "NumFolds" || name == "numFolds") {
153  // std::cout << "NumFolds|numFolds is a reserved variable! Adding to context." << std::endl;
154  fIdxFormulaParNumFolds = iFormulaPar;
155  } else {
156  fFormulaParIdxToDsiSpecIdx.push_back(std::make_pair(iFormulaPar, GetSpectatorIndexForName(fDsi, name)));
157  }
158  }
159 }
160 
161 ////////////////////////////////////////////////////////////////////////////////
162 ///
163 
164 UInt_t TMVA::CvSplitKFoldsExpr::Eval(UInt_t numFolds, const Event *ev)
165 {
166  for (auto &p : fFormulaParIdxToDsiSpecIdx) {
167  auto iFormulaPar = p.first;
168  auto iSpectator = p.second;
169 
170  fParValues.at(iFormulaPar) = ev->GetSpectator(iSpectator);
171  }
172 
173  if (fIdxFormulaParNumFolds < fSplitFormula.GetNpar()) {
174  fParValues[fIdxFormulaParNumFolds] = numFolds;
175  }
176 
177  // NOTE: We are using a double to represent an integer here. This _will_
178  // lead to problems if the norm of the double grows too large. A quick test
179  // with python suggests that problems arise at a magnitude of ~1e16.
180  Double_t iFold_d = fSplitFormula.EvalPar(nullptr, &fParValues[0]);
181 
182  if (iFold_d < 0) {
183  throw std::runtime_error("Output of splitExpr must be non-negative.");
184  }
185 
186  UInt_t iFold = std::lround(iFold_d);
187  if (iFold >= numFolds) {
188  throw std::runtime_error("Output of splitExpr should be a non-negative"
189  "integer between 0 and numFolds-1 inclusive.");
190  }
191 
192  return iFold;
193 }
194 
195 ////////////////////////////////////////////////////////////////////////////////
196 ///
197 
198 Bool_t TMVA::CvSplitKFoldsExpr::Validate(TString expr)
199 {
200  return TFormula("", expr).IsValid();
201 }
202 
203 ////////////////////////////////////////////////////////////////////////////////
204 ///
205 
206 UInt_t TMVA::CvSplitKFoldsExpr::GetSpectatorIndexForName(DataSetInfo &dsi, TString name)
207 {
208  std::vector<VariableInfo> spectatorInfos = dsi.GetSpectatorInfos();
209 
210  for (UInt_t iSpectator = 0; iSpectator < spectatorInfos.size(); ++iSpectator) {
211  VariableInfo vi = spectatorInfos[iSpectator];
212  if (vi.GetName() == name) {
213  return iSpectator;
214  } else if (vi.GetLabel() == name) {
215  return iSpectator;
216  } else if (vi.GetExpression() == name) {
217  return iSpectator;
218  }
219  }
220 
221  throw std::runtime_error("Spectator \"" + std::string(name.Data()) + "\" not found.");
222 }
223 
224 /* =============================================================================
225  TMVA::CvSplitKFolds
226 ============================================================================= */
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 /// \brief Splits a dataset into k folds, ready for use in cross validation.
230 /// \param numFolds[in] Number of folds to split data into
231 /// \param stratified[in] If true, use stratified splitting, balancing the
232 /// number of events across classes and folds. If false,
233 /// no such balancing is done. For
234 /// \param splitExpr[in] Expression used to split data into folds. If `""` a
235 /// random assignment will be done. Otherwise the
236 /// expression is fed into a TFormula and evaluated per
237 /// event. The resulting value is the the fold assignment.
238 /// \param seed[in] Used only when using random splitting (i.e. when
239 /// `splitExpr` is `""`). Seed is used to initialise the random
240 /// number generator when assigning events to folds.
241 ///
242 
243 TMVA::CvSplitKFolds::CvSplitKFolds(UInt_t numFolds, TString splitExpr, Bool_t stratified, UInt_t seed)
244  : CvSplit(numFolds), fSeed(seed), fSplitExprString(splitExpr), fStratified(stratified)
245 {
246  if (!CvSplitKFoldsExpr::Validate(fSplitExprString) && (splitExpr != TString(""))) {
247  Log() << kFATAL << "Split expression \"" << fSplitExprString << "\" is not a valid TFormula." << Endl;
248  }
249 
250 }
251 
252 ////////////////////////////////////////////////////////////////////////////////
253 /// \brief Prepares a DataSet for cross validation
254 
255 void TMVA::CvSplitKFolds::MakeKFoldDataSet(DataSetInfo &dsi)
256 {
257  // Validate spectator
258  // fSpectatorIdx = GetSpectatorIndexForName(dsi, fSpectatorName);
259 
260  if (fSplitExprString != TString("")) {
261  fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(new CvSplitKFoldsExpr(dsi, fSplitExprString));
262  }
263 
264  // No need to do it again if the sets have already been split.
265  if (fMakeFoldDataSet) {
266  Log() << kINFO << "Splitting in k-folds has been already done" << Endl;
267  return;
268  }
269 
270  fMakeFoldDataSet = kTRUE;
271 
272  UInt_t numClasses = dsi.GetNClasses();
273 
274  // Get the original event vectors for testing and training from the dataset.
275  std::vector<Event *> trainData = dsi.GetDataSet()->GetEventCollection(Types::kTraining);
276  std::vector<Event *> testData = dsi.GetDataSet()->GetEventCollection(Types::kTesting);
277 
278  // Split the sets into the number of folds.
279  fTrainEvents = SplitSets(trainData, fNumFolds, numClasses);
280  fTestEvents = SplitSets(testData, fNumFolds, numClasses);
281 }
282 
283 ////////////////////////////////////////////////////////////////////////////////
284 /// \brief Generates a vector of fold assignments
285 /// \param nEntires[in] Number of events in range
286 /// \param numFolds[in] Number of folds to split data into
287 /// \param seed[in] Random seed
288 ///
289 /// Randomly assigns events to `numFolds` folds. Each fold will hold at most
290 /// `nEntries / numFolds + 1` events.
291 ///
292 
293 std::vector<UInt_t> TMVA::CvSplitKFolds::GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed)
294 {
295  // Generate assignment of the pattern `0, 1, 2, 0, 1, 2, 0, 1 ...` for
296  // `numFolds = 3`.
297  std::vector<UInt_t> fOrigToFoldMapping;
298  fOrigToFoldMapping.reserve(nEntries);
299 
300  for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
301  fOrigToFoldMapping.push_back(iEvent % numFolds);
302  }
303 
304  // Shuffle assignment
305  TMVA::RandomGenerator<TRandom3> rng(seed);
306  std::shuffle(fOrigToFoldMapping.begin(), fOrigToFoldMapping.end(), rng);
307 
308  return fOrigToFoldMapping;
309 }
310 
311 
312 ////////////////////////////////////////////////////////////////////////////////
313 /// \brief Split sets for into k-folds
314 /// \param oldSet[in] Original, unsplit, events
315 /// \param numFolds[in] Number of folds to split data into
316 ///
317 
318 std::vector<std::vector<TMVA::Event *>>
319 TMVA::CvSplitKFolds::SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses)
320 {
321  const ULong64_t nEntries = oldSet.size();
322  const ULong64_t foldSize = nEntries / numFolds;
323 
324  std::vector<std::vector<Event *>> tempSets;
325  tempSets.reserve(fNumFolds);
326  for (UInt_t iFold = 0; iFold < numFolds; ++iFold) {
327  tempSets.emplace_back();
328  tempSets.at(iFold).reserve(foldSize);
329  }
330 
331  Bool_t useSplitExpr = !(fSplitExpr == nullptr || fSplitExprString == "");
332 
333  if (useSplitExpr) {
334  // Deterministic split
335  for (ULong64_t i = 0; i < nEntries; i++) {
336  TMVA::Event *ev = oldSet[i];
337  UInt_t iFold = fSplitExpr->Eval(numFolds, ev);
338  tempSets.at((UInt_t)iFold).push_back(ev);
339  }
340  } else {
341  if(!fStratified){
342  // Random split
343  std::vector<UInt_t> fOrigToFoldMapping;
344  fOrigToFoldMapping = GetEventIndexToFoldMapping(nEntries, numFolds, fSeed);
345 
346  for (UInt_t iEvent = 0; iEvent < nEntries; ++iEvent) {
347  UInt_t iFold = fOrigToFoldMapping[iEvent];
348  TMVA::Event *ev = oldSet[iEvent];
349  tempSets.at(iFold).push_back(ev);
350 
351  fEventToFoldMapping[ev] = iFold;
352  }
353  } else {
354  // Stratified Split
355  std::vector<std::vector<TMVA::Event *>> oldSets;
356  oldSets.reserve(numClasses);
357 
358  for(UInt_t iClass = 0; iClass < numClasses; iClass++){
359  oldSets.emplace_back();
360  //find a way to get number of events in each class
361  oldSets.reserve(nEntries);
362  }
363 
364  for(UInt_t iEvent = 0; iEvent < nEntries; ++iEvent){
365  // check the class of event and add to its vector of events
366  TMVA::Event *ev = oldSet[iEvent];
367  UInt_t iClass = ev->GetClass();
368  oldSets.at(iClass).push_back(ev);
369  }
370 
371  for(UInt_t i = 0; i<numClasses; ++i){
372  // Shuffle each vector individually
373  TMVA::RandomGenerator<TRandom3> rng(fSeed);
374  std::shuffle(oldSets.at(i).begin(), oldSets.at(i).end(), rng);
375  }
376 
377  for(UInt_t i = 0; i<numClasses; ++i) {
378  std::vector<UInt_t> fOrigToFoldMapping;
379  fOrigToFoldMapping = GetEventIndexToFoldMapping(oldSets.at(i).size(), numFolds, fSeed);
380 
381  for (UInt_t iEvent = 0; iEvent < oldSets.at(i).size(); ++iEvent) {
382  UInt_t iFold = fOrigToFoldMapping[iEvent];
383  TMVA::Event *ev = oldSets.at(i)[iEvent];
384  tempSets.at(iFold).push_back(ev);
385  fEventToFoldMapping[ev] = iFold;
386  }
387  }
388  }
389  }
390  return tempSets;
391 }