Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
DataLoader.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata
3 // Mentors: Lorenzo Moneta, Sergei Gleyzer
4 //NOTE: Based on TMVA::Factory
5 
6 /**********************************************************************************
7  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
8  * Package: TMVA *
9  * Class : DataLoader *
10  * Web : http://tmva.sourceforge.net *
11  * *
12  * Description: *
13  * This is a class to load datasets into every booked method *
14  * *
15  * Authors (alphabetical): *
16  * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
17  * Omar Zapata <Omar.Zapata@cern.ch> - ITM/UdeA, Colombia *
18  * Sergei Gleyzer<sergei.gleyzer@cern.ch> - CERN, Switzerland *
19  * *
20  * Copyright (c) 2005-2015: *
21  * CERN, Switzerland *
22  * ITM/UdeA, Colombia *
23  * *
24  * Redistribution and use in source and binary forms, with or without *
25  * modification, are permitted according to the terms listed in LICENSE *
26  * (http://tmva.sourceforge.net/LICENSE) *
27  **********************************************************************************/
28 
29 
30 /*! \class TMVA::DataLoader
31 \ingroup TMVA
32 
33 */
34 
35 #include "TFile.h"
36 #include "TTree.h"
37 #include "TH2.h"
38 #include "TMath.h"
39 #include "TMatrixD.h"
40 
41 #include "TMVA/DataLoader.h"
42 #include "TMVA/Config.h"
43 #include "TMVA/CvSplit.h"
44 #include "TMVA/Tools.h"
45 #include "TMVA/IMethod.h"
46 #include "TMVA/MethodBase.h"
47 #include "TMVA/DataInputHandler.h"
48 #include "TMVA/DataSetManager.h"
49 #include "TMVA/DataSetInfo.h"
50 #include "TMVA/MethodBoost.h"
51 #include "TMVA/MethodCategory.h"
52 
53 #include "TMVA/VariableInfo.h"
60 
61 ClassImp(TMVA::DataLoader);
62 
63 
64 ////////////////////////////////////////////////////////////////////////////////
65 
66 TMVA::DataLoader::DataLoader( TString thedlName)
67 : Configurable( ),
68  fDataSetManager ( NULL ), //DSMTEST
69  fDataInputHandler ( new DataInputHandler ),
70  fTransformations ( "I" ),
71  fVerbose ( kFALSE ),
72  fDataAssignType ( kAssignEvents ),
73  fATreeEvent (0)
74 {
75  fDataSetManager = new DataSetManager( *fDataInputHandler ); // DSMTEST
76  SetName(thedlName.Data());
77  fLogger->SetSource("DataLoader");
78 }
79 
80 ////////////////////////////////////////////////////////////////////////////////
81 
82 TMVA::DataLoader::~DataLoader( void )
83 {
84  // destructor
85 
86  std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
87  for (;trfIt != fDefaultTrfs.end(); ++trfIt) delete (*trfIt);
88 
89  delete fDataInputHandler;
90 
91  // destroy singletons
92  // DataSetManager::DestroyInstance(); // DSMTEST replaced by following line
93  delete fDataSetManager; // DSMTEST
94 
95  // problem with call of REGISTER_METHOD macro ...
96  // ClassifierDataLoader::DestroyInstance();
97  // Types::DestroyInstance();
98  //Tools::DestroyInstance();
99  //Config::DestroyInstance();
100 }
101 
102 
103 ////////////////////////////////////////////////////////////////////////////////
104 
105 TMVA::DataSetInfo& TMVA::DataLoader::AddDataSet( DataSetInfo &dsi )
106 {
107  return fDataSetManager->AddDataSetInfo(dsi); // DSMTEST
108 }
109 
110 ////////////////////////////////////////////////////////////////////////////////
111 
112 TMVA::DataSetInfo& TMVA::DataLoader::AddDataSet( const TString& dsiName )
113 {
114  DataSetInfo* dsi = fDataSetManager->GetDataSetInfo(dsiName); // DSMTEST
115 
116  if (dsi!=0) return *dsi;
117 
118  return fDataSetManager->AddDataSetInfo(*(new DataSetInfo(dsiName))); // DSMTEST
119 }
120 
121 ////////////////////////////////////////////////////////////////////////////////
122 
123 TMVA::DataSetInfo& TMVA::DataLoader::GetDataSetInfo()
124 {
125  return DefaultDataSetInfo(); // DSMTEST
126 }
127 
128 ////////////////////////////////////////////////////////////////////////////////
129 /// Transforms the variables and return a new DataLoader with the transformed
130 /// variables
131 
132 TMVA::DataLoader* TMVA::DataLoader::VarTransform(TString trafoDefinition)
133 {
134  TString trOptions = "0";
135  TString trName = "None";
136  if (trafoDefinition.Contains("(")) {
137 
138  // contains transformation parameters
139  Ssiz_t parStart = trafoDefinition.Index( "(" );
140  Ssiz_t parLen = trafoDefinition.Index( ")", parStart )-parStart+1;
141 
142  trName = trafoDefinition(0,parStart);
143  trOptions = trafoDefinition(parStart,parLen);
144  trOptions.Remove(parLen-1,1);
145  trOptions.Remove(0,1);
146  }
147  else
148  trName = trafoDefinition;
149 
150  VarTransformHandler* handler = new VarTransformHandler(this);
151  // variance threshold variable transformation
152  if (trName == "VT") {
153 
154  // find threshold value from given input
155  Double_t threshold = 0.0;
156  if (!trOptions.IsFloat()){
157  Log() << kFATAL << " VT transformation must be passed a floating threshold value" << Endl;
158  delete handler;
159  return this;
160  }
161  else
162  threshold = trOptions.Atof();
163  TMVA::DataLoader *transformedLoader = handler->VarianceThreshold(threshold);
164  delete handler;
165  return transformedLoader;
166  }
167  else {
168  Log() << kFATAL << "Incorrect transformation string provided, please check" << Endl;
169  }
170  Log() << kINFO << "No transformation applied, returning original loader" << Endl;
171  return this;
172 }
173 
174 ////////////////////////////////////////////////////////////////////////////////
175 // the next functions are to assign events directly
176 
177 ////////////////////////////////////////////////////////////////////////////////
178 /// create the data assignment tree (for event-wise data assignment by user)
179 
180 TTree* TMVA::DataLoader::CreateEventAssignTrees( const TString& name )
181 {
182  TTree * assignTree = new TTree( name, name );
183  assignTree->SetDirectory(0);
184  assignTree->Branch( "type", &fATreeType, "ATreeType/I" );
185  assignTree->Branch( "weight", &fATreeWeight, "ATreeWeight/F" );
186 
187  std::vector<VariableInfo>& vars = DefaultDataSetInfo().GetVariableInfos();
188  std::vector<VariableInfo>& tgts = DefaultDataSetInfo().GetTargetInfos();
189  std::vector<VariableInfo>& spec = DefaultDataSetInfo().GetSpectatorInfos();
190 
191  if (fATreeEvent.size()==0) fATreeEvent.resize(vars.size()+tgts.size()+spec.size());
192  // add variables
193  for (UInt_t ivar=0; ivar<vars.size(); ivar++) {
194  TString vname = vars[ivar].GetExpression();
195  assignTree->Branch( vname, &fATreeEvent[ivar], vname + "/F" );
196  }
197  // add targets
198  for (UInt_t itgt=0; itgt<tgts.size(); itgt++) {
199  TString vname = tgts[itgt].GetExpression();
200  assignTree->Branch( vname, &fATreeEvent[vars.size()+itgt], vname + "/F" );
201  }
202  // add spectators
203  for (UInt_t ispc=0; ispc<spec.size(); ispc++) {
204  TString vname = spec[ispc].GetExpression();
205  assignTree->Branch( vname, &fATreeEvent[vars.size()+tgts.size()+ispc], vname + "/F" );
206  }
207  return assignTree;
208 }
209 
210 ////////////////////////////////////////////////////////////////////////////////
211 /// add signal training event
212 
213 void TMVA::DataLoader::AddSignalTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
214 {
215  AddEvent( "Signal", Types::kTraining, event, weight );
216 }
217 
218 ////////////////////////////////////////////////////////////////////////////////
219 /// add signal testing event
220 
221 void TMVA::DataLoader::AddSignalTestEvent( const std::vector<Double_t>& event, Double_t weight )
222 {
223  AddEvent( "Signal", Types::kTesting, event, weight );
224 }
225 
226 ////////////////////////////////////////////////////////////////////////////////
227 /// add signal training event
228 
229 void TMVA::DataLoader::AddBackgroundTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
230 {
231  AddEvent( "Background", Types::kTraining, event, weight );
232 }
233 
234 ////////////////////////////////////////////////////////////////////////////////
235 /// add signal training event
236 
237 void TMVA::DataLoader::AddBackgroundTestEvent( const std::vector<Double_t>& event, Double_t weight )
238 {
239  AddEvent( "Background", Types::kTesting, event, weight );
240 }
241 
242 ////////////////////////////////////////////////////////////////////////////////
243 /// add signal training event
244 
245 void TMVA::DataLoader::AddTrainingEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
246 {
247  AddEvent( className, Types::kTraining, event, weight );
248 }
249 
250 ////////////////////////////////////////////////////////////////////////////////
251 /// add signal test event
252 
253 void TMVA::DataLoader::AddTestEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
254 {
255  AddEvent( className, Types::kTesting, event, weight );
256 }
257 
258 ////////////////////////////////////////////////////////////////////////////////
259 /// add event
260 /// vector event : the order of values is: variables + targets + spectators
261 
262 void TMVA::DataLoader::AddEvent( const TString& className, Types::ETreeType tt,
263  const std::vector<Double_t>& event, Double_t weight )
264 {
265  ClassInfo* theClass = DefaultDataSetInfo().AddClass(className); // returns class (creates it if necessary)
266  UInt_t clIndex = theClass->GetNumber();
267 
268 
269  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
270  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
271  fAnalysisType = Types::kMulticlass;
272 
273 
274  if (clIndex>=fTrainAssignTree.size()) {
275  fTrainAssignTree.resize(clIndex+1, 0);
276  fTestAssignTree.resize(clIndex+1, 0);
277  }
278 
279  if (fTrainAssignTree[clIndex]==0) { // does not exist yet
280  fTrainAssignTree[clIndex] = CreateEventAssignTrees( Form("TrainAssignTree_%s", className.Data()) );
281  fTestAssignTree[clIndex] = CreateEventAssignTrees( Form("TestAssignTree_%s", className.Data()) );
282  }
283 
284  fATreeType = clIndex;
285  fATreeWeight = weight;
286  for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
287 
288  if(tt==Types::kTraining) fTrainAssignTree[clIndex]->Fill();
289  else fTestAssignTree[clIndex]->Fill();
290 
291 }
292 
293 ////////////////////////////////////////////////////////////////////////////////
294 ///
295 
296 Bool_t TMVA::DataLoader::UserAssignEvents(UInt_t clIndex)
297 {
298  return fTrainAssignTree[clIndex]!=0;
299 }
300 
301 ////////////////////////////////////////////////////////////////////////////////
302 /// assign event-wise local trees to data set
303 
304 void TMVA::DataLoader::SetInputTreesFromEventAssignTrees()
305 {
306  UInt_t size = fTrainAssignTree.size();
307  for(UInt_t i=0; i<size; i++) {
308  if(!UserAssignEvents(i)) continue;
309  const TString& className = DefaultDataSetInfo().GetClassInfo(i)->GetName();
310  SetWeightExpression( "weight", className );
311  AddTree(fTrainAssignTree[i], className, 1.0, TCut(""), Types::kTraining );
312  AddTree(fTestAssignTree[i], className, 1.0, TCut(""), Types::kTesting );
313  }
314 }
315 
316 ////////////////////////////////////////////////////////////////////////////////
317 /// number of signal events (used to compute significance)
318 
319 void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight,
320  const TCut& cut, const TString& treetype )
321 {
322  Types::ETreeType tt = Types::kMaxTreeType;
323  TString tmpTreeType = treetype; tmpTreeType.ToLower();
324  if (tmpTreeType.Contains( "train" ) && tmpTreeType.Contains( "test" )) tt = Types::kMaxTreeType;
325  else if (tmpTreeType.Contains( "train" )) tt = Types::kTraining;
326  else if (tmpTreeType.Contains( "test" )) tt = Types::kTesting;
327  else {
328  Log() << kFATAL << "<AddTree> cannot interpret tree type: \"" << treetype
329  << "\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl;
330  }
331  AddTree( tree, className, weight, cut, tt );
332 }
333 
334 ////////////////////////////////////////////////////////////////////////////////
335 
336 void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight,
337  const TCut& cut, Types::ETreeType tt )
338 {
339  if(!tree)
340  Log() << kFATAL << "Tree does not exist (empty pointer)." << Endl;
341 
342  DefaultDataSetInfo().AddClass( className );
343 
344  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
345  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
346  fAnalysisType = Types::kMulticlass;
347 
348  Log() << kINFO<< "Add Tree " << tree->GetName() << " of type " << className
349  << " with " << tree->GetEntries() << " events" << Endl;
350  DataInput().AddTree( tree, className, weight, cut, tt );
351 }
352 
353 ////////////////////////////////////////////////////////////////////////////////
354 /// number of signal events (used to compute significance)
355 
356 void TMVA::DataLoader::AddSignalTree( TTree* signal, Double_t weight, Types::ETreeType treetype )
357 {
358  AddTree( signal, "Signal", weight, TCut(""), treetype );
359 }
360 
361 ////////////////////////////////////////////////////////////////////////////////
362 /// add signal tree from text file
363 
364 void TMVA::DataLoader::AddSignalTree( TString datFileS, Double_t weight, Types::ETreeType treetype )
365 {
366  // create trees from these ascii files
367  TTree* signalTree = new TTree( "TreeS", "Tree (S)" );
368  signalTree->ReadFile( datFileS );
369 
370  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Signal file : \""
371  << datFileS << Endl;
372 
373  // number of signal events (used to compute significance)
374  AddTree( signalTree, "Signal", weight, TCut(""), treetype );
375 }
376 
377 ////////////////////////////////////////////////////////////////////////////////
378 
379 void TMVA::DataLoader::AddSignalTree( TTree* signal, Double_t weight, const TString& treetype )
380 {
381  AddTree( signal, "Signal", weight, TCut(""), treetype );
382 }
383 
384 ////////////////////////////////////////////////////////////////////////////////
385 /// number of signal events (used to compute significance)
386 
387 void TMVA::DataLoader::AddBackgroundTree( TTree* signal, Double_t weight, Types::ETreeType treetype )
388 {
389  AddTree( signal, "Background", weight, TCut(""), treetype );
390 }
391 
392 ////////////////////////////////////////////////////////////////////////////////
393 /// add background tree from text file
394 
395 void TMVA::DataLoader::AddBackgroundTree( TString datFileB, Double_t weight, Types::ETreeType treetype )
396 {
397  // create trees from these ascii files
398  TTree* bkgTree = new TTree( "TreeB", "Tree (B)" );
399  bkgTree->ReadFile( datFileB );
400 
401  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Background file : \""
402  << datFileB << Endl;
403 
404  // number of signal events (used to compute significance)
405  AddTree( bkgTree, "Background", weight, TCut(""), treetype );
406 }
407 
408 ////////////////////////////////////////////////////////////////////////////////
409 
410 void TMVA::DataLoader::AddBackgroundTree( TTree* signal, Double_t weight, const TString& treetype )
411 {
412  AddTree( signal, "Background", weight, TCut(""), treetype );
413 }
414 
415 ////////////////////////////////////////////////////////////////////////////////
416 
417 void TMVA::DataLoader::SetSignalTree( TTree* tree, Double_t weight )
418 {
419  AddTree( tree, "Signal", weight );
420 }
421 
422 ////////////////////////////////////////////////////////////////////////////////
423 
424 void TMVA::DataLoader::SetBackgroundTree( TTree* tree, Double_t weight )
425 {
426  AddTree( tree, "Background", weight );
427 }
428 
429 ////////////////////////////////////////////////////////////////////////////////
430 /// set background tree
431 
432 void TMVA::DataLoader::SetTree( TTree* tree, const TString& className, Double_t weight )
433 {
434  AddTree( tree, className, weight, TCut(""), Types::kMaxTreeType );
435 }
436 
437 ////////////////////////////////////////////////////////////////////////////////
438 /// define the input trees for signal and background; no cuts are applied
439 
440 void TMVA::DataLoader::SetInputTrees( TTree* signal, TTree* background,
441  Double_t signalWeight, Double_t backgroundWeight )
442 {
443  AddTree( signal, "Signal", signalWeight, TCut(""), Types::kMaxTreeType );
444  AddTree( background, "Background", backgroundWeight, TCut(""), Types::kMaxTreeType );
445 }
446 
447 ////////////////////////////////////////////////////////////////////////////////
448 
449 void TMVA::DataLoader::SetInputTrees( const TString& datFileS, const TString& datFileB,
450  Double_t signalWeight, Double_t backgroundWeight )
451 {
452  DataInput().AddTree( datFileS, "Signal", signalWeight );
453  DataInput().AddTree( datFileB, "Background", backgroundWeight );
454 }
455 
456 ////////////////////////////////////////////////////////////////////////////////
457 /// define the input trees for signal and background from single input tree,
458 /// containing both signal and background events distinguished by the type
459 /// identifiers: SigCut and BgCut
460 
461 void TMVA::DataLoader::SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut )
462 {
463  AddTree( inputTree, "Signal", 1.0, SigCut, Types::kMaxTreeType );
464  AddTree( inputTree, "Background", 1.0, BgCut , Types::kMaxTreeType );
465 }
466 
467 ////////////////////////////////////////////////////////////////////////////////
468 /// user inserts discriminating variable in data set info
469 
470 void TMVA::DataLoader::AddVariable( const TString& expression, const TString& title, const TString& unit,
471  char type, Double_t min, Double_t max )
472 {
473  DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type );
474 }
475 
476 ////////////////////////////////////////////////////////////////////////////////
477 /// user inserts discriminating variable in data set info
478 
479 void TMVA::DataLoader::AddVariable( const TString& expression, char type,
480  Double_t min, Double_t max )
481 {
482  DefaultDataSetInfo().AddVariable( expression, "", "", min, max, type );
483 }
484 
485 ////////////////////////////////////////////////////////////////////////////////
486 /// user inserts discriminating array of variables in data set info
487 /// in case input tree provides an array of values
488 
489 void TMVA::DataLoader::AddVariablesArray(const TString &expression, int size, char type,
490  Double_t min, Double_t max)
491 {
492  DefaultDataSetInfo().AddVariablesArray(expression, size, "", "", min, max, type);
493 }
494 ////////////////////////////////////////////////////////////////////////////////
495 /// user inserts target in data set info
496 
497 void TMVA::DataLoader::AddTarget( const TString& expression, const TString& title, const TString& unit,
498  Double_t min, Double_t max )
499 {
500  if( fAnalysisType == Types::kNoAnalysisType )
501  fAnalysisType = Types::kRegression;
502 
503  DefaultDataSetInfo().AddTarget( expression, title, unit, min, max );
504 }
505 
506 ////////////////////////////////////////////////////////////////////////////////
507 /// user inserts target in data set info
508 
509 void TMVA::DataLoader::AddSpectator( const TString& expression, const TString& title, const TString& unit,
510  Double_t min, Double_t max )
511 {
512  DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max );
513 }
514 
515 ////////////////////////////////////////////////////////////////////////////////
516 /// default creation
517 
518 TMVA::DataSetInfo& TMVA::DataLoader::DefaultDataSetInfo()
519 {
520  return AddDataSet( fName );
521 }
522 
523 ////////////////////////////////////////////////////////////////////////////////
524 /// fill input variables in data set
525 
526 void TMVA::DataLoader::SetInputVariables( std::vector<TString>* theVariables )
527 {
528  for (std::vector<TString>::iterator it=theVariables->begin();
529  it!=theVariables->end(); ++it) AddVariable(*it);
530 }
531 
532 ////////////////////////////////////////////////////////////////////////////////
533 
534 void TMVA::DataLoader::SetSignalWeightExpression( const TString& variable)
535 {
536  DefaultDataSetInfo().SetWeightExpression(variable, "Signal");
537 }
538 
539 ////////////////////////////////////////////////////////////////////////////////
540 
541 void TMVA::DataLoader::SetBackgroundWeightExpression( const TString& variable)
542 {
543  DefaultDataSetInfo().SetWeightExpression(variable, "Background");
544 }
545 
546 ////////////////////////////////////////////////////////////////////////////////
547 
548 void TMVA::DataLoader::SetWeightExpression( const TString& variable, const TString& className )
549 {
550  //Log() << kWarning << DefaultDataSetInfo().GetNClasses() /*fClasses.size()*/ << Endl;
551  if (className=="") {
552  SetSignalWeightExpression(variable);
553  SetBackgroundWeightExpression(variable);
554  }
555  else DefaultDataSetInfo().SetWeightExpression( variable, className );
556 }
557 
558 ////////////////////////////////////////////////////////////////////////////////
559 
560 void TMVA::DataLoader::SetCut( const TString& cut, const TString& className ) {
561  SetCut( TCut(cut), className );
562 }
563 
564 ////////////////////////////////////////////////////////////////////////////////
565 
566 void TMVA::DataLoader::SetCut( const TCut& cut, const TString& className )
567 {
568  DefaultDataSetInfo().SetCut( cut, className );
569 }
570 
571 ////////////////////////////////////////////////////////////////////////////////
572 
573 void TMVA::DataLoader::AddCut( const TString& cut, const TString& className )
574 {
575  AddCut( TCut(cut), className );
576 }
577 
578 ////////////////////////////////////////////////////////////////////////////////
579 void TMVA::DataLoader::AddCut( const TCut& cut, const TString& className )
580 {
581  DefaultDataSetInfo().AddCut( cut, className );
582 }
583 
584 ////////////////////////////////////////////////////////////////////////////////
585 /// prepare the training and test trees
586 
587 void TMVA::DataLoader::PrepareTrainingAndTestTree( const TCut& cut,
588  Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
589  const TString& otherOpt )
590 {
591  SetInputTreesFromEventAssignTrees();
592 
593  AddCut( cut );
594 
595  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s",
596  NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.Data()) );
597 }
598 
599 ////////////////////////////////////////////////////////////////////////////////
600 /// prepare the training and test trees
601 /// kept for backward compatibility
602 
603 void TMVA::DataLoader::PrepareTrainingAndTestTree( const TCut& cut, Int_t Ntrain, Int_t Ntest )
604 {
605  SetInputTreesFromEventAssignTrees();
606 
607  AddCut( cut );
608 
609  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V",
610  Ntrain, Ntrain, Ntest, Ntest) );
611 }
612 
613 ////////////////////////////////////////////////////////////////////////////////
614 /// prepare the training and test trees
615 /// -> same cuts for signal and background
616 
617 void TMVA::DataLoader::PrepareTrainingAndTestTree( const TCut& cut, const TString& opt )
618 {
619  SetInputTreesFromEventAssignTrees();
620 
621  DefaultDataSetInfo().PrintClasses();
622  AddCut( cut );
623  DefaultDataSetInfo().SetSplitOptions( opt );
624 }
625 
626 ////////////////////////////////////////////////////////////////////////////////
627 /// prepare the training and test trees
628 
629 void TMVA::DataLoader::PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt )
630 {
631  // if event-wise data assignment, add local trees to dataset first
632  SetInputTreesFromEventAssignTrees();
633 
634  //Log() << kINFO <<"Preparing trees for training and testing..."<< Endl;
635  AddCut( sigcut, "Signal" );
636  AddCut( bkgcut, "Background" );
637 
638  DefaultDataSetInfo().SetSplitOptions( splitOpt );
639 }
640 
641 ////////////////////////////////////////////////////////////////////////////////
642 /// Function required to split the training and testing datasets into a
643 /// number of folds. Required by the CrossValidation and HyperParameterOptimisation
644 /// classes. The option to split the training dataset into a training set and
645 /// a validation set is implemented but not currently used.
646 
647 void TMVA::DataLoader::MakeKFoldDataSet(CvSplit & s)
648 {
649  s.MakeKFoldDataSet( DefaultDataSetInfo() );
650 }
651 
652 ////////////////////////////////////////////////////////////////////////////////
653 /// Function for assigning the correct folds to the testing or training set.
654 
655 void TMVA::DataLoader::PrepareFoldDataSet(CvSplit & s, UInt_t foldNumber, Types::ETreeType tt)
656 {
657  s.PrepareFoldDataSet( DefaultDataSetInfo(), foldNumber, tt );
658 }
659 
660 
661 ////////////////////////////////////////////////////////////////////////////////
662 /// Recombines the dataset. The precise semantics depend on the actual split.
663 ///
664 /// Similar to the inverse operation of `MakeKFoldDataSet` but _will_ differ.
665 /// See documentation for each particular split for more information.
666 ///
667 
668 void TMVA::DataLoader::RecombineKFoldDataSet(CvSplit & s, Types::ETreeType tt)
669 {
670  s.RecombineKFoldDataSet( DefaultDataSetInfo(), tt );
671 }
672 
673 ////////////////////////////////////////////////////////////////////////////////
674 /// Copy method use in VI and CV
675 
676 TMVA::DataLoader* TMVA::DataLoader::MakeCopy(TString name)
677 {
678  TMVA::DataLoader* des=new TMVA::DataLoader(name);
679  DataLoaderCopy(des,this);
680  return des;
681 }
682 
683 ////////////////////////////////////////////////////////////////////////////////
684 ///Loading Dataset from DataInputHandler for subseed
685 
686 void TMVA::DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src)
687 {
688  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();++treeinfo)
689  {
690  des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
691  }
692 
693  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();++treeinfo)
694  {
695  des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
696  }
697 }
698 
699 ////////////////////////////////////////////////////////////////////////////////
700 /// returns the correlation matrix of datasets
701 
702 TH2* TMVA::DataLoader::GetCorrelationMatrix(const TString& className)
703 {
704  const TMatrixD * m = DefaultDataSetInfo().CorrelationMatrix(className);
705  return DefaultDataSetInfo().CreateCorrelationMatrixHist(m,
706  "CorrelationMatrix"+className, "Correlation Matrix ("+className+")");
707 }