99 const Int_t MinNoTrainingEvents = 10;
102 ClassImp(TMVA::Factory);
104 #define READXML kTRUE
119 TMVA::Factory::Factory( TString jobName, TFile* theTargetFile, TString theOption )
120 : Configurable ( theOption ),
121 fTransformations (
"I" ),
123 fVerboseLevel ( kINFO ),
124 fCorrelations ( kFALSE ),
126 fSilentFile ( theTargetFile == nullptr ),
127 fJobName ( jobName ),
128 fAnalysisType ( Types::kClassification ),
129 fModelPersistence (kTRUE)
132 fgTargetFile = theTargetFile;
133 fLogger->SetSource(fName.Data());
136 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput();
140 SetConfigDescription(
"Configuration options for Factory running" );
141 SetConfigName( GetName() );
146 Bool_t silent = kFALSE;
149 Bool_t color = kFALSE;
150 Bool_t drawProgressBar = kFALSE;
152 Bool_t color = !gROOT->IsBatch();
153 Bool_t drawProgressBar = kTRUE;
155 DeclareOptionRef( fVerbose,
"V",
"Verbose flag" );
156 DeclareOptionRef( fVerboseLevel=TString(
"Info"),
"VerboseLevel",
"VerboseLevel (Debug/Verbose/Info)" );
157 AddPreDefVal(TString(
"Debug"));
158 AddPreDefVal(TString(
"Verbose"));
159 AddPreDefVal(TString(
"Info"));
160 DeclareOptionRef( color,
"Color",
"Flag for coloured screen output (default: True, if in batch mode: False)" );
161 DeclareOptionRef( fTransformations,
"Transformations",
"List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
162 DeclareOptionRef( fCorrelations,
"Correlations",
"boolean to show correlation in output" );
163 DeclareOptionRef( fROC,
"ROC",
"boolean to show ROC in output" );
164 DeclareOptionRef( silent,
"Silent",
"Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
165 DeclareOptionRef( drawProgressBar,
166 "DrawProgressBar",
"Draw progress bar to display training, testing and evaluation schedule (default: True)" );
167 DeclareOptionRef( fModelPersistence,
169 "Option to save the trained model in xml file or using serialization");
171 TString analysisType(
"Auto");
172 DeclareOptionRef( analysisType,
173 "AnalysisType",
"Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
174 AddPreDefVal(TString(
"Classification"));
175 AddPreDefVal(TString(
"Regression"));
176 AddPreDefVal(TString(
"Multiclass"));
177 AddPreDefVal(TString(
"Auto"));
180 CheckForUnusedOptions();
182 if (Verbose()) fLogger->SetMinType( kVERBOSE );
183 if (fVerboseLevel.CompareTo(
"Debug") ==0) fLogger->SetMinType( kDEBUG );
184 if (fVerboseLevel.CompareTo(
"Verbose") ==0) fLogger->SetMinType( kVERBOSE );
185 if (fVerboseLevel.CompareTo(
"Info") ==0) fLogger->SetMinType( kINFO );
188 gConfig().SetUseColor( color );
189 gConfig().SetSilent( silent );
190 gConfig().SetDrawProgressBar( drawProgressBar );
192 analysisType.ToLower();
193 if ( analysisType ==
"classification" ) fAnalysisType = Types::kClassification;
194 else if( analysisType ==
"regression" ) fAnalysisType = Types::kRegression;
195 else if( analysisType ==
"multiclass" ) fAnalysisType = Types::kMulticlass;
196 else if( analysisType ==
"auto" ) fAnalysisType = Types::kNoAnalysisType;
204 TMVA::Factory::Factory( TString jobName, TString theOption )
205 : Configurable ( theOption ),
206 fTransformations (
"I" ),
208 fCorrelations ( kFALSE ),
210 fSilentFile ( kTRUE ),
211 fJobName ( jobName ),
212 fAnalysisType ( Types::kClassification ),
213 fModelPersistence (kTRUE)
216 fgTargetFile =
nullptr;
217 fLogger->SetSource(fName.Data());
221 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput();
225 SetConfigDescription(
"Configuration options for Factory running" );
226 SetConfigName( GetName() );
230 TH1::AddDirectory(kFALSE);
231 Bool_t silent = kFALSE;
234 Bool_t color = kFALSE;
235 Bool_t drawProgressBar = kFALSE;
237 Bool_t color = !gROOT->IsBatch();
238 Bool_t drawProgressBar = kTRUE;
240 DeclareOptionRef( fVerbose,
"V",
"Verbose flag" );
241 DeclareOptionRef( fVerboseLevel=TString(
"Info"),
"VerboseLevel",
"VerboseLevel (Debug/Verbose/Info)" );
242 AddPreDefVal(TString(
"Debug"));
243 AddPreDefVal(TString(
"Verbose"));
244 AddPreDefVal(TString(
"Info"));
245 DeclareOptionRef( color,
"Color",
"Flag for coloured screen output (default: True, if in batch mode: False)" );
246 DeclareOptionRef( fTransformations,
"Transformations",
"List of transformations to test; formatting example: \"Transformations=I;D;P;U;G,D\", for identity, decorrelation, PCA, Uniform and Gaussianisation followed by decorrelation transformations" );
247 DeclareOptionRef( fCorrelations,
"Correlations",
"boolean to show correlation in output" );
248 DeclareOptionRef( fROC,
"ROC",
"boolean to show ROC in output" );
249 DeclareOptionRef( silent,
"Silent",
"Batch mode: boolean silent flag inhibiting any output from TMVA after the creation of the factory class object (default: False)" );
250 DeclareOptionRef( drawProgressBar,
251 "DrawProgressBar",
"Draw progress bar to display training, testing and evaluation schedule (default: True)" );
252 DeclareOptionRef( fModelPersistence,
254 "Option to save the trained model in xml file or using serialization");
256 TString analysisType(
"Auto");
257 DeclareOptionRef( analysisType,
258 "AnalysisType",
"Set the analysis type (Classification, Regression, Multiclass, Auto) (default: Auto)" );
259 AddPreDefVal(TString(
"Classification"));
260 AddPreDefVal(TString(
"Regression"));
261 AddPreDefVal(TString(
"Multiclass"));
262 AddPreDefVal(TString(
"Auto"));
265 CheckForUnusedOptions();
267 if (Verbose()) fLogger->SetMinType( kVERBOSE );
268 if (fVerboseLevel.CompareTo(
"Debug") ==0) fLogger->SetMinType( kDEBUG );
269 if (fVerboseLevel.CompareTo(
"Verbose") ==0) fLogger->SetMinType( kVERBOSE );
270 if (fVerboseLevel.CompareTo(
"Info") ==0) fLogger->SetMinType( kINFO );
273 gConfig().SetUseColor( color );
274 gConfig().SetSilent( silent );
275 gConfig().SetDrawProgressBar( drawProgressBar );
277 analysisType.ToLower();
278 if ( analysisType ==
"classification" ) fAnalysisType = Types::kClassification;
279 else if( analysisType ==
"regression" ) fAnalysisType = Types::kRegression;
280 else if( analysisType ==
"multiclass" ) fAnalysisType = Types::kMulticlass;
281 else if( analysisType ==
"auto" ) fAnalysisType = Types::kNoAnalysisType;
290 void TMVA::Factory::Greetings()
292 gTools().ROOTVersionMessage( Log() );
293 gTools().TMVAWelcomeMessage( Log(), gTools().kLogoWelcomeMsg );
294 gTools().TMVAVersionMessage( Log() ); Log() << Endl;
300 TMVA::Factory::~Factory(
void )
302 std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
303 for (;trfIt != fDefaultTrfs.end(); ++trfIt)
delete (*trfIt);
305 this->DeleteAllMethods();
318 void TMVA::Factory::DeleteAllMethods(
void )
320 std::map<TString,MVector*>::iterator itrMap;
322 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
324 MVector *methods=itrMap->second;
326 MVector::iterator itrMethod = methods->begin();
327 for (; itrMethod != methods->end(); ++itrMethod) {
328 Log() << kDEBUG <<
"Delete method: " << (*itrMethod)->GetName() << Endl;
338 void TMVA::Factory::SetVerbose( Bool_t v )
346 TMVA::MethodBase* TMVA::Factory::BookMethod( TMVA::DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption )
348 if(fModelPersistence) gSystem->MakeDirectory(loader->GetName());
350 TString datasetname=loader->GetName();
352 if( fAnalysisType == Types::kNoAnalysisType ){
353 if( loader->GetDataSetInfo().GetNClasses()==2
354 && loader->GetDataSetInfo().GetClassInfo(
"Signal") != NULL
355 && loader->GetDataSetInfo().GetClassInfo(
"Background") != NULL
357 fAnalysisType = Types::kClassification;
358 }
else if( loader->GetDataSetInfo().GetNClasses() >= 2 ){
359 fAnalysisType = Types::kMulticlass;
361 Log() << kFATAL <<
"No analysis type for " << loader->GetDataSetInfo().GetNClasses() <<
" classes and "
362 << loader->GetDataSetInfo().GetNTargets() <<
" regression targets." << Endl;
368 if(fMethodsMap.find(datasetname)!=fMethodsMap.end())
370 if (GetMethod( datasetname,methodTitle ) != 0) {
371 Log() << kFATAL <<
"Booking failed since method with title <"
372 << methodTitle <<
"> already exists "<<
"in with DataSet Name <"<< loader->GetName()<<
"> "
378 Log() << kHEADER <<
"Booking method: " << gTools().Color(
"bold") << methodTitle
380 << gTools().Color(
"reset") << Endl << Endl;
384 TMVA::Configurable* conf =
new TMVA::Configurable( theOption );
385 conf->DeclareOptionRef( boostNum = 0,
"Boost_num",
386 "Number of times the classifier will be boosted" );
387 conf->ParseOptions();
391 if(fModelPersistence)
394 TString prefix = gConfig().GetIONames().fWeightFileDirPrefix;
396 if (!prefix.IsNull())
397 if (fileDir[fileDir.Length()-1] !=
'/') fileDir +=
"/";
398 fileDir += loader->GetName();
399 fileDir +=
"/" + gConfig().GetIONames().fWeightFileDir;
404 im = ClassifierFactory::Instance().Create(theMethodName.Data(), fJobName, methodTitle,
405 loader->GetDataSetInfo(), theOption);
409 Log() << kDEBUG <<
"Boost Number is " << boostNum <<
" > 0: train boosted classifier" << Endl;
410 im = ClassifierFactory::Instance().Create(
"Boost", fJobName, methodTitle, loader->GetDataSetInfo(), theOption);
411 MethodBoost *methBoost =
dynamic_cast<MethodBoost *
>(im);
413 Log() << kFATAL <<
"Method with type kBoost cannot be casted to MethodCategory. /Factory" << Endl;
416 if (fModelPersistence) methBoost->SetWeightFileDir(fileDir);
417 methBoost->SetModelPersistence(fModelPersistence);
418 methBoost->SetBoostedMethodName(theMethodName);
419 methBoost->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager();
420 methBoost->SetFile(fgTargetFile);
421 methBoost->SetSilentFile(IsSilentFile());
424 MethodBase *method =
dynamic_cast<MethodBase*
>(im);
425 if (method==0)
return 0;
428 if (method->GetMethodType() == Types::kCategory) {
429 MethodCategory *methCat = (
dynamic_cast<MethodCategory*
>(im));
431 Log() << kFATAL <<
"Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
434 if(fModelPersistence) methCat->SetWeightFileDir(fileDir);
435 methCat->SetModelPersistence(fModelPersistence);
436 methCat->fDataSetManager = loader->GetDataSetInfo().GetDataSetManager();
437 methCat->SetFile(fgTargetFile);
438 methCat->SetSilentFile(IsSilentFile());
442 if (!method->HasAnalysisType( fAnalysisType,
443 loader->GetDataSetInfo().GetNClasses(),
444 loader->GetDataSetInfo().GetNTargets() )) {
445 Log() << kWARNING <<
"Method " << method->GetMethodTypeName() <<
" is not capable of handling " ;
446 if (fAnalysisType == Types::kRegression) {
447 Log() <<
"regression with " << loader->GetDataSetInfo().GetNTargets() <<
" targets." << Endl;
449 else if (fAnalysisType == Types::kMulticlass ) {
450 Log() <<
"multiclass classification with " << loader->GetDataSetInfo().GetNClasses() <<
" classes." << Endl;
453 Log() <<
"classification with " << loader->GetDataSetInfo().GetNClasses() <<
" classes." << Endl;
458 if(fModelPersistence) method->SetWeightFileDir(fileDir);
459 method->SetModelPersistence(fModelPersistence);
460 method->SetAnalysisType( fAnalysisType );
461 method->SetupMethod();
462 method->ParseOptions();
463 method->ProcessSetup();
464 method->SetFile(fgTargetFile);
465 method->SetSilentFile(IsSilentFile());
468 method->CheckSetup();
470 if(fMethodsMap.find(datasetname)==fMethodsMap.end())
472 MVector *mvector=
new MVector;
473 fMethodsMap[datasetname]=mvector;
475 fMethodsMap[datasetname]->push_back( method );
485 TMVA::MethodBase* TMVA::Factory::BookMethod(TMVA::DataLoader *loader, Types::EMVA theMethod, TString methodTitle, TString theOption )
487 return BookMethod(loader, Types::Instance().GetMethodName( theMethod ), methodTitle, theOption );
498 TMVA::MethodBase* TMVA::Factory::BookMethodWeightfile(DataLoader *loader, TMVA::Types::EMVA methodType,
const TString &weightfile)
500 TString datasetname = loader->GetName();
501 std::string methodTypeName = std::string(Types::Instance().GetMethodName(methodType).Data());
502 DataSetInfo &dsi = loader->GetDataSetInfo();
504 IMethod *im = ClassifierFactory::Instance().Create(methodTypeName, dsi, weightfile );
505 MethodBase *method = (
dynamic_cast<MethodBase*
>(im));
507 if (method ==
nullptr)
return nullptr;
509 if( method->GetMethodType() == Types::kCategory ){
510 Log() << kERROR <<
"Cannot handle category methods for now." << Endl;
514 if(fModelPersistence) {
516 TString prefix = gConfig().GetIONames().fWeightFileDirPrefix;
518 if (!prefix.IsNull())
519 if (fileDir[fileDir.Length() - 1] !=
'/')
521 fileDir=loader->GetName();
522 fileDir+=
"/"+gConfig().GetIONames().fWeightFileDir;
525 if(fModelPersistence) method->SetWeightFileDir(fileDir);
526 method->SetModelPersistence(fModelPersistence);
527 method->SetAnalysisType( fAnalysisType );
528 method->SetupMethod();
529 method->SetFile(fgTargetFile);
530 method->SetSilentFile(IsSilentFile());
532 method->DeclareCompatibilityOptions();
535 method->ReadStateFromFile();
539 TString methodTitle = method->GetName();
540 if (HasMethod(datasetname, methodTitle) != 0) {
541 Log() << kFATAL <<
"Booking failed since method with title <"
542 << methodTitle <<
"> already exists "<<
"in with DataSet Name <"<< loader->GetName()<<
"> "
546 Log() << kINFO <<
"Booked classifier \"" << method->GetMethodName()
547 <<
"\" of type: \"" << method->GetMethodTypeName() <<
"\"" << Endl;
549 if(fMethodsMap.count(datasetname) == 0) {
550 MVector *mvector =
new MVector;
551 fMethodsMap[datasetname] = mvector;
554 fMethodsMap[datasetname]->push_back( method );
562 TMVA::IMethod* TMVA::Factory::GetMethod(
const TString& datasetname,
const TString &methodTitle )
const
564 if(fMethodsMap.find(datasetname)==fMethodsMap.end())
return 0;
566 MVector *methods=fMethodsMap.find(datasetname)->second;
568 MVector::const_iterator itrMethod;
570 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
571 MethodBase* mva =
dynamic_cast<MethodBase*
>(*itrMethod);
572 if ( (mva->GetMethodName())==methodTitle )
return mva;
580 Bool_t TMVA::Factory::HasMethod(
const TString& datasetname,
const TString &methodTitle )
const
582 if(fMethodsMap.find(datasetname)==fMethodsMap.end())
return 0;
584 std::string methodName = methodTitle.Data();
585 auto isEqualToMethodName = [&methodName](TMVA::IMethod * m) {
586 return ( 0 == methodName.compare( m->GetName() ) );
589 TMVA::Factory::MVector * methods = this->fMethodsMap.at(datasetname);
590 Bool_t isMethodNameExisting = std::any_of( methods->begin(), methods->end(), isEqualToMethodName);
592 return isMethodNameExisting;
597 void TMVA::Factory::WriteDataInformation(DataSetInfo& fDataSetInfo)
601 if(!RootBaseDir()->GetDirectory(fDataSetInfo.GetName())) RootBaseDir()->mkdir(fDataSetInfo.GetName());
604 RootBaseDir()->cd(fDataSetInfo.GetName());
605 fDataSetInfo.GetDataSet();
609 const TMatrixD* m(0);
612 if(fAnalysisType == Types::kMulticlass){
613 for (UInt_t cls = 0; cls < fDataSetInfo.GetNClasses() ; cls++) {
614 m = fDataSetInfo.CorrelationMatrix(fDataSetInfo.GetClassInfo(cls)->GetName());
615 h = fDataSetInfo.CreateCorrelationMatrixHist(m, TString(
"CorrelationMatrix")+fDataSetInfo.GetClassInfo(cls)->GetName(),
616 TString(
"Correlation Matrix (")+ fDataSetInfo.GetClassInfo(cls)->GetName() +TString(
")"));
624 m = fDataSetInfo.CorrelationMatrix(
"Signal" );
625 h = fDataSetInfo.CreateCorrelationMatrixHist(m,
"CorrelationMatrixS",
"Correlation Matrix (signal)");
631 m = fDataSetInfo.CorrelationMatrix(
"Background" );
632 h = fDataSetInfo.CreateCorrelationMatrixHist(m,
"CorrelationMatrixB",
"Correlation Matrix (background)");
638 m = fDataSetInfo.CorrelationMatrix(
"Regression" );
639 h = fDataSetInfo.CreateCorrelationMatrixHist(m,
"CorrelationMatrix",
"Correlation Matrix");
648 TString processTrfs =
"I";
651 processTrfs = fTransformations;
654 std::vector<TMVA::TransformationHandler*> trfs;
655 TransformationHandler* identityTrHandler = 0;
657 std::vector<TString> trfsDef = gTools().SplitString(processTrfs,
';');
658 std::vector<TString>::iterator trfsDefIt = trfsDef.begin();
659 for (; trfsDefIt!=trfsDef.end(); ++trfsDefIt) {
660 trfs.push_back(
new TMVA::TransformationHandler(fDataSetInfo,
"Factory"));
661 TString trfS = (*trfsDefIt);
664 Log() << kDEBUG <<
"current transformation string: '" << trfS.Data() <<
"'" << Endl;
665 TMVA::CreateVariableTransforms( trfS,
670 if (trfS.BeginsWith(
'I')) identityTrHandler = trfs.back();
673 const std::vector<Event*>& inputEvents = fDataSetInfo.GetDataSet()->GetEventCollection();
676 std::vector<TMVA::TransformationHandler*>::iterator trfIt = trfs.begin();
678 for (;trfIt != trfs.end(); ++trfIt) {
680 (*trfIt)->SetRootDir(RootBaseDir()->GetDirectory(fDataSetInfo.GetName()));
681 (*trfIt)->CalcTransformations(inputEvents);
683 if(identityTrHandler) identityTrHandler->PrintVariableRanking();
686 for (trfIt = trfs.begin(); trfIt != trfs.end(); ++trfIt)
delete *trfIt;
695 std::map<TString,Double_t> TMVA::Factory::OptimizeAllMethods(TString fomType, TString fitType)
698 std::map<TString,MVector*>::iterator itrMap;
699 std::map<TString,Double_t> TunedParameters;
700 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
702 MVector *methods=itrMap->second;
704 MVector::iterator itrMethod;
707 for( itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod ) {
708 Event::SetIsTraining(kTRUE);
709 MethodBase* mva =
dynamic_cast<MethodBase*
>(*itrMethod);
711 Log() << kFATAL <<
"Dynamic cast to MethodBase failed" <<Endl;
712 return TunedParameters;
715 if (mva->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
716 Log() << kWARNING <<
"Method " << mva->GetMethodName()
717 <<
" not trained (training tree has less entries ["
718 << mva->Data()->GetNTrainingEvents()
719 <<
"] than required [" << MinNoTrainingEvents <<
"]" << Endl;
723 Log() << kINFO <<
"Optimize method: " << mva->GetMethodName() <<
" for "
724 << (fAnalysisType == Types::kRegression ?
"Regression" :
725 (fAnalysisType == Types::kMulticlass ?
"Multiclass classification" :
"Classification")) << Endl;
727 TunedParameters = mva->OptimizeTuningParameters(fomType,fitType);
728 Log() << kINFO <<
"Optimization of tuning parameters finished for Method:"<<mva->GetName() << Endl;
732 return TunedParameters;
744 TMVA::ROCCurve *TMVA::Factory::GetROC(TMVA::DataLoader *loader, TString theMethodName, UInt_t iClass,
745 Types::ETreeType type)
747 return GetROC((TString)loader->GetName(), theMethodName, iClass, type);
758 TMVA::ROCCurve *TMVA::Factory::GetROC(TString datasetname, TString theMethodName, UInt_t iClass, Types::ETreeType type)
760 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
761 Log() << kERROR << Form(
"DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
765 if (!this->HasMethod(datasetname, theMethodName)) {
766 Log() << kERROR << Form(
"Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data())
771 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
772 if (allowedAnalysisTypes.count(this->fAnalysisType) == 0) {
773 Log() << kERROR << Form(
"Can only generate ROC curves for analysis type kClassification and kMulticlass.")
778 TMVA::MethodBase *method =
dynamic_cast<TMVA::MethodBase *
>(this->GetMethod(datasetname, theMethodName));
779 TMVA::DataSet *dataset = method->Data();
780 dataset->SetCurrentType(type);
781 TMVA::Results *results = dataset->GetResults(theMethodName, type, this->fAnalysisType);
783 UInt_t nClasses = method->DataInfo().GetNClasses();
784 if (this->fAnalysisType == Types::kMulticlass && iClass >= nClasses) {
785 Log() << kERROR << Form(
"Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
791 TMVA::ROCCurve *rocCurve =
nullptr;
792 if (this->fAnalysisType == Types::kClassification) {
794 std::vector<Float_t> *mvaRes =
dynamic_cast<ResultsClassification *
>(results)->GetValueVector();
795 std::vector<Bool_t> *mvaResTypes =
dynamic_cast<ResultsClassification *
>(results)->GetValueVectorTypes();
796 std::vector<Float_t> mvaResWeights;
798 auto eventCollection = dataset->GetEventCollection(type);
799 mvaResWeights.reserve(eventCollection.size());
800 for (
auto ev : eventCollection) {
801 mvaResWeights.push_back(ev->GetWeight());
804 rocCurve =
new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);
806 }
else if (this->fAnalysisType == Types::kMulticlass) {
807 std::vector<Float_t> mvaRes;
808 std::vector<Bool_t> mvaResTypes;
809 std::vector<Float_t> mvaResWeights;
811 std::vector<std::vector<Float_t>> *rawMvaRes =
dynamic_cast<ResultsMulticlass *
>(results)->GetValueVector();
816 mvaRes.reserve(rawMvaRes->size());
817 for (
auto item : *rawMvaRes) {
818 mvaRes.push_back(item[iClass]);
821 auto eventCollection = dataset->GetEventCollection(type);
822 mvaResTypes.reserve(eventCollection.size());
823 mvaResWeights.reserve(eventCollection.size());
824 for (
auto ev : eventCollection) {
825 mvaResTypes.push_back(ev->GetClass() == iClass);
826 mvaResWeights.push_back(ev->GetWeight());
829 rocCurve =
new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
843 Double_t TMVA::Factory::GetROCIntegral(TMVA::DataLoader *loader, TString theMethodName, UInt_t iClass)
845 return GetROCIntegral((TString)loader->GetName(), theMethodName, iClass);
856 Double_t TMVA::Factory::GetROCIntegral(TString datasetname, TString theMethodName, UInt_t iClass)
858 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
859 Log() << kERROR << Form(
"DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
863 if ( ! this->HasMethod(datasetname, theMethodName) ) {
864 Log() << kERROR << Form(
"Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
868 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
869 if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
870 Log() << kERROR << Form(
"Can only generate ROC integral for analysis type kClassification. and kMulticlass.")
875 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
877 Log() << kFATAL << Form(
"ROCCurve object was not created in Method = %s not found with Dataset = %s ",
878 theMethodName.Data(), datasetname.Data())
883 Int_t npoints = TMVA::gConfig().fVariablePlotting.fNbinsXOfROCCurve + 1;
884 Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
904 TGraph* TMVA::Factory::GetROCCurve(DataLoader *loader, TString theMethodName, Bool_t setTitles, UInt_t iClass)
906 return GetROCCurve( (TString)loader->GetName(), theMethodName, setTitles, iClass );
923 TGraph* TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, Bool_t setTitles, UInt_t iClass)
925 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
926 Log() << kERROR << Form(
"DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
930 if ( ! this->HasMethod(datasetname, theMethodName) ) {
931 Log() << kERROR << Form(
"Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
935 std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
936 if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
937 Log() << kERROR << Form(
"Can only generate ROC curves for analysis type kClassification and kMulticlass.") << Endl;
941 TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
942 TGraph *graph =
nullptr;
945 Log() << kFATAL << Form(
"ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
949 graph = (TGraph *)rocCurve->GetROCCurve()->Clone();
953 graph->GetYaxis()->SetTitle(
"Background rejection (Specificity)");
954 graph->GetXaxis()->SetTitle(
"Signal efficiency (Sensitivity)");
955 graph->SetTitle(Form(
"Signal efficiency vs. Background rejection (%s)", theMethodName.Data()));
973 TMultiGraph* TMVA::Factory::GetROCCurveAsMultiGraph(DataLoader *loader, UInt_t iClass)
975 return GetROCCurveAsMultiGraph((TString)loader->GetName(), iClass);
990 TMultiGraph* TMVA::Factory::GetROCCurveAsMultiGraph(TString datasetname, UInt_t iClass)
992 UInt_t line_color = 1;
994 TMultiGraph *multigraph =
new TMultiGraph();
996 MVector *methods = fMethodsMap[datasetname.Data()];
997 for (
auto * method_raw : *methods) {
998 TMVA::MethodBase *method =
dynamic_cast<TMVA::MethodBase *
>(method_raw);
999 if (method ==
nullptr) {
continue; }
1001 TString methodName = method->GetMethodName();
1002 UInt_t nClasses = method->DataInfo().GetNClasses();
1004 if ( this->fAnalysisType == Types::kMulticlass && iClass >= nClasses ) {
1005 Log() << kERROR << Form(
"Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass, nClasses) << Endl;
1009 TString className = method->DataInfo().GetClassInfo(iClass)->GetName();
1011 TGraph *graph = this->GetROCCurve(datasetname, methodName,
false, iClass);
1012 graph->SetTitle(methodName);
1014 graph->SetLineWidth(2);
1015 graph->SetLineColor(line_color++);
1016 graph->SetFillColor(10);
1018 multigraph->Add(graph);
1021 if ( multigraph->GetListOfGraphs() == nullptr ) {
1022 Log() << kERROR << Form(
"No metohds have class %i defined.", iClass) << Endl;
1041 TCanvas * TMVA::Factory::GetROCCurve(TMVA::DataLoader *loader, UInt_t iClass)
1043 return GetROCCurve((TString)loader->GetName(), iClass);
1057 TCanvas * TMVA::Factory::GetROCCurve(TString datasetname, UInt_t iClass)
1059 if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
1060 Log() << kERROR << Form(
"DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
1064 TString name = Form(
"ROCCurve %s class %i", datasetname.Data(), iClass);
1065 TCanvas *canvas =
new TCanvas(name,
"ROC Curve", 200, 10, 700, 500);
1068 TMultiGraph *multigraph = this->GetROCCurveAsMultiGraph(datasetname, iClass);
1071 multigraph->Draw(
"AL");
1073 multigraph->GetYaxis()->SetTitle(
"Background rejection (Specificity)");
1074 multigraph->GetXaxis()->SetTitle(
"Signal efficiency (Sensitivity)");
1076 TString titleString = Form(
"Signal efficiency vs. Background rejection");
1077 if (this->fAnalysisType == Types::kMulticlass) {
1078 titleString = Form(
"%s (Class=%i)", titleString.Data(), iClass);
1082 multigraph->GetHistogram()->SetTitle( titleString );
1083 multigraph->SetTitle( titleString );
1085 canvas->BuildLegend(0.15, 0.15, 0.35, 0.3,
"MVA Method");
1094 void TMVA::Factory::TrainAllMethods()
1096 Log() << kHEADER << gTools().Color(
"bold") <<
"Train all methods" << gTools().Color(
"reset") << Endl;
1101 if (fMethodsMap.empty()) {
1102 Log() << kINFO <<
"...nothing found to train" << Endl;
1108 Log() << kDEBUG <<
"Train all methods for "
1109 << (fAnalysisType == Types::kRegression ?
"Regression" :
1110 (fAnalysisType == Types::kMulticlass ?
"Multiclass" :
"Classification") ) <<
" ..." << Endl;
1112 std::map<TString,MVector*>::iterator itrMap;
1114 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1116 MVector *methods=itrMap->second;
1117 MVector::iterator itrMethod;
1120 for( itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod ) {
1121 Event::SetIsTraining(kTRUE);
1122 MethodBase* mva =
dynamic_cast<MethodBase*
>(*itrMethod);
1124 if(mva==0)
continue;
1126 if(mva->DataInfo().GetDataSetManager()->DataInput().GetEntries() <=1) {
1127 Log() << kFATAL <<
"No input data for the training provided!" << Endl;
1130 if(fAnalysisType == Types::kRegression && mva->DataInfo().GetNTargets() < 1 )
1131 Log() << kFATAL <<
"You want to do regression training without specifying a target." << Endl;
1132 else if( (fAnalysisType == Types::kMulticlass || fAnalysisType == Types::kClassification)
1133 && mva->DataInfo().GetNClasses() < 2 )
1134 Log() << kFATAL <<
"You want to do classification training, but specified less than two classes." << Endl;
1137 if(!IsSilentFile()) WriteDataInformation(mva->fDataSetInfo);
1140 if (mva->Data()->GetNTrainingEvents() < MinNoTrainingEvents) {
1141 Log() << kWARNING <<
"Method " << mva->GetMethodName()
1142 <<
" not trained (training tree has less entries ["
1143 << mva->Data()->GetNTrainingEvents()
1144 <<
"] than required [" << MinNoTrainingEvents <<
"]" << Endl;
1148 Log() << kHEADER <<
"Train method: " << mva->GetMethodName() <<
" for "
1149 << (fAnalysisType == Types::kRegression ?
"Regression" :
1150 (fAnalysisType == Types::kMulticlass ?
"Multiclass classification" :
"Classification")) << Endl << Endl;
1152 Log() << kHEADER <<
"Training finished" << Endl << Endl;
1155 if (fAnalysisType != Types::kRegression) {
1159 Log() << kINFO <<
"Ranking input variables (method specific)..." << Endl;
1160 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1161 MethodBase* mva =
dynamic_cast<MethodBase*
>(*itrMethod);
1162 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
1165 const Ranking* ranking = (*itrMethod)->CreateRanking();
1166 if (ranking != 0) ranking->Print();
1167 else Log() << kINFO <<
"No variable ranking supplied by classifier: "
1168 <<
dynamic_cast<MethodBase*
>(*itrMethod)->GetMethodName() << Endl;
1174 if (!IsSilentFile()) {
1175 for (UInt_t i=0; i<methods->size(); i++) {
1176 MethodBase* m =
dynamic_cast<MethodBase*
>((*methods)[i]);
1179 m->fTrainHistory.SaveHistory(m->GetMethodName());
1187 if (fModelPersistence) {
1189 Log() << kHEADER <<
"=== Destroy and recreate all methods via weight files for testing ===" << Endl << Endl;
1191 if(!IsSilentFile())RootBaseDir()->cd();
1194 for (UInt_t i=0; i<methods->size(); i++) {
1196 MethodBase *m =
dynamic_cast<MethodBase *
>((*methods)[i]);
1200 TMVA::Types::EMVA methodType = m->GetMethodType();
1201 TString weightfile = m->GetWeightFileName();
1205 weightfile.ReplaceAll(
".txt",
".xml");
1207 DataSetInfo &dataSetInfo = m->DataInfo();
1208 TString testvarName = m->GetTestvarName();
1212 m =
dynamic_cast<MethodBase *
>(ClassifierFactory::Instance().Create(
1213 Types::Instance().GetMethodName(methodType).Data(), dataSetInfo, weightfile));
1214 if (m->GetMethodType() == Types::kCategory) {
1215 MethodCategory *methCat = (
dynamic_cast<MethodCategory *
>(m));
1217 Log() << kFATAL <<
"Method with type kCategory cannot be casted to MethodCategory. /Factory" << Endl;
1219 methCat->fDataSetManager = m->DataInfo().GetDataSetManager();
1223 TString wfileDir = m->DataInfo().GetName();
1224 wfileDir +=
"/" + gConfig().GetIONames().fWeightFileDir;
1225 m->SetWeightFileDir(wfileDir);
1226 m->SetModelPersistence(fModelPersistence);
1227 m->SetSilentFile(IsSilentFile());
1228 m->SetAnalysisType(fAnalysisType);
1230 m->ReadStateFromFile();
1231 m->SetTestvarName(testvarName);
1245 void TMVA::Factory::TestAllMethods()
1247 Log() << kHEADER << gTools().Color(
"bold") <<
"Test all methods" << gTools().Color(
"reset") << Endl;
1250 if (fMethodsMap.empty()) {
1251 Log() << kINFO <<
"...nothing found to test" << Endl;
1254 std::map<TString,MVector*>::iterator itrMap;
1256 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1258 MVector *methods=itrMap->second;
1259 MVector::iterator itrMethod;
1262 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1263 Event::SetIsTraining(kFALSE);
1264 MethodBase *mva =
dynamic_cast<MethodBase *
>(*itrMethod);
1267 Types::EAnalysisType analysisType = mva->GetAnalysisType();
1268 Log() << kHEADER <<
"Test method: " << mva->GetMethodName() <<
" for "
1269 << (analysisType == Types::kRegression
1271 : (analysisType == Types::kMulticlass ?
"Multiclass classification" :
"Classification"))
1272 <<
" performance" << Endl << Endl;
1273 mva->AddOutput(Types::kTesting, analysisType);
1280 void TMVA::Factory::MakeClass(
const TString& datasetname ,
const TString& methodTitle )
const
1282 if (methodTitle !=
"") {
1283 IMethod* method = GetMethod(datasetname, methodTitle);
1284 if (method) method->MakeClass();
1286 Log() << kWARNING <<
"<MakeClass> Could not find classifier \"" << methodTitle
1287 <<
"\" in list" << Endl;
1293 MVector *methods=fMethodsMap.find(datasetname)->second;
1294 MVector::const_iterator itrMethod;
1295 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1296 MethodBase* method =
dynamic_cast<MethodBase*
>(*itrMethod);
1297 if(method==0)
continue;
1298 Log() << kINFO <<
"Make response class for classifier: " << method->GetMethodName() << Endl;
1299 method->MakeClass();
1308 void TMVA::Factory::PrintHelpMessage(
const TString& datasetname ,
const TString& methodTitle )
const
1310 if (methodTitle !=
"") {
1311 IMethod* method = GetMethod(datasetname , methodTitle );
1312 if (method) method->PrintHelpMessage();
1314 Log() << kWARNING <<
"<PrintHelpMessage> Could not find classifier \"" << methodTitle
1315 <<
"\" in list" << Endl;
1321 MVector *methods=fMethodsMap.find(datasetname)->second;
1322 MVector::const_iterator itrMethod ;
1323 for (itrMethod = methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1324 MethodBase* method =
dynamic_cast<MethodBase*
>(*itrMethod);
1325 if(method==0)
continue;
1326 Log() << kINFO <<
"Print help message for classifier: " << method->GetMethodName() << Endl;
1327 method->PrintHelpMessage();
1335 void TMVA::Factory::EvaluateAllVariables(DataLoader *loader, TString options )
1337 Log() << kINFO <<
"Evaluating all variables..." << Endl;
1338 Event::SetIsTraining(kFALSE);
1340 for (UInt_t i=0; i<loader->GetDataSetInfo().GetNVariables(); i++) {
1341 TString s = loader->GetDataSetInfo().GetVariableInfo(i).GetLabel();
1342 if (options.Contains(
"V")) s +=
":V";
1343 this->BookMethod(loader,
"Variable", s );
1350 void TMVA::Factory::EvaluateAllMethods(
void )
1352 Log() << kHEADER << gTools().Color(
"bold") <<
"Evaluate all methods" << gTools().Color(
"reset") << Endl;
1355 if (fMethodsMap.empty()) {
1356 Log() << kINFO <<
"...nothing found to evaluate" << Endl;
1359 std::map<TString,MVector*>::iterator itrMap;
1361 for(itrMap = fMethodsMap.begin();itrMap != fMethodsMap.end();++itrMap)
1363 MVector *methods=itrMap->second;
1373 Int_t nmeth_used[2] = {0,0};
1375 std::vector<std::vector<TString> > mname(2);
1376 std::vector<std::vector<Double_t> > sig(2), sep(2), roc(2);
1377 std::vector<std::vector<Double_t> > eff01(2), eff10(2), eff30(2), effArea(2);
1378 std::vector<std::vector<Double_t> > eff01err(2), eff10err(2), eff30err(2);
1379 std::vector<std::vector<Double_t> > trainEff01(2), trainEff10(2), trainEff30(2);
1381 std::vector<std::vector<Float_t> > multiclass_testEff;
1382 std::vector<std::vector<Float_t> > multiclass_trainEff;
1383 std::vector<std::vector<Float_t> > multiclass_testPur;
1384 std::vector<std::vector<Float_t> > multiclass_trainPur;
1386 std::vector<std::vector<Float_t> > train_history;
1389 std::vector<TMatrixD> multiclass_trainConfusionEffB01;
1390 std::vector<TMatrixD> multiclass_trainConfusionEffB10;
1391 std::vector<TMatrixD> multiclass_trainConfusionEffB30;
1392 std::vector<TMatrixD> multiclass_testConfusionEffB01;
1393 std::vector<TMatrixD> multiclass_testConfusionEffB10;
1394 std::vector<TMatrixD> multiclass_testConfusionEffB30;
1396 std::vector<std::vector<Double_t> > biastrain(1);
1397 std::vector<std::vector<Double_t> > biastest(1);
1398 std::vector<std::vector<Double_t> > devtrain(1);
1399 std::vector<std::vector<Double_t> > devtest(1);
1400 std::vector<std::vector<Double_t> > rmstrain(1);
1401 std::vector<std::vector<Double_t> > rmstest(1);
1402 std::vector<std::vector<Double_t> > minftrain(1);
1403 std::vector<std::vector<Double_t> > minftest(1);
1404 std::vector<std::vector<Double_t> > rhotrain(1);
1405 std::vector<std::vector<Double_t> > rhotest(1);
1408 std::vector<std::vector<Double_t> > biastrainT(1);
1409 std::vector<std::vector<Double_t> > biastestT(1);
1410 std::vector<std::vector<Double_t> > devtrainT(1);
1411 std::vector<std::vector<Double_t> > devtestT(1);
1412 std::vector<std::vector<Double_t> > rmstrainT(1);
1413 std::vector<std::vector<Double_t> > rmstestT(1);
1414 std::vector<std::vector<Double_t> > minftrainT(1);
1415 std::vector<std::vector<Double_t> > minftestT(1);
1418 MVector methodsNoCuts;
1420 Bool_t doRegression = kFALSE;
1421 Bool_t doMulticlass = kFALSE;
1424 for (MVector::iterator itrMethod =methods->begin(); itrMethod != methods->end(); ++itrMethod) {
1425 Event::SetIsTraining(kFALSE);
1426 MethodBase* theMethod =
dynamic_cast<MethodBase*
>(*itrMethod);
1427 if(theMethod==0)
continue;
1428 theMethod->SetFile(fgTargetFile);
1429 theMethod->SetSilentFile(IsSilentFile());
1430 if (theMethod->GetMethodType() != Types::kCuts) methodsNoCuts.push_back( *itrMethod );
1432 if (theMethod->DoRegression()) {
1433 doRegression = kTRUE;
1435 Log() << kINFO <<
"Evaluate regression method: " << theMethod->GetMethodName() << Endl;
1436 Double_t bias, dev, rms, mInf;
1437 Double_t biasT, devT, rmsT, mInfT;
1440 Log() << kINFO <<
"TestRegression (testing)" << Endl;
1441 theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTesting );
1442 biastest[0] .push_back( bias );
1443 devtest[0] .push_back( dev );
1444 rmstest[0] .push_back( rms );
1445 minftest[0] .push_back( mInf );
1446 rhotest[0] .push_back( rho );
1447 biastestT[0] .push_back( biasT );
1448 devtestT[0] .push_back( devT );
1449 rmstestT[0] .push_back( rmsT );
1450 minftestT[0] .push_back( mInfT );
1452 Log() << kINFO <<
"TestRegression (training)" << Endl;
1453 theMethod->TestRegression( bias, biasT, dev, devT, rms, rmsT, mInf, mInfT, rho, TMVA::Types::kTraining );
1454 biastrain[0] .push_back( bias );
1455 devtrain[0] .push_back( dev );
1456 rmstrain[0] .push_back( rms );
1457 minftrain[0] .push_back( mInf );
1458 rhotrain[0] .push_back( rho );
1459 biastrainT[0].push_back( biasT );
1460 devtrainT[0] .push_back( devT );
1461 rmstrainT[0] .push_back( rmsT );
1462 minftrainT[0].push_back( mInfT );
1464 mname[0].push_back( theMethod->GetMethodName() );
1466 if (!IsSilentFile()) {
1467 Log() << kDEBUG <<
"\tWrite evaluation histograms to file" << Endl;
1468 theMethod->WriteEvaluationHistosToFile(Types::kTesting);
1469 theMethod->WriteEvaluationHistosToFile(Types::kTraining);
1471 }
else if (theMethod->DoMulticlass()) {
1475 doMulticlass = kTRUE;
1476 Log() << kINFO <<
"Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;
1484 theMethod->TestMulticlass();
1487 multiclass_trainConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTraining));
1488 multiclass_trainConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTraining));
1489 multiclass_trainConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTraining));
1491 multiclass_testConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTesting));
1492 multiclass_testConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTesting));
1493 multiclass_testConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTesting));
1495 if (!IsSilentFile()) {
1496 Log() << kDEBUG <<
"\tWrite evaluation histograms to file" << Endl;
1497 theMethod->WriteEvaluationHistosToFile(Types::kTesting);
1498 theMethod->WriteEvaluationHistosToFile(Types::kTraining);
1502 mname[0].push_back(theMethod->GetMethodName());
1505 Log() << kHEADER <<
"Evaluate classifier: " << theMethod->GetMethodName() << Endl << Endl;
1506 isel = (theMethod->GetMethodTypeName().Contains(
"Variable")) ? 1 : 0;
1509 theMethod->TestClassification();
1512 mname[isel].push_back(theMethod->GetMethodName());
1513 sig[isel].push_back(theMethod->GetSignificance());
1514 sep[isel].push_back(theMethod->GetSeparation());
1515 roc[isel].push_back(theMethod->GetROCIntegral());
1518 eff01[isel].push_back(theMethod->GetEfficiency(
"Efficiency:0.01", Types::kTesting, err));
1519 eff01err[isel].push_back(err);
1520 eff10[isel].push_back(theMethod->GetEfficiency(
"Efficiency:0.10", Types::kTesting, err));
1521 eff10err[isel].push_back(err);
1522 eff30[isel].push_back(theMethod->GetEfficiency(
"Efficiency:0.30", Types::kTesting, err));
1523 eff30err[isel].push_back(err);
1524 effArea[isel].push_back(theMethod->GetEfficiency(
"", Types::kTesting, err));
1526 trainEff01[isel].push_back(theMethod->GetTrainingEfficiency(
"Efficiency:0.01"));
1527 trainEff10[isel].push_back(theMethod->GetTrainingEfficiency(
"Efficiency:0.10"));
1528 trainEff30[isel].push_back(theMethod->GetTrainingEfficiency(
"Efficiency:0.30"));
1532 if (!IsSilentFile()) {
1533 Log() << kDEBUG <<
"\tWrite evaluation histograms to file" << Endl;
1534 theMethod->WriteEvaluationHistosToFile(Types::kTesting);
1535 theMethod->WriteEvaluationHistosToFile(Types::kTraining);
1541 std::vector<TString> vtemps = mname[0];
1542 std::vector< std::vector<Double_t> > vtmp;
1543 vtmp.push_back( devtest[0] );
1544 vtmp.push_back( devtrain[0] );
1545 vtmp.push_back( biastest[0] );
1546 vtmp.push_back( biastrain[0] );
1547 vtmp.push_back( rmstest[0] );
1548 vtmp.push_back( rmstrain[0] );
1549 vtmp.push_back( minftest[0] );
1550 vtmp.push_back( minftrain[0] );
1551 vtmp.push_back( rhotest[0] );
1552 vtmp.push_back( rhotrain[0] );
1553 vtmp.push_back( devtestT[0] );
1554 vtmp.push_back( devtrainT[0] );
1555 vtmp.push_back( biastestT[0] );
1556 vtmp.push_back( biastrainT[0]);
1557 vtmp.push_back( rmstestT[0] );
1558 vtmp.push_back( rmstrainT[0] );
1559 vtmp.push_back( minftestT[0] );
1560 vtmp.push_back( minftrainT[0]);
1561 gTools().UsefulSortAscending( vtmp, &vtemps );
1563 devtest[0] = vtmp[0];
1564 devtrain[0] = vtmp[1];
1565 biastest[0] = vtmp[2];
1566 biastrain[0] = vtmp[3];
1567 rmstest[0] = vtmp[4];
1568 rmstrain[0] = vtmp[5];
1569 minftest[0] = vtmp[6];
1570 minftrain[0] = vtmp[7];
1571 rhotest[0] = vtmp[8];
1572 rhotrain[0] = vtmp[9];
1573 devtestT[0] = vtmp[10];
1574 devtrainT[0] = vtmp[11];
1575 biastestT[0] = vtmp[12];
1576 biastrainT[0] = vtmp[13];
1577 rmstestT[0] = vtmp[14];
1578 rmstrainT[0] = vtmp[15];
1579 minftestT[0] = vtmp[16];
1580 minftrainT[0] = vtmp[17];
1581 }
else if (doMulticlass) {
1589 for (Int_t k=0; k<2; k++) {
1590 std::vector< std::vector<Double_t> > vtemp;
1591 vtemp.push_back( effArea[k] );
1592 vtemp.push_back( eff10[k] );
1593 vtemp.push_back( eff01[k] );
1594 vtemp.push_back( eff30[k] );
1595 vtemp.push_back( eff10err[k] );
1596 vtemp.push_back( eff01err[k] );
1597 vtemp.push_back( eff30err[k] );
1598 vtemp.push_back( trainEff10[k] );
1599 vtemp.push_back( trainEff01[k] );
1600 vtemp.push_back( trainEff30[k] );
1601 vtemp.push_back( sig[k] );
1602 vtemp.push_back( sep[k] );
1603 vtemp.push_back( roc[k] );
1604 std::vector<TString> vtemps = mname[k];
1605 gTools().UsefulSortDescending( vtemp, &vtemps );
1606 effArea[k] = vtemp[0];
1607 eff10[k] = vtemp[1];
1608 eff01[k] = vtemp[2];
1609 eff30[k] = vtemp[3];
1610 eff10err[k] = vtemp[4];
1611 eff01err[k] = vtemp[5];
1612 eff30err[k] = vtemp[6];
1613 trainEff10[k] = vtemp[7];
1614 trainEff01[k] = vtemp[8];
1615 trainEff30[k] = vtemp[9];
1631 const Int_t nmeth = methodsNoCuts.size();
1632 MethodBase* method =
dynamic_cast<MethodBase*
>(methods[0][0]);
1633 const Int_t nvar = method->fDataSetInfo.GetNVariables();
1634 if (!doRegression && !doMulticlass ) {
1639 Double_t *dvec =
new Double_t[nmeth+nvar];
1640 std::vector<Double_t> rvec;
1643 TPrincipal* tpSig =
new TPrincipal( nmeth+nvar,
"" );
1644 TPrincipal* tpBkg =
new TPrincipal( nmeth+nvar,
"" );
1648 std::vector<TString>* theVars =
new std::vector<TString>;
1649 std::vector<ResultsClassification*> mvaRes;
1650 for (MVector::iterator itrMethod = methodsNoCuts.begin(); itrMethod != methodsNoCuts.end(); ++itrMethod, ++ivar) {
1651 MethodBase* m =
dynamic_cast<MethodBase*
>(*itrMethod);
1653 theVars->push_back( m->GetTestvarName() );
1654 rvec.push_back( m->GetSignalReferenceCut() );
1655 theVars->back().ReplaceAll(
"MVA_",
"" );
1656 mvaRes.push_back( dynamic_cast<ResultsClassification*>( m->Data()->GetResults( m->GetMethodName(),
1658 Types::kMaxAnalysisType) ) );
1662 TMatrixD* overlapS =
new TMatrixD( nmeth, nmeth );
1663 TMatrixD* overlapB =
new TMatrixD( nmeth, nmeth );
1668 DataSet* defDs = method->fDataSetInfo.GetDataSet();
1669 defDs->SetCurrentType(Types::kTesting);
1670 for (Int_t ievt=0; ievt<defDs->GetNEvents(); ievt++) {
1671 const Event* ev = defDs->GetEvent(ievt);
1674 TMatrixD* theMat = 0;
1675 for (Int_t im=0; im<nmeth; im++) {
1677 Double_t retval = (Double_t)(*mvaRes[im])[ievt][0];
1678 if (TMath::IsNaN(retval)) {
1679 Log() << kWARNING <<
"Found NaN return value in event: " << ievt
1680 <<
" for method \"" << methodsNoCuts[im]->GetName() <<
"\"" << Endl;
1683 else dvec[im] = retval;
1685 for (Int_t iv=0; iv<nvar; iv++) dvec[iv+nmeth] = (Double_t)ev->GetValue(iv);
1686 if (method->fDataSetInfo.IsSignal(ev)) { tpSig->AddRow( dvec ); theMat = overlapS; }
1687 else { tpBkg->AddRow( dvec ); theMat = overlapB; }
1690 for (Int_t im=0; im<nmeth; im++) {
1691 for (Int_t jm=im; jm<nmeth; jm++) {
1692 if ((dvec[im] - rvec[im])*(dvec[jm] - rvec[jm]) > 0) {
1694 if (im != jm) (*theMat)(jm,im)++;
1701 (*overlapS) *= (1.0/defDs->GetNEvtSigTest());
1702 (*overlapB) *= (1.0/defDs->GetNEvtBkgdTest());
1704 tpSig->MakePrincipals();
1705 tpBkg->MakePrincipals();
1707 const TMatrixD* covMatS = tpSig->GetCovarianceMatrix();
1708 const TMatrixD* covMatB = tpBkg->GetCovarianceMatrix();
1710 const TMatrixD* corrMatS = gTools().GetCorrelationMatrix( covMatS );
1711 const TMatrixD* corrMatB = gTools().GetCorrelationMatrix( covMatB );
1714 if (corrMatS != 0 && corrMatB != 0) {
1717 TMatrixD mvaMatS(nmeth,nmeth);
1718 TMatrixD mvaMatB(nmeth,nmeth);
1719 for (Int_t im=0; im<nmeth; im++) {
1720 for (Int_t jm=0; jm<nmeth; jm++) {
1721 mvaMatS(im,jm) = (*corrMatS)(im,jm);
1722 mvaMatB(im,jm) = (*corrMatB)(im,jm);
1727 std::vector<TString> theInputVars;
1728 TMatrixD varmvaMatS(nvar,nmeth);
1729 TMatrixD varmvaMatB(nvar,nmeth);
1730 for (Int_t iv=0; iv<nvar; iv++) {
1731 theInputVars.push_back( method->fDataSetInfo.GetVariableInfo( iv ).GetLabel() );
1732 for (Int_t jm=0; jm<nmeth; jm++) {
1733 varmvaMatS(iv,jm) = (*corrMatS)(nmeth+iv,jm);
1734 varmvaMatB(iv,jm) = (*corrMatB)(nmeth+iv,jm);
1739 Log() << kINFO << Endl;
1740 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"Inter-MVA correlation matrix (signal):" << Endl;
1741 gTools().FormattedOutput( mvaMatS, *theVars, Log() );
1742 Log() << kINFO << Endl;
1744 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"Inter-MVA correlation matrix (background):" << Endl;
1745 gTools().FormattedOutput( mvaMatB, *theVars, Log() );
1746 Log() << kINFO << Endl;
1749 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"Correlations between input variables and MVA response (signal):" << Endl;
1750 gTools().FormattedOutput( varmvaMatS, theInputVars, *theVars, Log() );
1751 Log() << kINFO << Endl;
1753 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"Correlations between input variables and MVA response (background):" << Endl;
1754 gTools().FormattedOutput( varmvaMatB, theInputVars, *theVars, Log() );
1755 Log() << kINFO << Endl;
1757 else Log() << kWARNING <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"<TestAllMethods> cannot compute correlation matrices" << Endl;
1760 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"The following \"overlap\" matrices contain the fraction of events for which " << Endl;
1761 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" << Endl;
1762 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"An event is signal-like, if its MVA output exceeds the following value:" << Endl;
1763 gTools().FormattedOutput( rvec, *theVars,
"Method" ,
"Cut value", Log() );
1764 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"which correspond to the working point: eff(signal) = 1 - eff(background)" << Endl;
1767 if (nmeth != (Int_t)methods->size())
1768 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"Note: no correlations and overlap with cut method are provided at present" << Endl;
1771 Log() << kINFO << Endl;
1772 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"Inter-MVA overlap matrix (signal):" << Endl;
1773 gTools().FormattedOutput( *overlapS, *theVars, Log() );
1774 Log() << kINFO << Endl;
1776 Log() << kINFO <<Form(
"Dataset[%s] : ",method->fDataSetInfo.GetName())<<
"Inter-MVA overlap matrix (background):" << Endl;
1777 gTools().FormattedOutput( *overlapB, *theVars, Log() );
1799 Log() << kINFO << Endl;
1800 TString hLine =
"--------------------------------------------------------------------------------------------------";
1801 Log() << kINFO <<
"Evaluation results ranked by smallest RMS on test sample:" << Endl;
1802 Log() << kINFO <<
"(\"Bias\" quotes the mean deviation of the regression from true target." << Endl;
1803 Log() << kINFO <<
" \"MutInf\" is the \"Mutual Information\" between regression and target." << Endl;
1804 Log() << kINFO <<
" Indicated by \"_T\" are the corresponding \"truncated\" quantities ob-" << Endl;
1805 Log() << kINFO <<
" tained when removing events deviating more than 2sigma from average.)" << Endl;
1806 Log() << kINFO << hLine << Endl;
1808 Log() << kINFO << hLine << Endl;
1810 for (Int_t i=0; i<nmeth_used[0]; i++) {
1811 MethodBase* theMethod =
dynamic_cast<MethodBase*
>((*methods)[i]);
1812 if(theMethod==0)
continue;
1814 Log() << kINFO << Form(
"%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1815 theMethod->fDataSetInfo.GetName(),
1816 (
const char*)mname[0][i],
1817 biastest[0][i], biastestT[0][i],
1818 rmstest[0][i], rmstestT[0][i],
1819 minftest[0][i], minftestT[0][i] )
1822 Log() << kINFO << hLine << Endl;
1823 Log() << kINFO << Endl;
1824 Log() << kINFO <<
"Evaluation results ranked by smallest RMS on training sample:" << Endl;
1825 Log() << kINFO <<
"(overtraining check)" << Endl;
1826 Log() << kINFO << hLine << Endl;
1827 Log() << kINFO <<
"DataSet Name: MVA Method: <Bias> <Bias_T> RMS RMS_T | MutInf MutInf_T" << Endl;
1828 Log() << kINFO << hLine << Endl;
1830 for (Int_t i=0; i<nmeth_used[0]; i++) {
1831 MethodBase* theMethod =
dynamic_cast<MethodBase*
>((*methods)[i]);
1832 if(theMethod==0)
continue;
1833 Log() << kINFO << Form(
"%-20s %-15s:%#9.3g%#9.3g%#9.3g%#9.3g | %#5.3f %#5.3f",
1834 theMethod->fDataSetInfo.GetName(),
1835 (
const char*)mname[0][i],
1836 biastrain[0][i], biastrainT[0][i],
1837 rmstrain[0][i], rmstrainT[0][i],
1838 minftrain[0][i], minftrainT[0][i] )
1841 Log() << kINFO << hLine << Endl;
1842 Log() << kINFO << Endl;
1843 }
else if (doMulticlass) {
1849 "-------------------------------------------------------------------------------------------------------";
1888 TString header1 = Form(
"%-15s%-15s%-15s%-15s%-15s%-15s",
"Dataset",
"MVA Method",
"ROC AUC",
"Sig eff@B=0.01",
1889 "Sig eff@B=0.10",
"Sig eff@B=0.30");
1890 TString header2 = Form(
"%-15s%-15s%-15s%-15s%-15s%-15s",
"Name:",
"/ Class:",
"test (train)",
"test (train)",
1891 "test (train)",
"test (train)");
1892 Log() << kINFO << Endl;
1893 Log() << kINFO <<
"1-vs-rest performance metrics per class" << Endl;
1894 Log() << kINFO << hLine << Endl;
1895 Log() << kINFO << Endl;
1896 Log() << kINFO <<
"Considers the listed class as signal and the other classes" << Endl;
1897 Log() << kINFO <<
"as background, reporting the resulting binary performance." << Endl;
1898 Log() << kINFO <<
"A score of 0.820 (0.850) means 0.820 was acheived on the" << Endl;
1899 Log() << kINFO <<
"test set and 0.850 on the training set." << Endl;
1901 Log() << kINFO << Endl;
1902 Log() << kINFO << header1 << Endl;
1903 Log() << kINFO << header2 << Endl;
1904 for (Int_t k = 0; k < 2; k++) {
1905 for (Int_t i = 0; i < nmeth_used[k]; i++) {
1907 mname[k][i].ReplaceAll(
"Variable_",
"");
1910 const TString datasetName = itrMap->first;
1911 const TString mvaName = mname[k][i];
1913 MethodBase *theMethod =
dynamic_cast<MethodBase *
>(GetMethod(datasetName, mvaName));
1914 if (theMethod == 0) {
1918 Log() << kINFO << Endl;
1919 TString row = Form(
"%-15s%-15s", datasetName.Data(), mvaName.Data());
1920 Log() << kINFO << row << Endl;
1921 Log() << kINFO <<
"------------------------------" << Endl;
1923 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
1924 for (UInt_t iClass = 0; iClass < numClasses; ++iClass) {
1926 ROCCurve *rocCurveTrain = GetROC(datasetName, mvaName, iClass, Types::kTraining);
1927 ROCCurve *rocCurveTest = GetROC(datasetName, mvaName, iClass, Types::kTesting);
1929 const TString className = theMethod->DataInfo().GetClassInfo(iClass)->GetName();
1930 const Double_t rocaucTrain = rocCurveTrain->GetROCIntegral();
1931 const Double_t effB01Train = rocCurveTrain->GetEffSForEffB(0.01);
1932 const Double_t effB10Train = rocCurveTrain->GetEffSForEffB(0.10);
1933 const Double_t effB30Train = rocCurveTrain->GetEffSForEffB(0.30);
1934 const Double_t rocaucTest = rocCurveTest->GetROCIntegral();
1935 const Double_t effB01Test = rocCurveTest->GetEffSForEffB(0.01);
1936 const Double_t effB10Test = rocCurveTest->GetEffSForEffB(0.10);
1937 const Double_t effB30Test = rocCurveTest->GetEffSForEffB(0.30);
1938 const TString rocaucCmp = Form(
"%5.3f (%5.3f)", rocaucTest, rocaucTrain);
1939 const TString effB01Cmp = Form(
"%5.3f (%5.3f)", effB01Test, effB01Train);
1940 const TString effB10Cmp = Form(
"%5.3f (%5.3f)", effB10Test, effB10Train);
1941 const TString effB30Cmp = Form(
"%5.3f (%5.3f)", effB30Test, effB30Train);
1942 row = Form(
"%-15s%-15s%-15s%-15s%-15s%-15s",
"", className.Data(), rocaucCmp.Data(), effB01Cmp.Data(),
1943 effB10Cmp.Data(), effB30Cmp.Data());
1944 Log() << kINFO << row << Endl;
1946 delete rocCurveTrain;
1947 delete rocCurveTest;
1951 Log() << kINFO << Endl;
1952 Log() << kINFO << hLine << Endl;
1953 Log() << kINFO << Endl;
1957 auto printMatrix = [](TMatrixD
const &matTraining, TMatrixD
const &matTesting, std::vector<TString> classnames,
1958 UInt_t numClasses, MsgLogger &stream) {
1964 TString header = Form(
" %-14s",
" ");
1965 TString headerInfo = Form(
" %-14s",
" ");
1967 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
1968 header += Form(
" %-14s", classnames[iCol].Data());
1969 headerInfo += Form(
" %-14s",
" test (train)");
1971 stream << kINFO << header << Endl;
1972 stream << kINFO << headerInfo << Endl;
1974 for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
1975 stream << kINFO << Form(
" %-14s", classnames[iRow].Data());
1977 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
1979 stream << kINFO << Form(
" %-14s",
"-");
1981 Double_t trainValue = matTraining[iRow][iCol];
1982 Double_t testValue = matTesting[iRow][iCol];
1983 TString entry = Form(
"%-5.3f (%-5.3f)", testValue, trainValue);
1984 stream << kINFO << Form(
" %-14s", entry.Data());
1987 stream << kINFO << Endl;
1991 Log() << kINFO << Endl;
1992 Log() << kINFO <<
"Confusion matrices for all methods" << Endl;
1993 Log() << kINFO << hLine << Endl;
1994 Log() << kINFO << Endl;
1995 Log() << kINFO <<
"Does a binary comparison between the two classes given by a " << Endl;
1996 Log() << kINFO <<
"particular row-column combination. In each case, the class " << Endl;
1997 Log() << kINFO <<
"given by the row is considered signal while the class given " << Endl;
1998 Log() << kINFO <<
"by the column index is considered background." << Endl;
1999 Log() << kINFO << Endl;
2000 for (UInt_t iMethod = 0; iMethod < methods->size(); ++iMethod) {
2001 MethodBase *theMethod =
dynamic_cast<MethodBase *
>(methods->at(iMethod));
2002 if (theMethod ==
nullptr) {
2005 UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
2007 std::vector<TString> classnames;
2008 for (UInt_t iCls = 0; iCls < numClasses; ++iCls) {
2009 classnames.push_back(theMethod->fDataSetInfo.GetClassInfo(iCls)->GetName());
2012 <<
"=== Showing confusion matrix for method : " << Form(
"%-15s", (
const char *)mname[0][iMethod])
2014 Log() << kINFO <<
"(Signal Efficiency for Background Efficiency 0.01%)" << Endl;
2015 Log() << kINFO <<
"---------------------------------------------------" << Endl;
2016 printMatrix(multiclass_testConfusionEffB01[iMethod], multiclass_trainConfusionEffB01[iMethod], classnames,
2018 Log() << kINFO << Endl;
2020 Log() << kINFO <<
"(Signal Efficiency for Background Efficiency 0.10%)" << Endl;
2021 Log() << kINFO <<
"---------------------------------------------------" << Endl;
2022 printMatrix(multiclass_testConfusionEffB10[iMethod], multiclass_trainConfusionEffB10[iMethod], classnames,
2024 Log() << kINFO << Endl;
2026 Log() << kINFO <<
"(Signal Efficiency for Background Efficiency 0.30%)" << Endl;
2027 Log() << kINFO <<
"---------------------------------------------------" << Endl;
2028 printMatrix(multiclass_testConfusionEffB30[iMethod], multiclass_trainConfusionEffB30[iMethod], classnames,
2030 Log() << kINFO << Endl;
2032 Log() << kINFO << hLine << Endl;
2033 Log() << kINFO << Endl;
2038 Log().EnableOutput();
2039 gConfig().SetSilent(kFALSE);
2041 TString hLine =
"------------------------------------------------------------------------------------------"
2042 "-------------------------";
2043 Log() << kINFO <<
"Evaluation results ranked by best signal efficiency and purity (area)" << Endl;
2044 Log() << kINFO << hLine << Endl;
2045 Log() << kINFO <<
"DataSet MVA " << Endl;
2046 Log() << kINFO <<
"Name: Method: ROC-integ" << Endl;
2051 Log() << kDEBUG << hLine << Endl;
2052 for (Int_t k = 0; k < 2; k++) {
2053 if (k == 1 && nmeth_used[k] > 0) {
2054 Log() << kINFO << hLine << Endl;
2055 Log() << kINFO <<
"Input Variables: " << Endl << hLine << Endl;
2057 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2058 TString datasetName = itrMap->first;
2059 TString methodName = mname[k][i];
2062 methodName.ReplaceAll(
"Variable_",
"");
2065 MethodBase *theMethod =
dynamic_cast<MethodBase *
>(GetMethod(datasetName, methodName));
2066 if (theMethod == 0) {
2070 TMVA::DataSet *dataset = theMethod->Data();
2071 TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
2072 std::vector<Bool_t> *mvaResType =
2073 dynamic_cast<ResultsClassification *
>(results)->GetValueVectorTypes();
2075 Double_t rocIntegral = 0.0;
2076 if (mvaResType->size() != 0) {
2077 rocIntegral = GetROCIntegral(datasetName, methodName);
2080 if (sep[k][i] < 0 || sig[k][i] < 0) {
2082 Log() << kINFO << Form(
"%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), effArea[k][i])
2094 Log() << kINFO << Form(
"%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral)
2108 Log() << kINFO << hLine << Endl;
2109 Log() << kINFO << Endl;
2110 Log() << kINFO <<
"Testing efficiency compared to training efficiency (overtraining check)" << Endl;
2111 Log() << kINFO << hLine << Endl;
2113 <<
"DataSet MVA Signal efficiency: from test sample (from training sample) "
2115 Log() << kINFO <<
"Name: Method: @B=0.01 @B=0.10 @B=0.30 "
2117 Log() << kINFO << hLine << Endl;
2118 for (Int_t k = 0; k < 2; k++) {
2119 if (k == 1 && nmeth_used[k] > 0) {
2120 Log() << kINFO << hLine << Endl;
2121 Log() << kINFO <<
"Input Variables: " << Endl << hLine << Endl;
2123 for (Int_t i = 0; i < nmeth_used[k]; i++) {
2124 if (k == 1) mname[k][i].ReplaceAll(
"Variable_",
"");
2125 MethodBase *theMethod =
dynamic_cast<MethodBase *
>((*methods)[i]);
2126 if (theMethod == 0)
continue;
2128 Log() << kINFO << Form(
"%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
2129 theMethod->fDataSetInfo.GetName(), (
const char *)mname[k][i], eff01[k][i],
2130 trainEff01[k][i], eff10[k][i], trainEff10[k][i], eff30[k][i], trainEff30[k][i])
2134 Log() << kINFO << hLine << Endl;
2135 Log() << kINFO << Endl;
2137 if (gTools().CheckForSilentOption(GetOptions())) Log().InhibitOutput();
2142 std::list<TString> datasets;
2143 for (Int_t k=0; k<2; k++) {
2144 for (Int_t i=0; i<nmeth_used[k]; i++) {
2145 MethodBase* theMethod =
dynamic_cast<MethodBase*
>((*methods)[i]);
2146 if(theMethod==0)
continue;
2148 RootBaseDir()->cd(theMethod->fDataSetInfo.GetName());
2149 if(std::find(datasets.begin(), datasets.end(), theMethod->fDataSetInfo.GetName()) == datasets.end())
2151 theMethod->fDataSetInfo.GetDataSet()->GetTree(Types::kTesting)->Write(
"", TObject::kOverwrite );
2152 theMethod->fDataSetInfo.GetDataSet()->GetTree(Types::kTraining)->Write(
"", TObject::kOverwrite );
2153 datasets.push_back(theMethod->fDataSetInfo.GetName());
2160 gTools().TMVACitation( Log(), Tools::kHtmlLink );
2166 TH1F* TMVA::Factory::EvaluateImportance(DataLoader *loader,VIType vitype, Types::EMVA theMethod, TString methodTitle,
const char *theOption)
2168 fModelPersistence=kFALSE;
2172 const int nbits = loader->GetDataSetInfo().GetNVariables();
2173 if(vitype==VIType::kShort)
2174 return EvaluateImportanceShort(loader,theMethod,methodTitle,theOption);
2175 else if(vitype==VIType::kAll)
2176 return EvaluateImportanceAll(loader,theMethod,methodTitle,theOption);
2177 else if(vitype==VIType::kRandom&&nbits>10)
2179 return EvaluateImportanceRandom(loader,pow(2,nbits),theMethod,methodTitle,theOption);
2182 std::cerr<<
"Error in Variable Importance: Random mode require more that 10 variables in the dataset."<<std::endl;
2189 TH1F* TMVA::Factory::EvaluateImportanceAll(DataLoader *loader, Types::EMVA theMethod, TString methodTitle,
const char *theOption)
2196 const int nbits = loader->GetDataSetInfo().GetNVariables();
2197 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2199 uint64_t range = pow(2, nbits);
2202 std::vector<Double_t> importances(nbits);
2204 std::vector<Double_t> ROC(range);
2206 for (
int i = 0; i < nbits; i++)importances[i] = 0;
2208 Double_t SROC, SSROC;
2209 for ( x = 1; x <range ; x++) {
2211 std::bitset<VIBITS> xbitset(x);
2212 if (x == 0)
continue;
2215 TMVA::DataLoader *seedloader =
new TMVA::DataLoader(xbitset.to_string());
2218 for (
int index = 0; index < nbits; index++) {
2219 if (xbitset[index]) seedloader->AddVariable(varNames[index],
'F');
2222 DataLoaderCopy(seedloader,loader);
2223 seedloader->PrepareTrainingAndTestTree(loader->GetDataSetInfo().GetCut(
"Signal"), loader->GetDataSetInfo().GetCut(
"Background"), loader->GetDataSetInfo().GetSplitOptions());
2226 BookMethod(seedloader, theMethod, methodTitle, theOption);
2231 EvaluateAllMethods();
2234 ROC[x] = GetROCIntegral(xbitset.to_string(), methodTitle);
2237 TMVA::MethodBase *smethod=
dynamic_cast<TMVA::MethodBase*
>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2238 TMVA::ResultsClassification *sresults = (TMVA::ResultsClassification*)smethod->Data()->GetResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
2241 this->DeleteAllMethods();
2243 fMethodsMap.clear();
2248 for ( x = 0; x <range ; x++)
2251 for (uint32_t i = 0; i < VIBITS; ++i) {
2254 std::bitset<VIBITS> ybitset(y);
2258 Double_t ny = log(x - y) / 0.693147;
2260 importances[ny] = SROC - 0.5;
2266 importances[ny] += SROC - SSROC;
2272 std::cout<<
"--- Variable Importance Results (All)"<<std::endl;
2273 return GetImportance(nbits,importances,varNames);
2276 static long int sum(
long int i)
2279 for(
long int n=0;n<i;n++) _sum+=pow(2,n);
2285 TH1F* TMVA::Factory::EvaluateImportanceShort(DataLoader *loader, Types::EMVA theMethod, TString methodTitle,
const char *theOption)
2291 const int nbits = loader->GetDataSetInfo().GetNVariables();
2292 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2294 long int range = sum(nbits);
2297 std::vector<Double_t> importances(nbits);
2298 for (
int i = 0; i < nbits; i++)importances[i] = 0;
2300 Double_t SROC, SSROC;
2304 std::bitset<VIBITS> xbitset(x);
2305 if (x == 0) Log()<<kFATAL<<
"Error: need at least one variable.";
2309 TMVA::DataLoader *seedloader =
new TMVA::DataLoader(xbitset.to_string());
2312 for (
int index = 0; index < nbits; index++) {
2313 if (xbitset[index]) seedloader->AddVariable(varNames[index],
'F');
2317 DataLoaderCopy(seedloader,loader);
2320 BookMethod(seedloader, theMethod, methodTitle, theOption);
2325 EvaluateAllMethods();
2328 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2331 TMVA::MethodBase *smethod=
dynamic_cast<TMVA::MethodBase*
>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2332 TMVA::ResultsClassification *sresults = (TMVA::ResultsClassification*)smethod->Data()->GetResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
2335 this->DeleteAllMethods();
2336 fMethodsMap.clear();
2340 for (uint32_t i = 0; i < VIBITS; ++i) {
2343 std::bitset<VIBITS> ybitset(y);
2347 Double_t ny = log(x - y) / 0.693147;
2349 importances[ny] = SROC - 0.5;
2354 TMVA::DataLoader *subseedloader =
new TMVA::DataLoader(ybitset.to_string());
2356 for (
int index = 0; index < nbits; index++) {
2357 if (ybitset[index]) subseedloader->AddVariable(varNames[index],
'F');
2361 DataLoaderCopy(subseedloader,loader);
2364 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2369 EvaluateAllMethods();
2372 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2373 importances[ny] += SROC - SSROC;
2376 TMVA::MethodBase *ssmethod=
dynamic_cast<TMVA::MethodBase*
>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2377 TMVA::ResultsClassification *ssresults = (TMVA::ResultsClassification*)ssmethod->Data()->GetResults(ssmethod->GetMethodName(), Types::kTesting, Types::kClassification);
2379 delete subseedloader;
2380 this->DeleteAllMethods();
2381 fMethodsMap.clear();
2384 std::cout<<
"--- Variable Importance Results (Short)"<<std::endl;
2385 return GetImportance(nbits,importances,varNames);
2390 TH1F* TMVA::Factory::EvaluateImportanceRandom(DataLoader *loader, UInt_t nseeds, Types::EMVA theMethod, TString methodTitle,
const char *theOption)
2392 TRandom3 *rangen =
new TRandom3(0);
2398 const int nbits = loader->GetDataSetInfo().GetNVariables();
2399 std::vector<TString> varNames = loader->GetDataSetInfo().GetListOfVariables();
2401 long int range = pow(2, nbits);
2404 std::vector<Double_t> importances(nbits);
2405 Double_t importances_norm = 0;
2406 for (
int i = 0; i < nbits; i++)importances[i] = 0;
2408 Double_t SROC, SSROC;
2409 for (UInt_t n = 0; n < nseeds; n++) {
2410 x = rangen -> Integer(range);
2412 std::bitset<32> xbitset(x);
2413 if (x == 0)
continue;
2417 TMVA::DataLoader *seedloader =
new TMVA::DataLoader(xbitset.to_string());
2420 for (
int index = 0; index < nbits; index++) {
2421 if (xbitset[index]) seedloader->AddVariable(varNames[index],
'F');
2425 DataLoaderCopy(seedloader,loader);
2428 BookMethod(seedloader, theMethod, methodTitle, theOption);
2433 EvaluateAllMethods();
2436 SROC = GetROCIntegral(xbitset.to_string(), methodTitle);
2440 TMVA::MethodBase *smethod=
dynamic_cast<TMVA::MethodBase*
>(fMethodsMap[xbitset.to_string().c_str()][0][0]);
2441 TMVA::ResultsClassification *sresults = (TMVA::ResultsClassification*)smethod->Data()->GetResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
2444 this->DeleteAllMethods();
2445 fMethodsMap.clear();
2449 for (uint32_t i = 0; i < 32; ++i) {
2452 std::bitset<32> ybitset(y);
2456 Double_t ny = log(x - y) / 0.693147;
2458 importances[ny] = SROC - 0.5;
2459 importances_norm += importances[ny];
2465 TMVA::DataLoader *subseedloader =
new TMVA::DataLoader(ybitset.to_string());
2467 for (
int index = 0; index < nbits; index++) {
2468 if (ybitset[index]) subseedloader->AddVariable(varNames[index],
'F');
2472 DataLoaderCopy(subseedloader,loader);
2475 BookMethod(subseedloader, theMethod, methodTitle, theOption);
2480 EvaluateAllMethods();
2483 SSROC = GetROCIntegral(ybitset.to_string(), methodTitle);
2484 importances[ny] += SROC - SSROC;
2487 TMVA::MethodBase *ssmethod=
dynamic_cast<TMVA::MethodBase*
>(fMethodsMap[ybitset.to_string().c_str()][0][0]);
2488 TMVA::ResultsClassification *ssresults = (TMVA::ResultsClassification*)ssmethod->Data()->GetResults(ssmethod->GetMethodName(), Types::kTesting, Types::kClassification);
2490 delete subseedloader;
2491 this->DeleteAllMethods();
2492 fMethodsMap.clear();
2496 std::cout<<
"--- Variable Importance Results (Random)"<<std::endl;
2497 return GetImportance(nbits,importances,varNames);
2502 TH1F* TMVA::Factory::GetImportance(
const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames)
2504 TH1F *vih1 =
new TH1F(
"vih1",
"", nbits, 0, nbits);
2506 gStyle->SetOptStat(000000);
2508 Float_t normalization = 0.0;
2509 for (
int i = 0; i < nbits; i++) {
2510 normalization = normalization + importances[i];
2515 gStyle->SetTitleXOffset(0.4);
2516 gStyle->SetTitleXOffset(1.2);
2519 std::vector<Double_t> x_ie(nbits), y_ie(nbits);
2520 for (Int_t i = 1; i < nbits + 1; i++) {
2521 x_ie[i - 1] = (i - 1) * 1.;
2522 roc = 100.0 * importances[i - 1] / normalization;
2524 std::cout<<
"--- "<<varNames[i-1]<<
" = "<<roc<<
" %"<<std::endl;
2525 vih1->GetXaxis()->SetBinLabel(i, varNames[i - 1].Data());
2526 vih1->SetBinContent(i, roc);
2528 TGraph *g_ie =
new TGraph(nbits + 2, &x_ie[0], &y_ie[0]);
2531 vih1->LabelsOption(
"v >",
"X");
2532 vih1->SetBarWidth(0.97);
2533 Int_t ca = TColor::GetColor(
"#006600");
2534 vih1->SetFillColor(ca);
2537 vih1->GetYaxis()->SetTitle(
"Importance (%)");
2538 vih1->GetYaxis()->SetTitleSize(0.045);
2539 vih1->GetYaxis()->CenterTitle();
2540 vih1->GetYaxis()->SetTitleOffset(1.24);
2542 vih1->GetYaxis()->SetRangeUser(-7, 50);
2543 vih1->SetDirectory(0);