61 TMVA::DataSetInfo::DataSetInfo(
const TString& name)
63 fDataSetManager(NULL),
66 fNeedsRebuilding( kTRUE ),
71 fNormalization(
"NONE" ),
73 fTrainingSumSignalWeights(-1),
74 fTrainingSumBackgrWeights(-1),
75 fTestingSumSignalWeights (-1),
76 fTestingSumBackgrWeights (-1),
80 fTargetsForMulticlass(0),
81 fLogger( new MsgLogger(
"DataSetInfo", kINFO) )
83 std::cout <<
"create data set info " << name << std::endl;
89 TMVA::DataSetInfo::~DataSetInfo()
93 for(UInt_t i=0, iEnd = fClasses.size(); i<iEnd; ++i) {
97 delete fTargetsForMulticlass;
104 void TMVA::DataSetInfo::ClearDataSet()
const
106 if(fDataSet!=0) {
delete fDataSet; fDataSet=0; }
112 TMVA::DataSetInfo::SetMsgType( EMsgType t )
const
114 fLogger->SetMinType(t);
119 TMVA::ClassInfo* TMVA::DataSetInfo::AddClass(
const TString& className )
121 ClassInfo* theClass = GetClassInfo(className);
122 if (theClass)
return theClass;
125 fClasses.push_back(
new ClassInfo(className) );
126 fClasses.back()->SetNumber(fClasses.size()-1);
130 Log() << kHEADER << Form(
"[%s] : ",fName.Data()) <<
"Added class \"" << className <<
"\""<< Endl;
132 Log() << kDEBUG <<
"\t with internal class number " << fClasses.back()->GetNumber() << Endl;
135 if (className ==
"Signal") fSignalClass = fClasses.size()-1;
137 return fClasses.back();
142 TMVA::ClassInfo* TMVA::DataSetInfo::GetClassInfo(
const TString& name )
const
144 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
145 if ((*it)->GetName() == name)
return (*it);
152 TMVA::ClassInfo* TMVA::DataSetInfo::GetClassInfo( Int_t cls )
const
155 return fClasses.at(cls);
164 void TMVA::DataSetInfo::PrintClasses()
const
166 for (UInt_t cls = 0; cls < GetNClasses() ; cls++) {
167 Log() << kINFO << Form(
"Dataset[%s] : ",fName.Data()) <<
"Class index : " << cls <<
" name : " << GetClassInfo(cls)->GetName() << Endl;
173 Bool_t TMVA::DataSetInfo::IsSignal(
const TMVA::Event* ev )
const
175 return (ev->GetClass() == fSignalClass);
180 std::vector<Float_t>* TMVA::DataSetInfo::GetTargetsForMulticlass(
const TMVA::Event* ev )
182 if( !fTargetsForMulticlass ) fTargetsForMulticlass =
new std::vector<Float_t>( GetNClasses() );
184 fTargetsForMulticlass->assign( GetNClasses(), 0.0 );
185 fTargetsForMulticlass->at( ev->GetClass() ) = 1.0;
186 return fTargetsForMulticlass;
192 Bool_t TMVA::DataSetInfo::HasCuts()
const
194 Bool_t hasCuts = kFALSE;
195 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
196 if( TString((*it)->GetCut()) != TString(
"") ) hasCuts = kTRUE;
203 const TMatrixD* TMVA::DataSetInfo::CorrelationMatrix(
const TString& className )
const
205 ClassInfo* ptr = GetClassInfo(className);
206 return ptr?ptr->GetCorrelationMatrix():0;
213 TMVA::VariableInfo& TMVA::DataSetInfo::AddVariable(
const TString& expression,
214 const TString& title,
216 Double_t min, Double_t max,
221 TString regexpr = expression;
222 regexpr.ReplaceAll(
" ",
"" );
223 fVariables.push_back(VariableInfo( regexpr, title, unit,
224 fVariables.size()+1, varType, external, min, max, normalized ));
225 fNeedsRebuilding = kTRUE;
226 return fVariables.back();
232 TMVA::VariableInfo& TMVA::DataSetInfo::AddVariable(
const VariableInfo& varInfo){
233 fVariables.push_back(VariableInfo( varInfo ));
234 fNeedsRebuilding = kTRUE;
235 return fVariables.back();
241 void TMVA::DataSetInfo::AddVariablesArray(
const TString &expression, Int_t size,
const TString &title,
const TString &unit,
242 Double_t min, Double_t max,
char varType, Bool_t normalized,
245 TString regexpr = expression;
246 regexpr.ReplaceAll(
" ",
"");
247 fVariables.reserve(fVariables.size() + size);
248 for (
int i = 0; i < size; ++i) {
249 TString newTitle = title + TString::Format(
"[%d]", i);
251 fVariables.emplace_back(regexpr, newTitle, unit, fVariables.size() + 1, varType, external, min, max, normalized);
253 fVariables.back().SetBit(kIsArrayVariable);
254 TString newVarName = fVariables.back().GetInternalName() + TString::Format(
"[%d]", i);
255 fVariables.back().SetInternalName(newVarName);
257 fVarArrays[regexpr] = size;
258 fNeedsRebuilding = kTRUE;
265 TMVA::VariableInfo& TMVA::DataSetInfo::AddTarget(
const TString& expression,
266 const TString& title,
268 Double_t min, Double_t max,
272 TString regexpr = expression;
273 regexpr.ReplaceAll(
" ",
"" );
275 fTargets.push_back(VariableInfo( regexpr, title, unit,
276 fTargets.size()+1, type, external, min,
278 fNeedsRebuilding = kTRUE;
279 return fTargets.back();
285 TMVA::VariableInfo& TMVA::DataSetInfo::AddTarget(
const VariableInfo& varInfo){
286 fTargets.push_back(VariableInfo( varInfo ));
287 fNeedsRebuilding = kTRUE;
288 return fTargets.back();
295 TMVA::VariableInfo& TMVA::DataSetInfo::AddSpectator(
const TString& expression,
296 const TString& title,
298 Double_t min, Double_t max,
char type,
299 Bool_t normalized,
void* external )
301 TString regexpr = expression;
302 regexpr.ReplaceAll(
" ",
"" );
303 fSpectators.push_back(VariableInfo( regexpr, title, unit,
304 fSpectators.size()+1, type, external, min, max, normalized ));
305 fNeedsRebuilding = kTRUE;
306 return fSpectators.back();
312 TMVA::VariableInfo& TMVA::DataSetInfo::AddSpectator(
const VariableInfo& varInfo){
313 fSpectators.push_back(VariableInfo( varInfo ));
314 fNeedsRebuilding = kTRUE;
315 return fSpectators.back();
321 Int_t TMVA::DataSetInfo::FindVarIndex(
const TString& var)
const
323 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
324 if (var == GetVariableInfo(ivar).GetInternalName())
return ivar;
326 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
327 Log() << kINFO << Form(
"Dataset[%s] : ",fName.Data()) << GetVariableInfo(ivar).GetInternalName() << Endl;
329 Log() << kFATAL << Form(
"Dataset[%s] : ",fName.Data()) <<
"<FindVarIndex> Variable \'" << var <<
"\' not found." << Endl;
339 void TMVA::DataSetInfo::SetWeightExpression(
const TString& expr,
const TString& className )
341 if (className !=
"") {
342 TMVA::ClassInfo* ci = AddClass(className);
343 ci->SetWeight( expr );
347 if (fClasses.empty()) {
348 Log() << kWARNING << Form(
"Dataset[%s] : ",fName.Data()) <<
"No classes registered yet, cannot specify weight expression!" << Endl;
350 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
351 (*it)->SetWeight( expr );
358 void TMVA::DataSetInfo::SetCorrelationMatrix(
const TString& className, TMatrixD* matrix )
360 GetClassInfo(className)->SetCorrelationMatrix(matrix);
366 void TMVA::DataSetInfo::SetCut(
const TCut& cut,
const TString& className )
368 if (className ==
"") {
369 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
370 (*it)->SetCut( cut );
374 TMVA::ClassInfo* ci = AddClass(className);
382 void TMVA::DataSetInfo::AddCut(
const TCut& cut,
const TString& className )
384 if (className ==
"") {
385 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
386 const TCut& oldCut = (*it)->GetCut();
387 (*it)->SetCut( oldCut+cut );
391 TMVA::ClassInfo* ci = AddClass(className);
392 ci->SetCut( ci->GetCut()+cut );
399 std::vector<TString> TMVA::DataSetInfo::GetListOfVariables()
const
401 std::vector<TString> vNames;
402 std::vector<TMVA::VariableInfo>::const_iterator viIt = GetVariableInfos().begin();
403 for(;viIt != GetVariableInfos().end(); ++viIt) vNames.push_back( (*viIt).GetInternalName() );
412 void TMVA::DataSetInfo::PrintCorrelationMatrix(
const TString& className )
416 <<
"Correlation matrix (" << className <<
"):" << Endl;
417 gTools().FormattedOutput( *CorrelationMatrix( className ), GetListOfVariables(), Log() );
422 TH2* TMVA::DataSetInfo::CreateCorrelationMatrixHist(
const TMatrixD* m,
423 const TString& hName,
424 const TString& hTitle )
const
428 const UInt_t nvar = GetNVariables();
432 TMatrixF* tm =
new TMatrixF( nvar, nvar );
433 for (UInt_t ivar=0; ivar<nvar; ivar++) {
434 for (UInt_t jvar=0; jvar<nvar; jvar++) {
435 (*tm)(ivar, jvar) = (*m)(ivar,jvar);
439 TH2F* h2 =
new TH2F( *tm );
440 h2->SetNameTitle( hName, hTitle );
442 for (UInt_t ivar=0; ivar<nvar; ivar++) {
443 h2->GetXaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
444 h2->GetYaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
450 for (UInt_t ibin=1; ibin<=nvar; ibin++) {
451 for (UInt_t jbin=1; jbin<=nvar; jbin++) {
452 h2->SetBinContent( ibin, jbin, Int_t(h2->GetBinContent( ibin, jbin )) );
457 const Float_t labelSize = 0.055;
459 h2->GetXaxis()->SetLabelSize( labelSize );
460 h2->GetYaxis()->SetLabelSize( labelSize );
461 h2->SetMarkerSize( 1.5 );
462 h2->SetMarkerColor( 0 );
463 h2->LabelsOption(
"d" );
464 h2->SetLabelOffset( 0.011 );
465 h2->SetMinimum( -100.0 );
466 h2->SetMaximum( +100.0 );
478 Log() << kDEBUG << Form(
"Dataset[%s] : ",fName.Data()) <<
"Created correlation matrix as 2D histogram: " << h2->GetName() << Endl;
486 TMVA::DataSet* TMVA::DataSetInfo::GetDataSet()
const
488 if (fDataSet==0 || fNeedsRebuilding) {
489 if(fDataSet!=0) ClearDataSet();
491 if( !fDataSetManager )
492 Log() << kFATAL << Form(
"Dataset[%s] : ",fName.Data()) <<
"DataSetManager has not been set in DataSetInfo (GetDataSet() )." << Endl;
493 fDataSet = fDataSetManager->CreateDataSet(GetName());
495 fNeedsRebuilding = kFALSE;
502 UInt_t TMVA::DataSetInfo::GetNSpectators(
bool all)
const
505 return fSpectators.size();
507 for(std::vector<VariableInfo>::const_iterator spit=fSpectators.begin(); spit!=fSpectators.end(); ++spit) {
508 if(spit->GetVarType()!=
'C') nsp++;
515 Int_t TMVA::DataSetInfo::GetClassNameMaxLength()
const
518 for (UInt_t cl = 0; cl < GetNClasses(); cl++) {
519 if (TString(GetClassInfo(cl)->GetName()).Length() > maxL) maxL = TString(GetClassInfo(cl)->GetName()).Length();
527 Int_t TMVA::DataSetInfo::GetVariableNameMaxLength()
const
530 for (UInt_t i = 0; i < GetNVariables(); i++) {
531 if (TString(GetVariableInfo(i).GetExpression()).Length() > maxL) maxL = TString(GetVariableInfo(i).GetExpression()).Length();
539 Int_t TMVA::DataSetInfo::GetTargetNameMaxLength()
const
542 for (UInt_t i = 0; i < GetNTargets(); i++) {
543 if (TString(GetTargetInfo(i).GetExpression()).Length() > maxL) maxL = TString(GetTargetInfo(i).GetExpression()).Length();
551 Double_t TMVA::DataSetInfo::GetTrainingSumSignalWeights(){
552 if (fTrainingSumSignalWeights<0) Log() << kFATAL << Form(
"Dataset[%s] : ",fName.Data()) <<
" asking for the sum of training signal event weights which is not initialized yet" << Endl;
553 return fTrainingSumSignalWeights;
558 Double_t TMVA::DataSetInfo::GetTrainingSumBackgrWeights(){
559 if (fTrainingSumBackgrWeights<0) Log() << kFATAL << Form(
"Dataset[%s] : ",fName.Data()) <<
" asking for the sum of training backgr event weights which is not initialized yet" << Endl;
560 return fTrainingSumBackgrWeights;
565 Double_t TMVA::DataSetInfo::GetTestingSumSignalWeights (){
566 if (fTestingSumSignalWeights<0) Log() << kFATAL << Form(
"Dataset[%s] : ",fName.Data()) <<
" asking for the sum of testing signal event weights which is not initialized yet" << Endl;
567 return fTestingSumSignalWeights ;
572 Double_t TMVA::DataSetInfo::GetTestingSumBackgrWeights (){
573 if (fTestingSumBackgrWeights<0) Log() << kFATAL << Form(
"Dataset[%s] : ",fName.Data()) <<
" asking for the sum of testing backgr event weights which is not initialized yet" << Endl;
574 return fTestingSumBackgrWeights ;