61 ClassImp(TMVA::DataLoader);
66 TMVA::DataLoader::DataLoader( TString thedlName)
68 fDataSetManager ( NULL ),
69 fDataInputHandler ( new DataInputHandler ),
70 fTransformations (
"I" ),
72 fDataAssignType ( kAssignEvents ),
75 fDataSetManager =
new DataSetManager( *fDataInputHandler );
76 SetName(thedlName.Data());
77 fLogger->SetSource(
"DataLoader");
82 TMVA::DataLoader::~DataLoader(
void )
86 std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
87 for (;trfIt != fDefaultTrfs.end(); ++trfIt)
delete (*trfIt);
89 delete fDataInputHandler;
93 delete fDataSetManager;
105 TMVA::DataSetInfo& TMVA::DataLoader::AddDataSet( DataSetInfo &dsi )
107 return fDataSetManager->AddDataSetInfo(dsi);
112 TMVA::DataSetInfo& TMVA::DataLoader::AddDataSet(
const TString& dsiName )
114 DataSetInfo* dsi = fDataSetManager->GetDataSetInfo(dsiName);
116 if (dsi!=0)
return *dsi;
118 return fDataSetManager->AddDataSetInfo(*(
new DataSetInfo(dsiName)));
123 TMVA::DataSetInfo& TMVA::DataLoader::GetDataSetInfo()
125 return DefaultDataSetInfo();
132 TMVA::DataLoader* TMVA::DataLoader::VarTransform(TString trafoDefinition)
134 TString trOptions =
"0";
135 TString trName =
"None";
136 if (trafoDefinition.Contains(
"(")) {
139 Ssiz_t parStart = trafoDefinition.Index(
"(" );
140 Ssiz_t parLen = trafoDefinition.Index(
")", parStart )-parStart+1;
142 trName = trafoDefinition(0,parStart);
143 trOptions = trafoDefinition(parStart,parLen);
144 trOptions.Remove(parLen-1,1);
145 trOptions.Remove(0,1);
148 trName = trafoDefinition;
150 VarTransformHandler* handler =
new VarTransformHandler(
this);
152 if (trName ==
"VT") {
155 Double_t threshold = 0.0;
156 if (!trOptions.IsFloat()){
157 Log() << kFATAL <<
" VT transformation must be passed a floating threshold value" << Endl;
162 threshold = trOptions.Atof();
163 TMVA::DataLoader *transformedLoader = handler->VarianceThreshold(threshold);
165 return transformedLoader;
168 Log() << kFATAL <<
"Incorrect transformation string provided, please check" << Endl;
170 Log() << kINFO <<
"No transformation applied, returning original loader" << Endl;
180 TTree* TMVA::DataLoader::CreateEventAssignTrees(
const TString& name )
182 TTree * assignTree =
new TTree( name, name );
183 assignTree->SetDirectory(0);
184 assignTree->Branch(
"type", &fATreeType,
"ATreeType/I" );
185 assignTree->Branch(
"weight", &fATreeWeight,
"ATreeWeight/F" );
187 std::vector<VariableInfo>& vars = DefaultDataSetInfo().GetVariableInfos();
188 std::vector<VariableInfo>& tgts = DefaultDataSetInfo().GetTargetInfos();
189 std::vector<VariableInfo>& spec = DefaultDataSetInfo().GetSpectatorInfos();
191 if (fATreeEvent.size()==0) fATreeEvent.resize(vars.size()+tgts.size()+spec.size());
193 for (UInt_t ivar=0; ivar<vars.size(); ivar++) {
194 TString vname = vars[ivar].GetExpression();
195 assignTree->Branch( vname, &fATreeEvent[ivar], vname +
"/F" );
198 for (UInt_t itgt=0; itgt<tgts.size(); itgt++) {
199 TString vname = tgts[itgt].GetExpression();
200 assignTree->Branch( vname, &fATreeEvent[vars.size()+itgt], vname +
"/F" );
203 for (UInt_t ispc=0; ispc<spec.size(); ispc++) {
204 TString vname = spec[ispc].GetExpression();
205 assignTree->Branch( vname, &fATreeEvent[vars.size()+tgts.size()+ispc], vname +
"/F" );
213 void TMVA::DataLoader::AddSignalTrainingEvent(
const std::vector<Double_t>& event, Double_t weight )
215 AddEvent(
"Signal", Types::kTraining, event, weight );
221 void TMVA::DataLoader::AddSignalTestEvent(
const std::vector<Double_t>& event, Double_t weight )
223 AddEvent(
"Signal", Types::kTesting, event, weight );
229 void TMVA::DataLoader::AddBackgroundTrainingEvent(
const std::vector<Double_t>& event, Double_t weight )
231 AddEvent(
"Background", Types::kTraining, event, weight );
237 void TMVA::DataLoader::AddBackgroundTestEvent(
const std::vector<Double_t>& event, Double_t weight )
239 AddEvent(
"Background", Types::kTesting, event, weight );
245 void TMVA::DataLoader::AddTrainingEvent(
const TString& className,
const std::vector<Double_t>& event, Double_t weight )
247 AddEvent( className, Types::kTraining, event, weight );
253 void TMVA::DataLoader::AddTestEvent(
const TString& className,
const std::vector<Double_t>& event, Double_t weight )
255 AddEvent( className, Types::kTesting, event, weight );
262 void TMVA::DataLoader::AddEvent(
const TString& className, Types::ETreeType tt,
263 const std::vector<Double_t>& event, Double_t weight )
265 ClassInfo* theClass = DefaultDataSetInfo().AddClass(className);
266 UInt_t clIndex = theClass->GetNumber();
270 if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
271 fAnalysisType = Types::kMulticlass;
274 if (clIndex>=fTrainAssignTree.size()) {
275 fTrainAssignTree.resize(clIndex+1, 0);
276 fTestAssignTree.resize(clIndex+1, 0);
279 if (fTrainAssignTree[clIndex]==0) {
280 fTrainAssignTree[clIndex] = CreateEventAssignTrees( Form(
"TrainAssignTree_%s", className.Data()) );
281 fTestAssignTree[clIndex] = CreateEventAssignTrees( Form(
"TestAssignTree_%s", className.Data()) );
284 fATreeType = clIndex;
285 fATreeWeight = weight;
286 for (UInt_t ivar=0; ivar<
event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
288 if(tt==Types::kTraining) fTrainAssignTree[clIndex]->Fill();
289 else fTestAssignTree[clIndex]->Fill();
296 Bool_t TMVA::DataLoader::UserAssignEvents(UInt_t clIndex)
298 return fTrainAssignTree[clIndex]!=0;
304 void TMVA::DataLoader::SetInputTreesFromEventAssignTrees()
306 UInt_t size = fTrainAssignTree.size();
307 for(UInt_t i=0; i<size; i++) {
308 if(!UserAssignEvents(i))
continue;
309 const TString& className = DefaultDataSetInfo().GetClassInfo(i)->GetName();
310 SetWeightExpression(
"weight", className );
311 AddTree(fTrainAssignTree[i], className, 1.0, TCut(
""), Types::kTraining );
312 AddTree(fTestAssignTree[i], className, 1.0, TCut(
""), Types::kTesting );
319 void TMVA::DataLoader::AddTree( TTree* tree,
const TString& className, Double_t weight,
320 const TCut& cut,
const TString& treetype )
322 Types::ETreeType tt = Types::kMaxTreeType;
323 TString tmpTreeType = treetype; tmpTreeType.ToLower();
324 if (tmpTreeType.Contains(
"train" ) && tmpTreeType.Contains(
"test" )) tt = Types::kMaxTreeType;
325 else if (tmpTreeType.Contains(
"train" )) tt = Types::kTraining;
326 else if (tmpTreeType.Contains(
"test" )) tt = Types::kTesting;
328 Log() << kFATAL <<
"<AddTree> cannot interpret tree type: \"" << treetype
329 <<
"\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl;
331 AddTree( tree, className, weight, cut, tt );
336 void TMVA::DataLoader::AddTree( TTree* tree,
const TString& className, Double_t weight,
337 const TCut& cut, Types::ETreeType tt )
340 Log() << kFATAL <<
"Tree does not exist (empty pointer)." << Endl;
342 DefaultDataSetInfo().AddClass( className );
345 if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
346 fAnalysisType = Types::kMulticlass;
348 Log() << kINFO<<
"Add Tree " << tree->GetName() <<
" of type " << className
349 <<
" with " << tree->GetEntries() <<
" events" << Endl;
350 DataInput().AddTree( tree, className, weight, cut, tt );
356 void TMVA::DataLoader::AddSignalTree( TTree* signal, Double_t weight, Types::ETreeType treetype )
358 AddTree( signal,
"Signal", weight, TCut(
""), treetype );
364 void TMVA::DataLoader::AddSignalTree( TString datFileS, Double_t weight, Types::ETreeType treetype )
367 TTree* signalTree =
new TTree(
"TreeS",
"Tree (S)" );
368 signalTree->ReadFile( datFileS );
370 Log() << kINFO <<
"Create TTree objects from ASCII input files ... \n- Signal file : \""
374 AddTree( signalTree,
"Signal", weight, TCut(
""), treetype );
379 void TMVA::DataLoader::AddSignalTree( TTree* signal, Double_t weight,
const TString& treetype )
381 AddTree( signal,
"Signal", weight, TCut(
""), treetype );
387 void TMVA::DataLoader::AddBackgroundTree( TTree* signal, Double_t weight, Types::ETreeType treetype )
389 AddTree( signal,
"Background", weight, TCut(
""), treetype );
395 void TMVA::DataLoader::AddBackgroundTree( TString datFileB, Double_t weight, Types::ETreeType treetype )
398 TTree* bkgTree =
new TTree(
"TreeB",
"Tree (B)" );
399 bkgTree->ReadFile( datFileB );
401 Log() << kINFO <<
"Create TTree objects from ASCII input files ... \n- Background file : \""
405 AddTree( bkgTree,
"Background", weight, TCut(
""), treetype );
410 void TMVA::DataLoader::AddBackgroundTree( TTree* signal, Double_t weight,
const TString& treetype )
412 AddTree( signal,
"Background", weight, TCut(
""), treetype );
417 void TMVA::DataLoader::SetSignalTree( TTree* tree, Double_t weight )
419 AddTree( tree,
"Signal", weight );
424 void TMVA::DataLoader::SetBackgroundTree( TTree* tree, Double_t weight )
426 AddTree( tree,
"Background", weight );
432 void TMVA::DataLoader::SetTree( TTree* tree,
const TString& className, Double_t weight )
434 AddTree( tree, className, weight, TCut(
""), Types::kMaxTreeType );
440 void TMVA::DataLoader::SetInputTrees( TTree* signal, TTree* background,
441 Double_t signalWeight, Double_t backgroundWeight )
443 AddTree( signal,
"Signal", signalWeight, TCut(
""), Types::kMaxTreeType );
444 AddTree( background,
"Background", backgroundWeight, TCut(
""), Types::kMaxTreeType );
449 void TMVA::DataLoader::SetInputTrees(
const TString& datFileS,
const TString& datFileB,
450 Double_t signalWeight, Double_t backgroundWeight )
452 DataInput().AddTree( datFileS,
"Signal", signalWeight );
453 DataInput().AddTree( datFileB,
"Background", backgroundWeight );
461 void TMVA::DataLoader::SetInputTrees( TTree* inputTree,
const TCut& SigCut,
const TCut& BgCut )
463 AddTree( inputTree,
"Signal", 1.0, SigCut, Types::kMaxTreeType );
464 AddTree( inputTree,
"Background", 1.0, BgCut , Types::kMaxTreeType );
470 void TMVA::DataLoader::AddVariable(
const TString& expression,
const TString& title,
const TString& unit,
471 char type, Double_t min, Double_t max )
473 DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type );
479 void TMVA::DataLoader::AddVariable(
const TString& expression,
char type,
480 Double_t min, Double_t max )
482 DefaultDataSetInfo().AddVariable( expression,
"",
"", min, max, type );
489 void TMVA::DataLoader::AddVariablesArray(
const TString &expression,
int size,
char type,
490 Double_t min, Double_t max)
492 DefaultDataSetInfo().AddVariablesArray(expression, size,
"",
"", min, max, type);
497 void TMVA::DataLoader::AddTarget(
const TString& expression,
const TString& title,
const TString& unit,
498 Double_t min, Double_t max )
500 if( fAnalysisType == Types::kNoAnalysisType )
501 fAnalysisType = Types::kRegression;
503 DefaultDataSetInfo().AddTarget( expression, title, unit, min, max );
509 void TMVA::DataLoader::AddSpectator(
const TString& expression,
const TString& title,
const TString& unit,
510 Double_t min, Double_t max )
512 DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max );
518 TMVA::DataSetInfo& TMVA::DataLoader::DefaultDataSetInfo()
520 return AddDataSet( fName );
526 void TMVA::DataLoader::SetInputVariables( std::vector<TString>* theVariables )
528 for (std::vector<TString>::iterator it=theVariables->begin();
529 it!=theVariables->end(); ++it) AddVariable(*it);
534 void TMVA::DataLoader::SetSignalWeightExpression(
const TString& variable)
536 DefaultDataSetInfo().SetWeightExpression(variable,
"Signal");
541 void TMVA::DataLoader::SetBackgroundWeightExpression(
const TString& variable)
543 DefaultDataSetInfo().SetWeightExpression(variable,
"Background");
548 void TMVA::DataLoader::SetWeightExpression(
const TString& variable,
const TString& className )
552 SetSignalWeightExpression(variable);
553 SetBackgroundWeightExpression(variable);
555 else DefaultDataSetInfo().SetWeightExpression( variable, className );
560 void TMVA::DataLoader::SetCut(
const TString& cut,
const TString& className ) {
561 SetCut( TCut(cut), className );
566 void TMVA::DataLoader::SetCut(
const TCut& cut,
const TString& className )
568 DefaultDataSetInfo().SetCut( cut, className );
573 void TMVA::DataLoader::AddCut(
const TString& cut,
const TString& className )
575 AddCut( TCut(cut), className );
579 void TMVA::DataLoader::AddCut(
const TCut& cut,
const TString& className )
581 DefaultDataSetInfo().AddCut( cut, className );
587 void TMVA::DataLoader::PrepareTrainingAndTestTree(
const TCut& cut,
588 Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
589 const TString& otherOpt )
591 SetInputTreesFromEventAssignTrees();
595 DefaultDataSetInfo().SetSplitOptions( Form(
"nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s",
596 NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.Data()) );
603 void TMVA::DataLoader::PrepareTrainingAndTestTree(
const TCut& cut, Int_t Ntrain, Int_t Ntest )
605 SetInputTreesFromEventAssignTrees();
609 DefaultDataSetInfo().SetSplitOptions( Form(
"nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V",
610 Ntrain, Ntrain, Ntest, Ntest) );
617 void TMVA::DataLoader::PrepareTrainingAndTestTree(
const TCut& cut,
const TString& opt )
619 SetInputTreesFromEventAssignTrees();
621 DefaultDataSetInfo().PrintClasses();
623 DefaultDataSetInfo().SetSplitOptions( opt );
629 void TMVA::DataLoader::PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut,
const TString& splitOpt )
632 SetInputTreesFromEventAssignTrees();
635 AddCut( sigcut,
"Signal" );
636 AddCut( bkgcut,
"Background" );
638 DefaultDataSetInfo().SetSplitOptions( splitOpt );
647 void TMVA::DataLoader::MakeKFoldDataSet(CvSplit & s)
649 s.MakeKFoldDataSet( DefaultDataSetInfo() );
655 void TMVA::DataLoader::PrepareFoldDataSet(CvSplit & s, UInt_t foldNumber, Types::ETreeType tt)
657 s.PrepareFoldDataSet( DefaultDataSetInfo(), foldNumber, tt );
668 void TMVA::DataLoader::RecombineKFoldDataSet(CvSplit & s, Types::ETreeType tt)
670 s.RecombineKFoldDataSet( DefaultDataSetInfo(), tt );
676 TMVA::DataLoader* TMVA::DataLoader::MakeCopy(TString name)
678 TMVA::DataLoader* des=
new TMVA::DataLoader(name);
679 DataLoaderCopy(des,
this);
686 void TMVA::DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src)
688 for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();++treeinfo)
690 des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
693 for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();++treeinfo)
695 des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
702 TH2* TMVA::DataLoader::GetCorrelationMatrix(
const TString& className)
704 const TMatrixD * m = DefaultDataSetInfo().CorrelationMatrix(className);
705 return DefaultDataSetInfo().CreateCorrelationMatrixHist(m,
706 "CorrelationMatrix"+className,
"Correlation Matrix ("+className+
")");