82 Int_t LargestCommonDivider(Int_t a, Int_t b)
84 if (a<b) {Int_t tmp = a; a=b; b=tmp; }
87 return LargestCommonDivider(b,a-b*fullFits);
95 TMVA::DataSetFactory::DataSetFactory() :
97 fVerboseLevel(TString(
"Info")),
98 fScaleWithPreselEff(0),
102 fLogger( new MsgLogger(
"DataSetFactory", kINFO) )
109 TMVA::DataSetFactory::~DataSetFactory()
111 std::vector<TTreeFormula*>::const_iterator formIt;
113 for (formIt = fInputFormulas.begin() ; formIt!=fInputFormulas.end() ; ++formIt)
if (*formIt)
delete *formIt;
114 for (formIt = fTargetFormulas.begin() ; formIt!=fTargetFormulas.end() ; ++formIt)
if (*formIt)
delete *formIt;
115 for (formIt = fCutFormulas.begin() ; formIt!=fCutFormulas.end() ; ++formIt)
if (*formIt)
delete *formIt;
116 for (formIt = fWeightFormula.begin() ; formIt!=fWeightFormula.end() ; ++formIt)
if (*formIt)
delete *formIt;
117 for (formIt = fSpectatorFormulas.begin(); formIt!=fSpectatorFormulas.end(); ++formIt)
if (*formIt)
delete *formIt;
125 TMVA::DataSet* TMVA::DataSetFactory::CreateDataSet( TMVA::DataSetInfo& dsi,
126 TMVA::DataInputHandler& dataInput )
129 DataSet * ds = BuildInitialDataSet( dsi, dataInput );
131 if (ds->GetNEvents() > 1 && fComputeCorrelations ) {
135 for (UInt_t cl = 0; cl< dsi.GetNClasses(); cl++) {
136 const TString className = dsi.GetClassInfo(cl)->GetName();
137 dsi.SetCorrelationMatrix( className, CalcCorrelationMatrix( ds, cl ) );
139 dsi.PrintCorrelationMatrix(className);
143 Log() << kHEADER << Form(
"[%s] : ",dsi.GetName()) <<
" " << Endl << Endl;
151 TMVA::DataSet* TMVA::DataSetFactory::BuildDynamicDataSet( TMVA::DataSetInfo& dsi )
153 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"Build DataSet consisting of one Event with dynamically changing variables" << Endl;
154 DataSet* ds =
new DataSet(dsi);
158 if(dsi.GetNClasses()==0){
159 dsi.AddClass(
"data" );
160 dsi.GetClassInfo(
"data" )->SetNumber(0);
163 std::vector<Float_t*>* evdyn =
new std::vector<Float_t*>(0);
165 std::vector<VariableInfo>& varinfos = dsi.GetVariableInfos();
167 if (varinfos.empty())
168 Log() << kFATAL << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"Dynamic data set cannot be built, since no variable informations are present. Apparently no variables have been set. This should not happen, please contact the TMVA authors." << Endl;
170 std::vector<VariableInfo>::iterator it = varinfos.begin(), itEnd=varinfos.end();
171 for (;it!=itEnd;++it) {
172 Float_t* external=(Float_t*)(*it).GetExternalLink();
174 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"The link to the external variable is NULL while I am trying to build a dynamic data set. In this case fTmpEvent from MethodBase HAS TO BE USED in the method to get useful values in variables." << Endl;
175 else evdyn->push_back (external);
178 std::vector<VariableInfo>& spectatorinfos = dsi.GetSpectatorInfos();
179 it = spectatorinfos.begin();
180 for (;it!=spectatorinfos.end();++it) evdyn->push_back( (Float_t*)(*it).GetExternalLink() );
182 TMVA::Event * ev =
new Event((
const std::vector<Float_t*>*&)evdyn, varinfos.size());
183 std::vector<Event*>* newEventVector =
new std::vector<Event*>;
184 newEventVector->push_back(ev);
186 ds->SetEventCollection(newEventVector, Types::kTraining);
187 ds->SetCurrentType( Types::kTraining );
188 ds->SetCurrentEvent( 0 );
190 delete newEventVector;
199 TMVA::DataSetFactory::BuildInitialDataSet( DataSetInfo& dsi,
200 DataInputHandler& dataInput )
202 if (dataInput.GetEntries()==0)
return BuildDynamicDataSet( dsi );
207 std::vector< TString >* classList = dataInput.GetClassList();
208 for (std::vector<TString>::iterator it = classList->begin(); it< classList->end(); ++it) {
209 dsi.AddClass( (*it) );
213 EvtStatsPerClass eventCounts(dsi.GetNClasses());
219 InitOptions( dsi, eventCounts, normMode, splitSeed, splitMode , mixMode );
221 EventVectorOfClassesOfTreeType tmpEventVector;
222 BuildEventVector( dsi, dataInput, tmpEventVector, eventCounts );
224 DataSet* ds = MixEvents( dsi, tmpEventVector, eventCounts,
225 splitMode, mixMode, normMode, splitSeed );
227 const Bool_t showCollectedOutput = kFALSE;
228 if (showCollectedOutput) {
229 Int_t maxL = dsi.GetClassNameMaxLength();
230 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"Collected:" << Endl;
231 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
232 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" "
233 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
234 <<
" training entries: " << ds->GetNClassEvents( 0, cl ) << Endl;
235 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" "
236 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
237 <<
" testing entries: " << ds->GetNClassEvents( 1, cl ) << Endl;
239 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" " << Endl;
248 Bool_t TMVA::DataSetFactory::CheckTTreeFormula( TTreeFormula* ttf,
249 const TString& expression,
252 Bool_t worked = kTRUE;
254 if( ttf->GetNdim() <= 0 )
255 Log() << kFATAL <<
"Expression " << expression.Data()
256 <<
" could not be resolved to a valid formula. " << Endl;
257 if( ttf->GetNdata() == 0 ){
258 Log() << kWARNING <<
"Expression: " << expression.Data()
259 <<
" does not provide data for this event. "
260 <<
"This event is not taken into account. --> please check if you use as a variable "
261 <<
"an entry of an array which is not filled for some events "
262 <<
"(e.g. arr[4] when arr has only 3 elements)." << Endl;
263 Log() << kWARNING <<
"If you want to take the event into account you can do something like: "
264 <<
"\"Alt$(arr[4],0)\" where in cases where arr doesn't have a 4th element, "
265 <<
" 0 is taken as an alternative." << Endl;
268 if( expression.Contains(
"$") )
272 for (
int i = 0, iEnd = ttf->GetNcodes (); i < iEnd; ++i)
274 TLeaf* leaf = ttf->GetLeaf (i);
275 if (!leaf->IsOnTerminalBranch())
290 void TMVA::DataSetFactory::ChangeToNewTree( TreeInfo& tinfo,
const DataSetInfo & dsi )
292 TTree *tr = tinfo.GetTree()->GetTree();
295 tr->ResetBranchAddresses();
297 Bool_t hasDollar = kTRUE;
300 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" create input formulas for tree " << tr->GetName() << Endl;
301 std::vector<TTreeFormula*>::const_iterator formIt, formItEnd;
302 for (formIt = fInputFormulas.begin(), formItEnd=fInputFormulas.end(); formIt!=formItEnd; ++formIt)
if (*formIt)
delete *formIt;
303 fInputFormulas.clear();
304 TTreeFormula* ttf = 0;
305 fInputTableFormulas.clear();
307 bool firstArrayVar = kTRUE;
308 int firstArrayVarIndex = -1;
310 for (UInt_t i = 0; i < dsi.GetNVariables(); i++) {
313 if (! dsi.IsVariableFromArray(i) ) {
314 ttf =
new TTreeFormula(Form(
"Formula%s", dsi.GetVariableInfo(i).GetInternalName().Data()),
315 dsi.GetVariableInfo(i).GetExpression().Data(), tr);
316 CheckTTreeFormula(ttf, dsi.GetVariableInfo(i).GetExpression(), hasDollar);
317 fInputFormulas.emplace_back(ttf);
318 fInputTableFormulas.emplace_back(std::make_pair(ttf, (Int_t) 0));
324 ttf =
new TTreeFormula(Form(
"Formula%s", dsi.GetVariableInfo(i).GetInternalName().Data()),
325 dsi.GetVariableInfo(i).GetExpression().Data(), tr);
326 CheckTTreeFormula(ttf, dsi.GetVariableInfo(i).GetExpression(), hasDollar);
327 fInputFormulas.push_back(ttf);
329 arraySize = dsi.GetVarArraySize(dsi.GetVariableInfo(i).GetExpression());
330 firstArrayVar = kFALSE;
331 firstArrayVarIndex = i;
333 Log() << kINFO <<
"Using variable " << dsi.GetVariableInfo(i).GetInternalName() <<
334 " from array expression " << dsi.GetVariableInfo(i).GetExpression() <<
" of size " << arraySize << Endl;
336 fInputTableFormulas.push_back(std::make_pair(ttf, (Int_t) i-firstArrayVarIndex));
337 if (
int(i)-firstArrayVarIndex == arraySize-1 ) {
339 firstArrayVar = kTRUE;
340 firstArrayVarIndex = -1;
341 Log() << kDEBUG <<
"Using Last variable from array : " << dsi.GetVariableInfo(i).GetInternalName() << Endl;
350 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"transform regression targets" << Endl;
351 for (formIt = fTargetFormulas.begin(), formItEnd = fTargetFormulas.end(); formIt!=formItEnd; ++formIt)
if (*formIt)
delete *formIt;
352 fTargetFormulas.clear();
353 for (UInt_t i=0; i<dsi.GetNTargets(); i++) {
354 ttf =
new TTreeFormula( Form(
"Formula%s", dsi.GetTargetInfo(i).GetInternalName().Data() ),
355 dsi.GetTargetInfo(i).GetExpression().Data(), tr );
356 CheckTTreeFormula( ttf, dsi.GetTargetInfo(i).GetExpression(), hasDollar );
357 fTargetFormulas.push_back( ttf );
363 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"transform spectator variables" << Endl;
364 for (formIt = fSpectatorFormulas.begin(), formItEnd = fSpectatorFormulas.end(); formIt!=formItEnd; ++formIt)
if (*formIt)
delete *formIt;
365 fSpectatorFormulas.clear();
366 for (UInt_t i=0; i<dsi.GetNSpectators(); i++) {
367 ttf =
new TTreeFormula( Form(
"Formula%s", dsi.GetSpectatorInfo(i).GetInternalName().Data() ),
368 dsi.GetSpectatorInfo(i).GetExpression().Data(), tr );
369 CheckTTreeFormula( ttf, dsi.GetSpectatorInfo(i).GetExpression(), hasDollar );
370 fSpectatorFormulas.push_back( ttf );
376 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"transform cuts" << Endl;
377 for (formIt = fCutFormulas.begin(), formItEnd = fCutFormulas.end(); formIt!=formItEnd; ++formIt)
if (*formIt)
delete *formIt;
378 fCutFormulas.clear();
379 for (UInt_t clIdx=0; clIdx<dsi.GetNClasses(); clIdx++) {
380 const TCut& tmpCut = dsi.GetClassInfo(clIdx)->GetCut();
381 const TString tmpCutExp(tmpCut.GetTitle());
384 ttf =
new TTreeFormula( Form(
"CutClass%i",clIdx), tmpCutExp, tr );
385 Bool_t worked = CheckTTreeFormula( ttf, tmpCutExp, hasDollar );
387 Log() << kWARNING <<
"Please check class \"" << dsi.GetClassInfo(clIdx)->GetName()
388 <<
"\" cut \"" << dsi.GetClassInfo(clIdx)->GetCut() << Endl;
391 fCutFormulas.push_back( ttf );
397 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"transform weights" << Endl;
398 for (formIt = fWeightFormula.begin(), formItEnd = fWeightFormula.end(); formIt!=formItEnd; ++formIt)
if (*formIt)
delete *formIt;
399 fWeightFormula.clear();
400 for (UInt_t clIdx=0; clIdx<dsi.GetNClasses(); clIdx++) {
401 const TString tmpWeight = dsi.GetClassInfo(clIdx)->GetWeight();
403 if (dsi.GetClassInfo(clIdx)->GetName() != tinfo.GetClassName() ) {
404 fWeightFormula.push_back( 0 );
410 ttf =
new TTreeFormula(
"FormulaWeight", tmpWeight, tr );
411 Bool_t worked = CheckTTreeFormula( ttf, tmpWeight, hasDollar );
413 Log() << kWARNING << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"Please check class \"" << dsi.GetClassInfo(clIdx)->GetName()
414 <<
"\" weight \"" << dsi.GetClassInfo(clIdx)->GetWeight() << Endl;
420 fWeightFormula.push_back( ttf );
425 Log() << kDEBUG << Form(
"Dataset[%s] : ", dsi.GetName()) <<
"enable branches" << Endl;
429 tr->SetBranchStatus(
"*",0);
430 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"enable branches: input variables" << Endl;
432 for (formIt = fInputFormulas.begin(); formIt!=fInputFormulas.end(); ++formIt) {
434 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++) {
435 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
439 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"enable branches: targets" << Endl;
440 for (formIt = fTargetFormulas.begin(); formIt!=fTargetFormulas.end(); ++formIt) {
442 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
443 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
446 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"enable branches: spectators" << Endl;
447 for (formIt = fSpectatorFormulas.begin(); formIt!=fSpectatorFormulas.end(); ++formIt) {
449 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
450 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
453 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"enable branches: cuts" << Endl;
454 for (formIt = fCutFormulas.begin(); formIt!=fCutFormulas.end(); ++formIt) {
457 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
458 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
461 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"enable branches: weights" << Endl;
462 for (formIt = fWeightFormula.begin(); formIt!=fWeightFormula.end(); ++formIt) {
465 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
466 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
469 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"tree initialized" << Endl;
476 void TMVA::DataSetFactory::CalcMinMax( DataSet* ds, TMVA::DataSetInfo& dsi )
478 const UInt_t nvar = ds->GetNVariables();
479 const UInt_t ntgts = ds->GetNTargets();
480 const UInt_t nvis = ds->GetNSpectators();
482 Float_t *min =
new Float_t[nvar];
483 Float_t *max =
new Float_t[nvar];
484 Float_t *tgmin =
new Float_t[ntgts];
485 Float_t *tgmax =
new Float_t[ntgts];
486 Float_t *vmin =
new Float_t[nvis];
487 Float_t *vmax =
new Float_t[nvis];
489 for (UInt_t ivar=0; ivar<nvar ; ivar++) { min[ivar] = FLT_MAX; max[ivar] = -FLT_MAX; }
490 for (UInt_t ivar=0; ivar<ntgts; ivar++) { tgmin[ivar] = FLT_MAX; tgmax[ivar] = -FLT_MAX; }
491 for (UInt_t ivar=0; ivar<nvis; ivar++) { vmin[ivar] = FLT_MAX; vmax[ivar] = -FLT_MAX; }
495 for (Int_t i=0; i<ds->GetNEvents(); i++) {
496 const Event * ev = ds->GetEvent(i);
497 for (UInt_t ivar=0; ivar<nvar; ivar++) {
498 Double_t v = ev->GetValue(ivar);
499 if (v<min[ivar]) min[ivar] = v;
500 if (v>max[ivar]) max[ivar] = v;
502 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
503 Double_t v = ev->GetTarget(itgt);
504 if (v<tgmin[itgt]) tgmin[itgt] = v;
505 if (v>tgmax[itgt]) tgmax[itgt] = v;
507 for (UInt_t ivis=0; ivis<nvis; ivis++) {
508 Double_t v = ev->GetSpectator(ivis);
509 if (v<vmin[ivis]) vmin[ivis] = v;
510 if (v>vmax[ivis]) vmax[ivis] = v;
514 for (UInt_t ivar=0; ivar<nvar; ivar++) {
515 dsi.GetVariableInfo(ivar).SetMin(min[ivar]);
516 dsi.GetVariableInfo(ivar).SetMax(max[ivar]);
517 if( TMath::Abs(max[ivar]-min[ivar]) <= FLT_MIN )
518 Log() << kWARNING << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"Variable " << dsi.GetVariableInfo(ivar).GetExpression().Data() <<
" is constant. Please remove the variable." << Endl;
520 for (UInt_t ivar=0; ivar<ntgts; ivar++) {
521 dsi.GetTargetInfo(ivar).SetMin(tgmin[ivar]);
522 dsi.GetTargetInfo(ivar).SetMax(tgmax[ivar]);
523 if( TMath::Abs(tgmax[ivar]-tgmin[ivar]) <= FLT_MIN )
524 Log() << kFATAL << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"Target " << dsi.GetTargetInfo(ivar).GetExpression().Data() <<
" is constant. Please remove the variable." << Endl;
526 for (UInt_t ivar=0; ivar<nvis; ivar++) {
527 dsi.GetSpectatorInfo(ivar).SetMin(vmin[ivar]);
528 dsi.GetSpectatorInfo(ivar).SetMax(vmax[ivar]);
545 TMatrixD* TMVA::DataSetFactory::CalcCorrelationMatrix( DataSet* ds,
const UInt_t classNumber )
548 TMatrixD* mat = CalcCovarianceMatrix( ds, classNumber );
551 UInt_t nvar = ds->GetNVariables(), ivar, jvar;
553 for (ivar=0; ivar<nvar; ivar++) {
554 for (jvar=0; jvar<nvar; jvar++) {
556 Double_t d = (*mat)(ivar, ivar)*(*mat)(jvar, jvar);
557 if (d > 0) (*mat)(ivar, jvar) /= sqrt(d);
559 Log() << kWARNING << Form(
"Dataset[%s] : ",DataSetInfo().GetName())<<
"<GetCorrelationMatrix> Zero variances for variables "
560 <<
"(" << ivar <<
", " << jvar <<
") = " << d
562 (*mat)(ivar, jvar) = 0;
568 for (ivar=0; ivar<nvar; ivar++) (*mat)(ivar, ivar) = 1.0;
576 TMatrixD* TMVA::DataSetFactory::CalcCovarianceMatrix( DataSet * ds,
const UInt_t classNumber )
578 UInt_t nvar = ds->GetNVariables();
579 UInt_t ivar = 0, jvar = 0;
581 TMatrixD* mat =
new TMatrixD( nvar, nvar );
585 TMatrixD mat2(nvar, nvar);
586 for (ivar=0; ivar<nvar; ivar++) {
588 for (jvar=0; jvar<nvar; jvar++) mat2(ivar, jvar) = 0;
593 for (Int_t i=0; i<ds->GetNEvents(); i++) {
595 const Event * ev = ds->GetEvent(i);
596 if (ev->GetClass() != classNumber )
continue;
598 Double_t weight = ev->GetWeight();
601 for (ivar=0; ivar<nvar; ivar++) {
603 Double_t xi = ev->GetValue(ivar);
604 vec(ivar) += xi*weight;
605 mat2(ivar, ivar) += (xi*xi*weight);
607 for (jvar=ivar+1; jvar<nvar; jvar++) {
608 Double_t xj = ev->GetValue(jvar);
609 mat2(ivar, jvar) += (xi*xj*weight);
614 for (ivar=0; ivar<nvar; ivar++)
615 for (jvar=ivar+1; jvar<nvar; jvar++)
616 mat2(jvar, ivar) = mat2(ivar, jvar);
620 for (ivar=0; ivar<nvar; ivar++) {
621 for (jvar=0; jvar<nvar; jvar++) {
622 (*mat)(ivar, jvar) = mat2(ivar, jvar)/ic - vec(ivar)*vec(jvar)/(ic*ic);
635 TMVA::DataSetFactory::InitOptions( TMVA::DataSetInfo& dsi,
636 EvtStatsPerClass& nEventRequests,
642 Configurable splitSpecs( dsi.GetSplitOptions() );
643 splitSpecs.SetConfigName(
"DataSetFactory");
644 splitSpecs.SetConfigDescription(
"Configuration options given in the \"PrepareForTrainingAndTesting\" call; these options define the creation of the data sets used for training and expert validation by TMVA" );
646 splitMode =
"Random";
647 splitSpecs.DeclareOptionRef( splitMode,
"SplitMode",
648 "Method of picking training and testing events (default: random)" );
649 splitSpecs.AddPreDefVal(TString(
"Random"));
650 splitSpecs.AddPreDefVal(TString(
"Alternate"));
651 splitSpecs.AddPreDefVal(TString(
"Block"));
653 mixMode =
"SameAsSplitMode";
654 splitSpecs.DeclareOptionRef( mixMode,
"MixMode",
655 "Method of mixing events of different classes into one dataset (default: SameAsSplitMode)" );
656 splitSpecs.AddPreDefVal(TString(
"SameAsSplitMode"));
657 splitSpecs.AddPreDefVal(TString(
"Random"));
658 splitSpecs.AddPreDefVal(TString(
"Alternate"));
659 splitSpecs.AddPreDefVal(TString(
"Block"));
662 splitSpecs.DeclareOptionRef( splitSeed,
"SplitSeed",
663 "Seed for random event shuffling" );
665 normMode =
"EqualNumEvents";
666 splitSpecs.DeclareOptionRef( normMode,
"NormMode",
667 "Overall renormalisation of event-by-event weights used in the training (NumEvents: average weight of 1 per event, independently for signal and background; EqualNumEvents: average weight of 1 per event for signal, and sum of weights for background equal to sum of weights for signal)" );
668 splitSpecs.AddPreDefVal(TString(
"None"));
669 splitSpecs.AddPreDefVal(TString(
"NumEvents"));
670 splitSpecs.AddPreDefVal(TString(
"EqualNumEvents"));
672 splitSpecs.DeclareOptionRef(fScaleWithPreselEff=kFALSE,
"ScaleWithPreselEff",
"Scale the number of requested events by the eff. of the preselection cuts (or not)" );
677 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
678 TString clName = dsi.GetClassInfo(cl)->GetName();
679 TString titleTrain = TString().Format(
"Number of training events of class %s (default: 0 = all)",clName.Data()).Data();
680 TString titleTest = TString().Format(
"Number of test events of class %s (default: 0 = all)",clName.Data()).Data();
681 TString titleSplit = TString().Format(
"Split in training and test events of class %s (default: 0 = deactivated)",clName.Data()).Data();
683 splitSpecs.DeclareOptionRef( nEventRequests.at(cl).nTrainingEventsRequested, TString(
"nTrain_")+clName, titleTrain );
684 splitSpecs.DeclareOptionRef( nEventRequests.at(cl).nTestingEventsRequested , TString(
"nTest_")+clName , titleTest );
685 splitSpecs.DeclareOptionRef( nEventRequests.at(cl).TrainTestSplitRequested , TString(
"TrainTestSplit_")+clName , titleTest );
688 splitSpecs.DeclareOptionRef( fVerbose,
"V",
"Verbosity (default: true)" );
690 splitSpecs.DeclareOptionRef( fVerboseLevel=TString(
"Info"),
"VerboseLevel",
"VerboseLevel (Debug/Verbose/Info)" );
691 splitSpecs.AddPreDefVal(TString(
"Debug"));
692 splitSpecs.AddPreDefVal(TString(
"Verbose"));
693 splitSpecs.AddPreDefVal(TString(
"Info"));
695 fCorrelations = kTRUE;
696 splitSpecs.DeclareOptionRef(fCorrelations,
"Correlations",
"Boolean to show correlation output (Default: true)");
697 fComputeCorrelations = kTRUE;
698 splitSpecs.DeclareOptionRef(fComputeCorrelations,
"CalcCorrelations",
"Compute correlations and also some variable statistics, e.g. min/max (Default: true )");
700 splitSpecs.ParseOptions();
701 splitSpecs.CheckForUnusedOptions();
704 if (Verbose()) fLogger->SetMinType( kVERBOSE );
705 if (fVerboseLevel.CompareTo(
"Debug") ==0) fLogger->SetMinType( kDEBUG );
706 if (fVerboseLevel.CompareTo(
"Verbose") ==0) fLogger->SetMinType( kVERBOSE );
707 if (fVerboseLevel.CompareTo(
"Info") ==0) fLogger->SetMinType( kINFO );
710 splitMode.ToUpper(); mixMode.ToUpper(); normMode.ToUpper();
713 <<
"\tSplitmode is: \"" << splitMode <<
"\" the mixmode is: \"" << mixMode <<
"\"" << Endl;
714 if (mixMode==
"SAMEASSPLITMODE") mixMode = splitMode;
715 else if (mixMode!=splitMode)
716 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"DataSet splitmode="<<splitMode
717 <<
" differs from mixmode="<<mixMode<<Endl;
725 TMVA::DataSetFactory::BuildEventVector( TMVA::DataSetInfo& dsi,
726 TMVA::DataInputHandler& dataInput,
727 EventVectorOfClassesOfTreeType& eventsmap,
728 EvtStatsPerClass& eventCounts)
730 const UInt_t nclasses = dsi.GetNClasses();
732 eventsmap[ Types::kTraining ] = EventVectorOfClasses(nclasses);
733 eventsmap[ Types::kTesting ] = EventVectorOfClasses(nclasses);
734 eventsmap[ Types::kMaxTreeType ] = EventVectorOfClasses(nclasses);
737 const UInt_t nvars = dsi.GetNVariables();
738 const UInt_t ntgts = dsi.GetNTargets();
739 const UInt_t nvis = dsi.GetNSpectators();
741 for (
size_t i=0; i<nclasses; i++) {
742 eventCounts[i].varAvLength =
new Float_t[nvars];
743 for (UInt_t ivar=0; ivar<nvars; ivar++)
744 eventCounts[i].varAvLength[ivar] = 0;
754 std::map<TString, int> nanInfWarnings;
755 std::map<TString, int> nanInfErrors;
759 for (UInt_t cl=0; cl<nclasses; cl++) {
763 EventStats& classEventCounts = eventCounts[cl];
767 <<
"\tWeight expression for class \'" << dsi.GetClassInfo(cl)->GetName() <<
"\': \""
768 << dsi.GetClassInfo(cl)->GetWeight() <<
"\"" << Endl;
771 TString currentFileName(
"");
773 std::vector<TreeInfo>::const_iterator treeIt(dataInput.begin(dsi.GetClassInfo(cl)->GetName()));
774 for (;treeIt!=dataInput.end(dsi.GetClassInfo(cl)->GetName()); ++treeIt) {
777 std::vector<Float_t> vars(nvars);
778 std::vector<Float_t> tgts(ntgts);
779 std::vector<Float_t> vis(nvis);
780 TreeInfo currentInfo = *treeIt;
782 Log() << kINFO <<
"Building event vectors for type " << currentInfo.GetTreeType() <<
" " << currentInfo.GetClassName() << Endl;
784 EventVector& event_v = eventsmap[currentInfo.GetTreeType()].at(cl);
786 Bool_t isChain = (TString(
"TChain") == currentInfo.GetTree()->ClassName());
787 currentInfo.GetTree()->LoadTree(0);
789 ChangeToNewTree( currentInfo, dsi );
792 classEventCounts.nInitialEvents += currentInfo.GetTree()->GetEntries();
795 const UInt_t nEvts = currentInfo.GetTree()->GetEntries();
796 for (Long64_t evtIdx = 0; evtIdx < nEvts; evtIdx++) {
797 currentInfo.GetTree()->LoadTree(evtIdx);
801 if (currentInfo.GetTree()->GetTree()->GetDirectory()->GetFile()->GetName() != currentFileName) {
802 currentFileName = currentInfo.GetTree()->GetTree()->GetDirectory()->GetFile()->GetName();
803 ChangeToNewTree( currentInfo, dsi );
806 currentInfo.GetTree()->GetEntry(evtIdx);
807 Int_t sizeOfArrays = 1;
808 Int_t prevArrExpr = 0;
809 Bool_t haveAllArrayData = kFALSE;
820 for (UInt_t ivar = 0; ivar < nvars; ivar++) {
822 if (dsi.IsVariableFromArray(ivar))
continue;
823 auto inputFormula = fInputTableFormulas[ivar].first;
825 Int_t ndata = inputFormula->GetNdata();
827 classEventCounts.varAvLength[ivar] += ndata;
828 if (ndata == 1)
continue;
829 haveAllArrayData = kTRUE;
832 if (sizeOfArrays == 1) {
833 sizeOfArrays = ndata;
836 else if (sizeOfArrays!=ndata) {
837 Log() << kERROR << Form(
"Dataset[%s] : ",dsi.GetName())<<
"ERROR while preparing training and testing trees:" << Endl;
838 Log() << Form(
"Dataset[%s] : ",dsi.GetName())<<
" multiple array-type expressions of different length were encountered" << Endl;
839 Log() << Form(
"Dataset[%s] : ",dsi.GetName())<<
" location of error: event " << evtIdx
840 <<
" in tree " << currentInfo.GetTree()->GetName()
841 <<
" of file " << currentInfo.GetTree()->GetCurrentFile()->GetName() << Endl;
842 Log() << Form(
"Dataset[%s] : ",dsi.GetName())<<
" expression " << inputFormula->GetTitle() <<
" has "
843 << Form(
"Dataset[%s] : ",dsi.GetName()) << ndata <<
" entries, while" << Endl;
844 Log() << Form(
"Dataset[%s] : ",dsi.GetName())<<
" expression " << fInputTableFormulas[prevArrExpr].first->GetTitle() <<
" has "
845 << Form(
"Dataset[%s] : ",dsi.GetName())<< fInputTableFormulas[prevArrExpr].first->GetNdata() <<
" entries" << Endl;
846 Log() << kFATAL << Form(
"Dataset[%s] : ",dsi.GetName())<<
"Need to abort" << Endl;
851 for (Int_t idata = 0; idata<sizeOfArrays; idata++) {
852 Bool_t contains_NaN_or_inf = kFALSE;
854 auto checkNanInf = [&](std::map<TString, int> &msgMap, Float_t value,
const char *what,
const char *formulaTitle) {
855 if (TMath::IsNaN(value)) {
856 contains_NaN_or_inf = kTRUE;
857 ++msgMap[TString::Format(
"Dataset[%s] : %s expression resolves to indeterminate value (NaN): %s", dsi.GetName(), what, formulaTitle)];
858 }
else if (!TMath::Finite(value)) {
859 contains_NaN_or_inf = kTRUE;
860 ++msgMap[TString::Format(
"Dataset[%s] : %s expression resolves to infinite value (+inf or -inf): %s", dsi.GetName(), what, formulaTitle)];
864 TTreeFormula* formula = 0;
867 Double_t cutVal = 1.;
868 formula = fCutFormulas[cl];
870 Int_t ndata = formula->GetNdata();
872 formula->EvalInstance(0) :
873 formula->EvalInstance(idata));
874 checkNanInf(nanInfErrors, cutVal,
"Cut", formula->GetTitle());
878 auto &nanMessages = cutVal < 0.5 ? nanInfWarnings : nanInfErrors;
881 for (UInt_t ivar=0; ivar<nvars; ivar++) {
882 auto formulaMap = fInputTableFormulas[ivar];
883 formula = formulaMap.first;
884 int inputVarIndex = formulaMap.second;
885 formula->SetQuickLoad(
true);
887 vars[ivar] = ( !haveAllArrayData ?
888 formula->EvalInstance(inputVarIndex) :
889 formula->EvalInstance(idata));
890 checkNanInf(nanMessages, vars[ivar],
"Input", formula->GetTitle());
894 for (UInt_t itrgt=0; itrgt<ntgts; itrgt++) {
895 formula = fTargetFormulas[itrgt];
896 Int_t ndata = formula->GetNdata();
897 tgts[itrgt] = (ndata == 1 ?
898 formula->EvalInstance(0) :
899 formula->EvalInstance(idata));
900 checkNanInf(nanMessages, tgts[itrgt],
"Target", formula->GetTitle());
904 for (UInt_t itVis=0; itVis<nvis; itVis++) {
905 formula = fSpectatorFormulas[itVis];
906 Int_t ndata = formula->GetNdata();
907 vis[itVis] = (ndata == 1 ?
908 formula->EvalInstance(0) :
909 formula->EvalInstance(idata));
910 checkNanInf(nanMessages, vis[itVis],
"Spectator", formula->GetTitle());
915 Float_t weight = currentInfo.GetWeight();
916 formula = fWeightFormula[cl];
918 Int_t ndata = formula->GetNdata();
919 weight *= (ndata == 1 ?
920 formula->EvalInstance() :
921 formula->EvalInstance(idata));
922 checkNanInf(nanMessages, weight,
"Weight", formula->GetTitle());
927 classEventCounts.nEvBeforeCut++;
928 if (!TMath::IsNaN(weight))
929 classEventCounts.nWeEvBeforeCut += weight;
932 if (cutVal<0.5)
continue;
937 if (weight < 0) classEventCounts.nNegWeights++;
941 if (contains_NaN_or_inf) {
942 Log() << kWARNING << Form(
"Dataset[%s] : ",dsi.GetName())<<
"NaN or +-inf in Event " << evtIdx << Endl;
943 if (sizeOfArrays>1) Log() << kWARNING << Form(
"Dataset[%s] : ",dsi.GetName())<<
" rejected" << Endl;
949 classEventCounts.nEvAfterCut++;
950 classEventCounts.nWeEvAfterCut += weight;
953 event_v.push_back(
new Event(vars, tgts , vis, cl , weight));
956 currentInfo.GetTree()->ResetBranchAddresses();
960 if (!nanInfWarnings.empty()) {
961 Log() << kWARNING <<
"Found events with NaN and/or +-inf values" << Endl;
962 for (
const auto &warning : nanInfWarnings) {
963 auto &log = Log() << kWARNING << warning.first;
964 if (warning.second > 1) log <<
" (" << warning.second <<
" times)";
967 Log() << kWARNING <<
"These NaN and/or +-infs were all removed by the specified cut, continuing." << Endl;
971 if (!nanInfErrors.empty()) {
972 Log() << kWARNING <<
"Found events with NaN and/or +-inf values (not removed by cut)" << Endl;
973 for (
const auto &error : nanInfErrors) {
974 auto &log = Log() << kWARNING << error.first;
975 if (error.second > 1) log <<
" (" << error.second <<
" times)";
978 Log() << kFATAL <<
"How am I supposed to train a NaN or +-inf?!" << Endl;
982 Int_t maxL = dsi.GetClassNameMaxLength();
984 Log() << kHEADER << Form(
"[%s] : ",dsi.GetName()) <<
"Number of events in input trees" << Endl;
985 Log() << kDEBUG <<
"(after possible flattening of arrays):" << Endl;
988 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
991 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
992 <<
" -- number of events : "
993 << std::setw(5) << eventCounts[cl].nEvBeforeCut
994 <<
" / sum of weights: " << std::setw(5) << eventCounts[cl].nWeEvBeforeCut << Endl;
997 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
999 <<
" " << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
1000 <<
" tree -- total number of entries: "
1001 << std::setw(5) << dataInput.GetEntries(dsi.GetClassInfo(cl)->GetName()) << Endl;
1004 if (fScaleWithPreselEff)
1006 <<
"\tPreselection: (will affect number of requested training and testing events)" << Endl;
1009 <<
"\tPreselection: (will NOT affect number of requested training and testing events)" << Endl;
1011 if (dsi.HasCuts()) {
1012 for (UInt_t cl = 0; cl< dsi.GetNClasses(); cl++) {
1013 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" " << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
1014 <<
" requirement: \"" << dsi.GetClassInfo(cl)->GetCut() <<
"\"" << Endl;
1015 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" "
1016 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
1017 <<
" -- number of events passed: "
1018 << std::setw(5) << eventCounts[cl].nEvAfterCut
1019 <<
" / sum of weights: " << std::setw(5) << eventCounts[cl].nWeEvAfterCut << Endl;
1020 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" "
1021 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
1022 <<
" -- efficiency : "
1023 << std::setw(6) << eventCounts[cl].nWeEvAfterCut/eventCounts[cl].nWeEvBeforeCut << Endl;
1026 else Log() << kDEBUG
1027 <<
" No preselection cuts applied on event classes" << Endl;
1037 TMVA::DataSetFactory::MixEvents( DataSetInfo& dsi,
1038 EventVectorOfClassesOfTreeType& tmpEventVector,
1039 EvtStatsPerClass& eventCounts,
1040 const TString& splitMode,
1041 const TString& mixMode,
1042 const TString& normMode,
1045 TMVA::RandomGenerator<TRandom3> rndm(splitSeed);
1050 if (splitMode.Contains(
"RANDOM" ) ) {
1052 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
1053 EventVector& unspecifiedEvents = tmpEventVector[Types::kMaxTreeType].at(cls);
1054 if( ! unspecifiedEvents.empty() ) {
1055 Log() << kDEBUG <<
"randomly shuffling "
1056 << unspecifiedEvents.size()
1057 <<
" events of class " << cls
1058 <<
" which are not yet associated to testing or training" << Endl;
1059 std::shuffle(unspecifiedEvents.begin(), unspecifiedEvents.end(), rndm);
1065 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"SPLITTING ========" << Endl;
1066 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
1067 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"---- class " << cls << Endl;
1068 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"check number of training/testing events, requested and available number of events and for class " << cls << Endl;
1071 EventVector& eventVectorTraining = tmpEventVector[ Types::kTraining ].at(cls);
1072 EventVector& eventVectorTesting = tmpEventVector[ Types::kTesting ].at(cls);
1073 EventVector& eventVectorUndefined = tmpEventVector[ Types::kMaxTreeType ].at(cls);
1075 Int_t availableTraining = eventVectorTraining.size();
1076 Int_t availableTesting = eventVectorTesting.size();
1077 Int_t availableUndefined = eventVectorUndefined.size();
1079 Float_t presel_scale;
1080 if (fScaleWithPreselEff) {
1081 presel_scale = eventCounts[cls].cutScaling();
1082 if (presel_scale < 1)
1083 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" you have opted for scaling the number of requested training/testing events\n to be scaled by the preselection efficiency"<< Endl;
1086 if (eventCounts[cls].cutScaling() < 1)
1087 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" you have opted for interpreting the requested number of training/testing events\n to be the number of events AFTER your preselection cuts" << Endl;
1094 if(eventCounts[cls].TrainTestSplitRequested < 1.0 && eventCounts[cls].TrainTestSplitRequested > 0.0){
1095 eventCounts[cls].nTrainingEventsRequested = Int_t(eventCounts[cls].TrainTestSplitRequested*(availableTraining+availableTesting+availableUndefined));
1096 eventCounts[cls].nTestingEventsRequested = Int_t(0);
1098 else if(eventCounts[cls].TrainTestSplitRequested != 0.0) Log() << kFATAL << Form(
"The option TrainTestSplit_<class> has to be in range (0, 1] but is set to %f.",eventCounts[cls].TrainTestSplitRequested) << Endl;
1099 Int_t requestedTraining = Int_t(eventCounts[cls].nTrainingEventsRequested * presel_scale);
1100 Int_t requestedTesting = Int_t(eventCounts[cls].nTestingEventsRequested * presel_scale);
1102 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"events in training trees : " << availableTraining << Endl;
1103 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"events in testing trees : " << availableTesting << Endl;
1104 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"events in unspecified trees : " << availableUndefined << Endl;
1105 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"requested for training : " << requestedTraining << Endl;;
1108 Log() <<
" ( " << eventCounts[cls].nTrainingEventsRequested
1109 <<
" * " << presel_scale <<
" preselection efficiency)" << Endl;
1112 Log() << kDEBUG <<
"requested for testing : " << requestedTesting;
1114 Log() <<
" ( " << eventCounts[cls].nTestingEventsRequested
1115 <<
" * " << presel_scale <<
" preselection efficiency)" << Endl;
1166 Int_t useForTesting(0),useForTraining(0);
1167 Int_t allAvailable(availableUndefined + availableTraining + availableTesting);
1169 if( (requestedTraining == 0) && (requestedTesting == 0)){
1173 if ( availableUndefined >= TMath::Abs(availableTraining - availableTesting) ) {
1175 useForTraining = useForTesting = allAvailable/2;
1178 useForTraining = availableTraining;
1179 useForTesting = availableTesting;
1180 if (availableTraining < availableTesting)
1181 useForTraining += availableUndefined;
1183 useForTesting += availableUndefined;
1185 requestedTraining = useForTraining;
1186 requestedTesting = useForTesting;
1189 else if (requestedTesting == 0){
1191 useForTraining = TMath::Max(requestedTraining,availableTraining);
1192 if (allAvailable < useForTraining) {
1193 Log() << kFATAL << Form(
"Dataset[%s] : ",dsi.GetName())<<
"More events requested for training ("
1194 << requestedTraining <<
") than available ("
1195 << allAvailable <<
")!" << Endl;
1197 useForTesting = allAvailable - useForTraining;
1198 requestedTesting = useForTesting;
1201 else if (requestedTraining == 0){
1202 useForTesting = TMath::Max(requestedTesting,availableTesting);
1203 if (allAvailable < useForTesting) {
1204 Log() << kFATAL << Form(
"Dataset[%s] : ",dsi.GetName())<<
"More events requested for testing ("
1205 << requestedTesting <<
") than available ("
1206 << allAvailable <<
")!" << Endl;
1208 useForTraining= allAvailable - useForTesting;
1209 requestedTraining = useForTraining;
1218 Int_t stillNeedForTraining = TMath::Max(requestedTraining-availableTraining,0);
1219 Int_t stillNeedForTesting = TMath::Max(requestedTesting-availableTesting,0);
1221 int NFree = availableUndefined - stillNeedForTraining - stillNeedForTesting;
1222 if (NFree <0) NFree = 0;
1223 useForTraining = TMath::Max(requestedTraining,availableTraining) + NFree/2;
1224 useForTesting= allAvailable - useForTraining;
1227 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"determined event sample size to select training sample from="<<useForTraining<<Endl;
1228 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"determined event sample size to select test sample from="<<useForTesting<<Endl;
1233 if( splitMode ==
"ALTERNATE" ){
1234 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"split 'ALTERNATE'" << Endl;
1235 Int_t nTraining = availableTraining;
1236 Int_t nTesting = availableTesting;
1237 for( EventVector::iterator it = eventVectorUndefined.begin(), itEnd = eventVectorUndefined.end(); it != itEnd; ){
1239 if( nTraining <= requestedTraining ){
1240 eventVectorTraining.insert( eventVectorTraining.end(), (*it) );
1245 eventVectorTesting.insert( eventVectorTesting.end(), (*it) );
1250 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"split '" << splitMode <<
"'" << Endl;
1253 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"availableundefined : " << availableUndefined << Endl;
1254 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"useForTraining : " << useForTraining << Endl;
1255 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"useForTesting : " << useForTesting << Endl;
1256 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"availableTraining : " << availableTraining << Endl;
1257 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"availableTesting : " << availableTesting << Endl;
1259 if( availableUndefined<(useForTraining-availableTraining) ||
1260 availableUndefined<(useForTesting -availableTesting ) ||
1261 availableUndefined<(useForTraining+useForTesting-availableTraining-availableTesting ) ){
1262 Log() << kFATAL << Form(
"Dataset[%s] : ",dsi.GetName())<<
"More events requested than available!" << Endl;
1266 if (useForTraining>availableTraining){
1267 eventVectorTraining.insert( eventVectorTraining.end() , eventVectorUndefined.begin(), eventVectorUndefined.begin()+ useForTraining- availableTraining );
1268 eventVectorUndefined.erase( eventVectorUndefined.begin(), eventVectorUndefined.begin() + useForTraining- availableTraining);
1270 if (useForTesting>availableTesting){
1271 eventVectorTesting.insert( eventVectorTesting.end() , eventVectorUndefined.begin(), eventVectorUndefined.begin()+ useForTesting- availableTesting );
1274 eventVectorUndefined.clear();
1277 if (splitMode.Contains(
"RANDOM" )){
1278 UInt_t sizeTraining = eventVectorTraining.size();
1279 if( sizeTraining > UInt_t(requestedTraining) ){
1280 std::vector<UInt_t> indicesTraining( sizeTraining );
1282 std::generate( indicesTraining.begin(), indicesTraining.end(), TMVA::Increment<UInt_t>(0) );
1284 std::shuffle(indicesTraining.begin(), indicesTraining.end(), rndm);
1286 indicesTraining.erase( indicesTraining.begin()+sizeTraining-UInt_t(requestedTraining), indicesTraining.end() );
1288 for( std::vector<UInt_t>::iterator it = indicesTraining.begin(), itEnd = indicesTraining.end(); it != itEnd; ++it ){
1289 delete eventVectorTraining.at( (*it) );
1290 eventVectorTraining.at( (*it) ) = NULL;
1293 eventVectorTraining.erase( std::remove( eventVectorTraining.begin(), eventVectorTraining.end(), (
void*)NULL ), eventVectorTraining.end() );
1296 UInt_t sizeTesting = eventVectorTesting.size();
1297 if( sizeTesting > UInt_t(requestedTesting) ){
1298 std::vector<UInt_t> indicesTesting( sizeTesting );
1300 std::generate( indicesTesting.begin(), indicesTesting.end(), TMVA::Increment<UInt_t>(0) );
1302 std::shuffle(indicesTesting.begin(), indicesTesting.end(), rndm);
1304 indicesTesting.erase( indicesTesting.begin()+sizeTesting-UInt_t(requestedTesting), indicesTesting.end() );
1306 for( std::vector<UInt_t>::iterator it = indicesTesting.begin(), itEnd = indicesTesting.end(); it != itEnd; ++it ){
1307 delete eventVectorTesting.at( (*it) );
1308 eventVectorTesting.at( (*it) ) = NULL;
1311 eventVectorTesting.erase( std::remove( eventVectorTesting.begin(), eventVectorTesting.end(), (
void*)NULL ), eventVectorTesting.end() );
1315 if( eventVectorTraining.size() < UInt_t(requestedTraining) )
1316 Log() << kWARNING << Form(
"Dataset[%s] : ",dsi.GetName())<<
"DataSetFactory/requested number of training samples larger than size of eventVectorTraining.\n"
1317 <<
"There is probably an issue. Please contact the TMVA developers." << Endl;
1318 std::for_each( eventVectorTraining.begin()+requestedTraining, eventVectorTraining.end(), DeleteFunctor<Event>() );
1319 eventVectorTraining.erase(eventVectorTraining.begin()+requestedTraining,eventVectorTraining.end());
1321 if( eventVectorTesting.size() < UInt_t(requestedTesting) )
1322 Log() << kWARNING << Form(
"Dataset[%s] : ",dsi.GetName())<<
"DataSetFactory/requested number of testing samples larger than size of eventVectorTesting.\n"
1323 <<
"There is probably an issue. Please contact the TMVA developers." << Endl;
1324 std::for_each( eventVectorTesting.begin()+requestedTesting, eventVectorTesting.end(), DeleteFunctor<Event>() );
1325 eventVectorTesting.erase(eventVectorTesting.begin()+requestedTesting,eventVectorTesting.end());
1329 TMVA::DataSetFactory::RenormEvents( dsi, tmpEventVector, eventCounts, normMode );
1331 Int_t trainingSize = 0;
1332 Int_t testingSize = 0;
1335 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
1336 trainingSize += tmpEventVector[Types::kTraining].at(cls).size();
1337 testingSize += tmpEventVector[Types::kTesting].at(cls).size();
1343 EventVector* trainingEventVector =
new EventVector();
1344 EventVector* testingEventVector =
new EventVector();
1346 trainingEventVector->reserve( trainingSize );
1347 testingEventVector->reserve( testingSize );
1353 Log() << kDEBUG <<
" MIXING ============= " << Endl;
1355 if( mixMode ==
"ALTERNATE" ){
1358 for( UInt_t cls = 1; cls < dsi.GetNClasses(); ++cls ){
1359 if (tmpEventVector[Types::kTraining].at(cls).size() != tmpEventVector[Types::kTraining].at(0).size()){
1360 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"Training sample: You are trying to mix events in alternate mode although the classes have different event numbers. This works but the alternation stops at the last event of the smaller class."<<Endl;
1362 if (tmpEventVector[Types::kTesting].at(cls).size() != tmpEventVector[Types::kTesting].at(0).size()){
1363 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"Testing sample: You are trying to mix events in alternate mode although the classes have different event numbers. This works but the alternation stops at the last event of the smaller class."<<Endl;
1366 typedef EventVector::iterator EvtVecIt;
1367 EvtVecIt itEvent, itEventEnd;
1370 Log() << kDEBUG <<
"insert class 0 into training and test vector" << Endl;
1371 trainingEventVector->insert( trainingEventVector->end(), tmpEventVector[Types::kTraining].at(0).begin(), tmpEventVector[Types::kTraining].at(0).end() );
1372 testingEventVector->insert( testingEventVector->end(), tmpEventVector[Types::kTesting].at(0).begin(), tmpEventVector[Types::kTesting].at(0).end() );
1376 for( UInt_t cls = 1; cls < dsi.GetNClasses(); ++cls ){
1377 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"insert class " << cls << Endl;
1379 itTarget = trainingEventVector->begin() - 1;
1381 for( itEvent = tmpEventVector[Types::kTraining].at(cls).begin(), itEventEnd = tmpEventVector[Types::kTraining].at(cls).end(); itEvent != itEventEnd; ++itEvent ){
1383 if( (trainingEventVector->end() - itTarget) < Int_t(cls+1) ) {
1384 itTarget = trainingEventVector->end();
1385 trainingEventVector->insert( itTarget, itEvent, itEventEnd );
1389 trainingEventVector->insert( itTarget, (*itEvent) );
1393 itTarget = testingEventVector->begin() - 1;
1395 for( itEvent = tmpEventVector[Types::kTesting].at(cls).begin(), itEventEnd = tmpEventVector[Types::kTesting].at(cls).end(); itEvent != itEventEnd; ++itEvent ){
1397 if( ( testingEventVector->end() - itTarget ) < Int_t(cls+1) ) {
1398 itTarget = testingEventVector->end();
1399 testingEventVector->insert( itTarget, itEvent, itEventEnd );
1403 testingEventVector->insert( itTarget, (*itEvent) );
1408 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
1409 trainingEventVector->insert( trainingEventVector->end(), tmpEventVector[Types::kTraining].at(cls).begin(), tmpEventVector[Types::kTraining].at(cls).end() );
1410 testingEventVector->insert ( testingEventVector->end(), tmpEventVector[Types::kTesting].at(cls).begin(), tmpEventVector[Types::kTesting].at(cls).end() );
1414 tmpEventVector[Types::kTraining].clear();
1415 tmpEventVector[Types::kTesting].clear();
1417 tmpEventVector[Types::kMaxTreeType].clear();
1419 if (mixMode ==
"RANDOM") {
1420 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"shuffling events"<<Endl;
1422 std::shuffle(trainingEventVector->begin(), trainingEventVector->end(), rndm);
1423 std::shuffle(testingEventVector->begin(), testingEventVector->end(), rndm);
1426 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"trainingEventVector " << trainingEventVector->size() << Endl;
1427 Log() << kDEBUG << Form(
"Dataset[%s] : ",dsi.GetName())<<
"testingEventVector " << testingEventVector->size() << Endl;
1430 DataSet* ds =
new DataSet(dsi);
1433 ds->SetEventCollection(trainingEventVector, Types::kTraining );
1435 ds->SetEventCollection(testingEventVector, Types::kTesting );
1438 if (ds->GetNTrainingEvents() < 1){
1439 Log() << kFATAL <<
"Dataset " << std::string(dsi.GetName()) <<
" does not have any training events, I better stop here and let you fix that one first " << Endl;
1442 if (ds->GetNTestEvents() < 1) {
1443 Log() << kERROR <<
"Dataset " << std::string(dsi.GetName()) <<
" does not have any testing events, guess that will cause problems later..but for now, I continue " << Endl;
1446 delete trainingEventVector;
1447 delete testingEventVector;
1461 TMVA::DataSetFactory::RenormEvents( TMVA::DataSetInfo& dsi,
1462 EventVectorOfClassesOfTreeType& tmpEventVector,
1463 const EvtStatsPerClass& eventCounts,
1464 const TString& normMode )
1471 Int_t trainingSize = 0;
1472 Int_t testingSize = 0;
1474 ValuePerClass trainingSumWeightsPerClass( dsi.GetNClasses() );
1475 ValuePerClass testingSumWeightsPerClass( dsi.GetNClasses() );
1477 NumberPerClass trainingSizePerClass( dsi.GetNClasses() );
1478 NumberPerClass testingSizePerClass( dsi.GetNClasses() );
1480 Double_t trainingSumSignalWeights = 0;
1481 Double_t trainingSumBackgrWeights = 0;
1482 Double_t testingSumSignalWeights = 0;
1483 Double_t testingSumBackgrWeights = 0;
1487 for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
1488 trainingSizePerClass.at(cls) = tmpEventVector[Types::kTraining].at(cls).size();
1489 testingSizePerClass.at(cls) = tmpEventVector[Types::kTesting].at(cls).size();
1491 trainingSize += trainingSizePerClass.back();
1492 testingSize += testingSizePerClass.back();
1505 trainingSumWeightsPerClass.at(cls) =
1506 std::accumulate(tmpEventVector[Types::kTraining].at(cls).begin(),
1507 tmpEventVector[Types::kTraining].at(cls).end(),
1508 Double_t(0), [](Double_t w,
const TMVA::Event *E) {
return w + E->GetOriginalWeight(); });
1510 testingSumWeightsPerClass.at(cls) =
1511 std::accumulate(tmpEventVector[Types::kTesting].at(cls).begin(),
1512 tmpEventVector[Types::kTesting].at(cls).end(),
1513 Double_t(0), [](Double_t w,
const TMVA::Event *E) {
return w + E->GetOriginalWeight(); });
1515 if ( cls == dsi.GetSignalClassIndex()){
1516 trainingSumSignalWeights += trainingSumWeightsPerClass.at(cls);
1517 testingSumSignalWeights += testingSumWeightsPerClass.at(cls);
1519 trainingSumBackgrWeights += trainingSumWeightsPerClass.at(cls);
1520 testingSumBackgrWeights += testingSumWeightsPerClass.at(cls);
1527 ValuePerClass renormFactor( dsi.GetNClasses() );
1531 dsi.SetNormalization( normMode );
1534 dsi.SetTrainingSumSignalWeights(trainingSumSignalWeights);
1535 dsi.SetTrainingSumBackgrWeights(trainingSumBackgrWeights);
1536 dsi.SetTestingSumSignalWeights(testingSumSignalWeights);
1537 dsi.SetTestingSumBackgrWeights(testingSumBackgrWeights);
1540 if (normMode ==
"NONE") {
1541 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"No weight renormalisation applied: use original global and event weights" << Endl;
1547 else if (normMode ==
"NUMEVENTS") {
1549 <<
"\tWeight renormalisation mode: \"NumEvents\": renormalises all event classes " << Endl;
1551 <<
" such that the effective (weighted) number of events in each class equals the respective " << Endl;
1553 <<
" number of events (entries) that you demanded in PrepareTrainingAndTestTree(\"\",\"nTrain_Signal=.. )" << Endl;
1555 <<
" ... i.e. such that Sum[i=1..N_j]{w_i} = N_j, j=0,1,2..." << Endl;
1557 <<
" ... (note that N_j is the sum of TRAINING events (nTrain_j...with j=Signal,Background.." << Endl;
1559 <<
" ..... Testing events are not renormalised nor included in the renormalisation factor! )"<< Endl;
1561 for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
1565 renormFactor.at(cls) = ((Float_t)trainingSizePerClass.at(cls) )/
1566 (trainingSumWeightsPerClass.at(cls)) ;
1569 else if (normMode ==
"EQUALNUMEVENTS") {
1575 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"Weight renormalisation mode: \"EqualNumEvents\": renormalises all event classes ..." << Endl;
1576 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" such that the effective (weighted) number of events in each class is the same " << Endl;
1577 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" (and equals the number of events (entries) given for class=0 )" << Endl;
1578 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"... i.e. such that Sum[i=1..N_j]{w_i} = N_classA, j=classA, classB, ..." << Endl;
1579 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
"... (note that N_j is the sum of TRAINING events" << Endl;
1580 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) <<
" ..... Testing events are not renormalised nor included in the renormalisation factor!)" << Endl;
1583 UInt_t referenceClass = 0;
1584 for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ) {
1585 renormFactor.at(cls) = Float_t(trainingSizePerClass.at(referenceClass))/
1586 (trainingSumWeightsPerClass.at(cls));
1590 Log() << kFATAL << Form(
"Dataset[%s] : ",dsi.GetName())<<
"<PrepareForTrainingAndTesting> Unknown NormMode: " << normMode << Endl;
1595 Int_t maxL = dsi.GetClassNameMaxLength();
1596 for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls<clsEnd; ++cls) {
1598 <<
"--> Rescale " << setiosflags(ios::left) << std::setw(maxL)
1599 << dsi.GetClassInfo(cls)->GetName() <<
" event weights by factor: " << renormFactor.at(cls) << Endl;
1600 for (EventVector::iterator it = tmpEventVector[Types::kTraining].at(cls).begin(),
1601 itEnd = tmpEventVector[Types::kTraining].at(cls).end(); it != itEnd; ++it){
1602 (*it)->SetWeight ((*it)->GetWeight() * renormFactor.at(cls));
1613 <<
"Number of training and testing events" << Endl;
1614 Log() << kDEBUG <<
"\tafter rescaling:" << Endl;
1616 <<
"---------------------------------------------------------------------------" << Endl;
1618 trainingSumSignalWeights = 0;
1619 trainingSumBackgrWeights = 0;
1620 testingSumSignalWeights = 0;
1621 testingSumBackgrWeights = 0;
1623 for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
1624 trainingSumWeightsPerClass.at(cls) =
1625 std::accumulate(tmpEventVector[Types::kTraining].at(cls).begin(),
1626 tmpEventVector[Types::kTraining].at(cls).end(),
1627 Double_t(0), [](Double_t w,
const TMVA::Event *E) {
return w + E->GetOriginalWeight(); });
1629 testingSumWeightsPerClass.at(cls) =
1630 std::accumulate(tmpEventVector[Types::kTesting].at(cls).begin(),
1631 tmpEventVector[Types::kTesting].at(cls).end(),
1632 Double_t(0), [](Double_t w,
const TMVA::Event *E) {
return w + E->GetOriginalWeight(); });
1634 if ( cls == dsi.GetSignalClassIndex()){
1635 trainingSumSignalWeights += trainingSumWeightsPerClass.at(cls);
1636 testingSumSignalWeights += testingSumWeightsPerClass.at(cls);
1638 trainingSumBackgrWeights += trainingSumWeightsPerClass.at(cls);
1639 testingSumBackgrWeights += testingSumWeightsPerClass.at(cls);
1645 << setiosflags(ios::left) << std::setw(maxL)
1646 << dsi.GetClassInfo(cls)->GetName() <<
" -- "
1647 <<
"training events : " << trainingSizePerClass.at(cls) << Endl;
1648 Log() << kDEBUG <<
"\t(sum of weights: " << trainingSumWeightsPerClass.at(cls) <<
")"
1649 <<
" - requested were " << eventCounts[cls].nTrainingEventsRequested <<
" events" << Endl;
1651 << setiosflags(ios::left) << std::setw(maxL)
1652 << dsi.GetClassInfo(cls)->GetName() <<
" -- "
1653 <<
"testing events : " << testingSizePerClass.at(cls) << Endl;
1654 Log() << kDEBUG <<
"\t(sum of weights: " << testingSumWeightsPerClass.at(cls) <<
")"
1655 <<
" - requested were " << eventCounts[cls].nTestingEventsRequested <<
" events" << Endl;
1657 << setiosflags(ios::left) << std::setw(maxL)
1658 << dsi.GetClassInfo(cls)->GetName() <<
" -- "
1659 <<
"training and testing events: "
1660 << (trainingSizePerClass.at(cls)+testingSizePerClass.at(cls)) << Endl;
1661 Log() << kDEBUG <<
"\t(sum of weights: "
1662 << (trainingSumWeightsPerClass.at(cls)+testingSumWeightsPerClass.at(cls)) <<
")" << Endl;
1663 if(eventCounts[cls].nEvAfterCut<eventCounts[cls].nEvBeforeCut) {
1664 Log() << kINFO << Form(
"Dataset[%s] : ",dsi.GetName()) << setiosflags(ios::left) << std::setw(maxL)
1665 << dsi.GetClassInfo(cls)->GetName() <<
" -- "
1666 <<
"due to the preselection a scaling factor has been applied to the numbers of requested events: "
1667 << eventCounts[cls].cutScaling() << Endl;
1670 Log() << kINFO << Endl;
1673 dsi.SetTrainingSumSignalWeights(trainingSumSignalWeights);
1674 dsi.SetTrainingSumBackgrWeights(trainingSumBackgrWeights);
1675 dsi.SetTestingSumSignalWeights(testingSumSignalWeights);
1676 dsi.SetTestingSumBackgrWeights(testingSumBackgrWeights);