128 ClassImp(TMVA::MethodBase);
134 const Bool_t Use_Splines_for_Eff_ = kTRUE;
137 const Int_t NBIN_HIST_HIGH = 10000;
141 #pragma warning ( disable : 4355 )
151 TMVA::IPythonInteractive::IPythonInteractive() : fMultiGraph(new TMultiGraph())
159 TMVA::IPythonInteractive::~IPythonInteractive()
163 fMultiGraph =
nullptr;
174 void TMVA::IPythonInteractive::Init(std::vector<TString>& graphTitles)
177 std::cerr << kERROR <<
"IPythonInteractive::Init: already initialized..." << std::endl;
181 for(
auto& title : graphTitles){
182 fGraphs.push_back(
new TGraph() );
183 fGraphs.back()->SetTitle(title);
184 fGraphs.back()->SetName(title);
185 fGraphs.back()->SetFillColor(color);
186 fGraphs.back()->SetLineColor(color);
187 fGraphs.back()->SetMarkerColor(color);
188 fMultiGraph->Add(fGraphs.back());
198 void TMVA::IPythonInteractive::ClearGraphs()
200 for(Int_t i=0; i<fNumGraphs; i++){
212 void TMVA::IPythonInteractive::AddPoint(Double_t x, Double_t y1, Double_t y2)
214 fGraphs[0]->Set(fIndex+1);
215 fGraphs[1]->Set(fIndex+1);
216 fGraphs[0]->SetPoint(fIndex, x, y1);
217 fGraphs[1]->SetPoint(fIndex, x, y2);
228 void TMVA::IPythonInteractive::AddPoint(std::vector<Double_t>& dat)
230 for(Int_t i=0; i<fNumGraphs;i++){
231 fGraphs[i]->Set(fIndex+1);
232 fGraphs[i]->SetPoint(fIndex, dat[0], dat[i+1]);
242 TMVA::MethodBase::MethodBase(
const TString& jobName,
243 Types::EMVA methodType,
244 const TString& methodTitle,
246 const TString& theOption) :
248 Configurable ( theOption ),
252 fAnalysisType ( Types::kNoAnalysisType ),
253 fRegressionReturnVal ( 0 ),
254 fMulticlassReturnVal ( 0 ),
255 fDataSetInfo ( dsi ),
256 fSignalReferenceCut ( 0.5 ),
257 fSignalReferenceCutOrientation( 1. ),
258 fVariableTransformType ( Types::kSignal ),
259 fJobName ( jobName ),
260 fMethodName ( methodTitle ),
261 fMethodType ( methodType ),
263 fTMVATrainingVersion ( TMVA_VERSION_CODE ),
264 fROOTTrainingVersion ( ROOT_VERSION_CODE ),
265 fConstructedFromWeightFile ( kFALSE ),
267 fMethodBaseDir ( 0 ),
269 fSilentFile (kFALSE),
270 fModelPersistence (kTRUE),
281 fSplTrainEffBvsS ( 0 ),
282 fVarTransformString (
"None" ),
283 fTransformationPointer ( 0 ),
284 fTransformation ( dsi, methodTitle ),
286 fVerbosityLevelString (
"Default" ),
288 fHasMVAPdfs ( kFALSE ),
289 fIgnoreNegWeightsInTraining( kFALSE ),
291 fBackgroundClass ( 0 ),
296 fSetupCompleted (kFALSE)
299 fLogger->SetSource(GetName());
308 TMVA::MethodBase::MethodBase( Types::EMVA methodType,
310 const TString& weightFile ) :
316 fAnalysisType ( Types::kNoAnalysisType ),
317 fRegressionReturnVal ( 0 ),
318 fMulticlassReturnVal ( 0 ),
319 fDataSetInfo ( dsi ),
320 fSignalReferenceCut ( 0.5 ),
321 fVariableTransformType ( Types::kSignal ),
323 fMethodName (
"MethodBase" ),
324 fMethodType ( methodType ),
326 fTMVATrainingVersion ( 0 ),
327 fROOTTrainingVersion ( 0 ),
328 fConstructedFromWeightFile ( kTRUE ),
330 fMethodBaseDir ( 0 ),
332 fSilentFile (kFALSE),
333 fModelPersistence (kTRUE),
334 fWeightFile ( weightFile ),
344 fSplTrainEffBvsS ( 0 ),
345 fVarTransformString (
"None" ),
346 fTransformationPointer ( 0 ),
347 fTransformation ( dsi,
"" ),
349 fVerbosityLevelString (
"Default" ),
351 fHasMVAPdfs ( kFALSE ),
352 fIgnoreNegWeightsInTraining( kFALSE ),
354 fBackgroundClass ( 0 ),
359 fSetupCompleted (kFALSE)
361 fLogger->SetSource(GetName());
369 TMVA::MethodBase::~MethodBase(
void )
372 if (!fSetupCompleted) Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Calling destructor of method which got never setup" << Endl;
375 if (fInputVars != 0) { fInputVars->clear();
delete fInputVars; }
376 if (fRanking != 0)
delete fRanking;
379 if (fDefaultPDF!= 0) {
delete fDefaultPDF; fDefaultPDF = 0; }
380 if (fMVAPdfS != 0) {
delete fMVAPdfS; fMVAPdfS = 0; }
381 if (fMVAPdfB != 0) {
delete fMVAPdfB; fMVAPdfB = 0; }
384 if (fSplS) {
delete fSplS; fSplS = 0; }
385 if (fSplB) {
delete fSplB; fSplB = 0; }
386 if (fSpleffBvsS) {
delete fSpleffBvsS; fSpleffBvsS = 0; }
387 if (fSplRefS) {
delete fSplRefS; fSplRefS = 0; }
388 if (fSplRefB) {
delete fSplRefB; fSplRefB = 0; }
389 if (fSplTrainRefS) {
delete fSplTrainRefS; fSplTrainRefS = 0; }
390 if (fSplTrainRefB) {
delete fSplTrainRefB; fSplTrainRefB = 0; }
391 if (fSplTrainEffBvsS) {
delete fSplTrainEffBvsS; fSplTrainEffBvsS = 0; }
393 for (Int_t i = 0; i < 2; i++ ) {
394 if (fEventCollections.at(i)) {
395 for (std::vector<Event*>::const_iterator it = fEventCollections.at(i)->begin();
396 it != fEventCollections.at(i)->end(); ++it) {
399 delete fEventCollections.at(i);
400 fEventCollections.at(i) = 0;
404 if (fRegressionReturnVal)
delete fRegressionReturnVal;
405 if (fMulticlassReturnVal)
delete fMulticlassReturnVal;
411 void TMVA::MethodBase::SetupMethod()
415 if (fSetupCompleted) Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Calling SetupMethod for the second time" << Endl;
417 DeclareBaseOptions();
420 fSetupCompleted = kTRUE;
428 void TMVA::MethodBase::ProcessSetup()
430 ProcessBaseOptions();
438 void TMVA::MethodBase::CheckSetup()
440 CheckForUnusedOptions();
446 void TMVA::MethodBase::InitBase()
448 SetConfigDescription(
"Configuration options for classifier architecture and tuning" );
450 fNbins = gConfig().fVariablePlotting.fNbinsXOfROCCurve;
451 fNbinsMVAoutput = gConfig().fVariablePlotting.fNbinsMVAoutput;
452 fNbinsH = NBIN_HIST_HIGH;
456 fSplTrainEffBvsS = 0;
463 fTxtWeightsOnly = kTRUE;
473 fInputVars =
new std::vector<TString>;
474 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
475 fInputVars->push_back(DataInfo().GetVariableInfo(ivar).GetLabel());
477 fRegressionReturnVal = 0;
478 fMulticlassReturnVal = 0;
480 fEventCollections.resize( 2 );
481 fEventCollections.at(0) = 0;
482 fEventCollections.at(1) = 0;
485 if (DataInfo().GetClassInfo(
"Signal") != 0) {
486 fSignalClass = DataInfo().GetClassInfo(
"Signal")->GetNumber();
488 if (DataInfo().GetClassInfo(
"Background") != 0) {
489 fBackgroundClass = DataInfo().GetClassInfo(
"Background")->GetNumber();
492 SetConfigDescription(
"Configuration options for MVA method" );
493 SetConfigName( TString(
"Method") + GetMethodTypeName() );
514 void TMVA::MethodBase::DeclareBaseOptions()
516 DeclareOptionRef( fVerbose,
"V",
"Verbose output (short form of \"VerbosityLevel\" below - overrides the latter one)" );
518 DeclareOptionRef( fVerbosityLevelString=
"Default",
"VerbosityLevel",
"Verbosity level" );
519 AddPreDefVal( TString(
"Default") );
520 AddPreDefVal( TString(
"Debug") );
521 AddPreDefVal( TString(
"Verbose") );
522 AddPreDefVal( TString(
"Info") );
523 AddPreDefVal( TString(
"Warning") );
524 AddPreDefVal( TString(
"Error") );
525 AddPreDefVal( TString(
"Fatal") );
529 fTxtWeightsOnly = kTRUE;
532 DeclareOptionRef( fVarTransformString,
"VarTransform",
"List of variable transformations performed before training, e.g., \"D_Background,P_Signal,G,N_AllClasses\" for: \"Decorrelation, PCA-transformation, Gaussianisation, Normalisation, each for the given class of events ('AllClasses' denotes all events of all classes, if no class indication is given, 'All' is assumed)\"" );
534 DeclareOptionRef( fHelp,
"H",
"Print method-specific help message" );
536 DeclareOptionRef( fHasMVAPdfs,
"CreateMVAPdfs",
"Create PDFs for classifier outputs (signal and background)" );
538 DeclareOptionRef( fIgnoreNegWeightsInTraining,
"IgnoreNegWeightsInTraining",
539 "Events with negative weights are ignored in the training (but are included for testing and performance evaluation)" );
545 void TMVA::MethodBase::ProcessBaseOptions()
551 fDefaultPDF =
new PDF( TString(GetName())+
"_PDF", GetOptions(),
"MVAPdf" );
552 fDefaultPDF->DeclareOptions();
553 fDefaultPDF->ParseOptions();
554 fDefaultPDF->ProcessOptions();
555 fMVAPdfB =
new PDF( TString(GetName())+
"_PDFBkg", fDefaultPDF->GetOptions(),
"MVAPdfBkg", fDefaultPDF );
556 fMVAPdfB->DeclareOptions();
557 fMVAPdfB->ParseOptions();
558 fMVAPdfB->ProcessOptions();
559 fMVAPdfS =
new PDF( TString(GetName())+
"_PDFSig", fMVAPdfB->GetOptions(),
"MVAPdfSig", fDefaultPDF );
560 fMVAPdfS->DeclareOptions();
561 fMVAPdfS->ParseOptions();
562 fMVAPdfS->ProcessOptions();
565 SetOptions( fMVAPdfS->GetOptions() );
568 TMVA::CreateVariableTransforms( fVarTransformString,
570 GetTransformationHandler(),
574 if (fDefaultPDF!= 0) {
delete fDefaultPDF; fDefaultPDF = 0; }
575 if (fMVAPdfS != 0) {
delete fMVAPdfS; fMVAPdfS = 0; }
576 if (fMVAPdfB != 0) {
delete fMVAPdfB; fMVAPdfB = 0; }
580 fVerbosityLevelString = TString(
"Verbose");
581 Log().SetMinType( kVERBOSE );
583 else if (fVerbosityLevelString ==
"Debug" ) Log().SetMinType( kDEBUG );
584 else if (fVerbosityLevelString ==
"Verbose" ) Log().SetMinType( kVERBOSE );
585 else if (fVerbosityLevelString ==
"Info" ) Log().SetMinType( kINFO );
586 else if (fVerbosityLevelString ==
"Warning" ) Log().SetMinType( kWARNING );
587 else if (fVerbosityLevelString ==
"Error" ) Log().SetMinType( kERROR );
588 else if (fVerbosityLevelString ==
"Fatal" ) Log().SetMinType( kFATAL );
589 else if (fVerbosityLevelString !=
"Default" ) {
590 Log() << kFATAL <<
"<ProcessOptions> Verbosity level type '"
591 << fVerbosityLevelString <<
"' unknown." << Endl;
593 Event::SetIgnoreNegWeightsInTraining(fIgnoreNegWeightsInTraining);
601 void TMVA::MethodBase::DeclareCompatibilityOptions()
603 DeclareOptionRef( fNormalise=kFALSE,
"Normalise",
"Normalise input variables" );
604 DeclareOptionRef( fUseDecorr=kFALSE,
"D",
"Use-decorrelated-variables flag" );
605 DeclareOptionRef( fVariableTransformTypeString=
"Signal",
"VarTransformType",
606 "Use signal or background events to derive for variable transformation (the transformation is applied on both types of, course)" );
607 AddPreDefVal( TString(
"Signal") );
608 AddPreDefVal( TString(
"Background") );
609 DeclareOptionRef( fTxtWeightsOnly=kTRUE,
"TxtWeightFilesOnly",
"If True: write all training results (weights) as text files (False: some are written in ROOT format)" );
619 DeclareOptionRef( fNbinsMVAPdf = 60,
"NbinsMVAPdf",
"Number of bins used for the PDFs of classifier outputs" );
620 DeclareOptionRef( fNsmoothMVAPdf = 2,
"NsmoothMVAPdf",
"Number of smoothing iterations for classifier PDFs" );
628 std::map<TString,Double_t> TMVA::MethodBase::OptimizeTuningParameters(TString , TString )
634 Log() << kWARNING <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Parameter optimization is not yet implemented for method "
635 << GetName() << Endl;
636 Log() << kWARNING <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Currently we need to set hardcoded which parameter is tuned in which ranges"<<Endl;
638 std::map<TString,Double_t> tunedParameters;
639 tunedParameters.size();
640 return tunedParameters;
649 void TMVA::MethodBase::SetTuneParameters(std::map<TString,Double_t> )
655 void TMVA::MethodBase::TrainMethod()
657 Data()->SetCurrentType(Types::kTraining);
658 Event::SetIsTraining(kTRUE);
661 if (Help()) PrintHelpMessage();
664 if(!IsSilentFile()) BaseDir()->cd();
668 GetTransformationHandler().CalcTransformations(Data()->GetEventCollection());
672 <<
"Begin training" << Endl;
673 Long64_t nEvents = Data()->GetNEvents();
674 Timer traintimer( nEvents, GetName(), kTRUE );
677 <<
"\tEnd of training " << Endl;
678 SetTrainTime(traintimer.ElapsedSeconds());
680 <<
"Elapsed time for training with " << nEvents <<
" events: "
681 << traintimer.GetElapsedTime() <<
" " << Endl;
684 <<
"\tCreate MVA output for ";
687 if (DoMulticlass()) {
688 Log() <<Form(
"[%s] : ",DataInfo().GetName())<<
"Multiclass classification on training sample" << Endl;
689 AddMulticlassOutput(Types::kTraining);
691 else if (!DoRegression()) {
693 Log() <<Form(
"[%s] : ",DataInfo().GetName())<<
"classification on training sample" << Endl;
694 AddClassifierOutput(Types::kTraining);
697 AddClassifierOutputProb(Types::kTraining);
702 Log() <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"regression on training sample" << Endl;
703 AddRegressionOutput( Types::kTraining );
706 Log() <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Create PDFs" << Endl;
713 if (fModelPersistence ) WriteStateToFile();
716 if ((!DoRegression()) && (fModelPersistence)) MakeClass();
723 WriteMonitoringHistosToFile();
729 void TMVA::MethodBase::GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t& stddev, Double_t& stddev90Percent )
const
731 if (!DoRegression()) Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Trying to use GetRegressionDeviation() with a classification job" << Endl;
732 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Create results for " << (type==Types::kTraining?
"training":
"testing") << Endl;
733 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), Types::kTesting, Types::kRegression);
734 bool truncate =
false;
735 TH1F* h1 = regRes->QuadraticDeviation( tgtNum , truncate, 1.);
736 stddev = sqrt(h1->GetMean());
738 Double_t yq[1], xq[]={0.9};
739 h1->GetQuantiles(1,yq,xq);
740 TH1F* h2 = regRes->QuadraticDeviation( tgtNum , truncate, yq[0]);
741 stddev90Percent = sqrt(h2->GetMean());
749 void TMVA::MethodBase::AddRegressionOutput(Types::ETreeType type)
751 Data()->SetCurrentType(type);
753 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Create results for " << (type==Types::kTraining?
"training":
"testing") << Endl;
755 ResultsRegression* regRes = (ResultsRegression*)Data()->GetResults(GetMethodName(), type, Types::kRegression);
757 Long64_t nEvents = Data()->GetNEvents();
760 Timer timer( nEvents, GetName(), kTRUE );
761 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName()) <<
"Evaluation of " << GetMethodName() <<
" on "
762 << (type==Types::kTraining?
"training":
"testing") <<
" sample" << Endl;
764 regRes->Resize( nEvents );
769 Int_t totalProgressDraws = 100;
770 Int_t drawProgressEvery = 1;
771 if(nEvents >= totalProgressDraws) drawProgressEvery = nEvents/totalProgressDraws;
773 for (Int_t ievt=0; ievt<nEvents; ievt++) {
775 Data()->SetCurrentEvent(ievt);
776 std::vector< Float_t > vals = GetRegressionValues();
777 regRes->SetValue( vals, ievt );
780 if(ievt % drawProgressEvery == 0 || ievt==nEvents-1) timer.DrawProgressBar( ievt );
783 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())
784 <<
"Elapsed time for evaluation of " << nEvents <<
" events: "
785 << timer.GetElapsedTime() <<
" " << Endl;
788 if (type==Types::kTesting)
789 SetTestTime(timer.ElapsedSeconds());
791 TString histNamePrefix(GetTestvarName());
792 histNamePrefix += (type==Types::kTraining?
"train":
"test");
793 regRes->CreateDeviationHistograms( histNamePrefix );
799 void TMVA::MethodBase::AddMulticlassOutput(Types::ETreeType type)
801 Data()->SetCurrentType(type);
803 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Create results for " << (type==Types::kTraining?
"training":
"testing") << Endl;
805 ResultsMulticlass* resMulticlass =
dynamic_cast<ResultsMulticlass*
>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
806 if (!resMulticlass) Log() << kFATAL<<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"unable to create pointer in AddMulticlassOutput, exiting."<<Endl;
808 Long64_t nEvents = Data()->GetNEvents();
811 Timer timer( nEvents, GetName(), kTRUE );
813 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Multiclass evaluation of " << GetMethodName() <<
" on "
814 << (type==Types::kTraining?
"training":
"testing") <<
" sample" << Endl;
816 resMulticlass->Resize( nEvents );
817 for (Int_t ievt=0; ievt<nEvents; ievt++) {
818 Data()->SetCurrentEvent(ievt);
819 std::vector< Float_t > vals = GetMulticlassValues();
820 resMulticlass->SetValue( vals, ievt );
821 timer.DrawProgressBar( ievt );
824 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())
825 <<
"Elapsed time for evaluation of " << nEvents <<
" events: "
826 << timer.GetElapsedTime() <<
" " << Endl;
829 if (type==Types::kTesting)
830 SetTestTime(timer.ElapsedSeconds());
832 TString histNamePrefix(GetTestvarName());
833 histNamePrefix += (type==Types::kTraining?
"_Train":
"_Test");
835 resMulticlass->CreateMulticlassHistos( histNamePrefix, fNbinsMVAoutput, fNbinsH );
836 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefix);
841 void TMVA::MethodBase::NoErrorCalc(Double_t*
const err, Double_t*
const errUpper) {
843 if (errUpper) *errUpper=-1;
848 Double_t TMVA::MethodBase::GetMvaValue(
const Event*
const ev, Double_t* err, Double_t* errUpper ) {
850 Double_t val = GetMvaValue(err, errUpper);
859 Bool_t TMVA::MethodBase::IsSignalLike() {
860 return GetMvaValue()*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
866 Bool_t TMVA::MethodBase::IsSignalLike(Double_t mvaVal) {
867 return mvaVal*GetSignalReferenceCutOrientation() > GetSignalReferenceCut()*GetSignalReferenceCutOrientation() ? kTRUE : kFALSE;
873 void TMVA::MethodBase::AddClassifierOutput( Types::ETreeType type )
875 Data()->SetCurrentType(type);
877 ResultsClassification* clRes =
878 (ResultsClassification*)Data()->GetResults(GetMethodName(), type, Types::kClassification );
880 Long64_t nEvents = Data()->GetNEvents();
881 clRes->Resize( nEvents );
884 Timer timer( nEvents, GetName(), kTRUE );
885 std::vector<Double_t> mvaValues = GetMvaValues(0, nEvents,
true);
888 if (type==Types::kTesting)
889 SetTestTime(timer.ElapsedSeconds());
892 for (Int_t ievt=0; ievt<nEvents; ievt++) {
893 clRes->SetValue( mvaValues[ievt], ievt );
899 std::vector<Double_t> TMVA::MethodBase::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
902 Long64_t nEvents = Data()->GetNEvents();
903 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
904 if (firstEvt < 0) firstEvt = 0;
905 std::vector<Double_t> values(lastEvt-firstEvt);
907 nEvents = values.size();
910 Timer timer( nEvents, GetName(), kTRUE );
913 Log() << kHEADER << Form(
"[%s] : ",DataInfo().GetName())
914 <<
"Evaluation of " << GetMethodName() <<
" on "
915 << (Data()->GetCurrentType() == Types::kTraining ?
"training" :
"testing")
916 <<
" sample (" << nEvents <<
" events)" << Endl;
918 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
919 Data()->SetCurrentEvent(ievt);
920 values[ievt] = GetMvaValue();
924 Int_t modulo = Int_t(nEvents/100);
925 if (modulo <= 0 ) modulo = 1;
926 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
931 <<
"Elapsed time for evaluation of " << nEvents <<
" events: "
932 << timer.GetElapsedTime() <<
" " << Endl;
941 void TMVA::MethodBase::AddClassifierOutputProb( Types::ETreeType type )
943 Data()->SetCurrentType(type);
945 ResultsClassification* mvaProb =
946 (ResultsClassification*)Data()->GetResults(TString(
"prob_")+GetMethodName(), type, Types::kClassification );
948 Long64_t nEvents = Data()->GetNEvents();
951 Timer timer( nEvents, GetName(), kTRUE );
953 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName()) <<
"Evaluation of " << GetMethodName() <<
" on "
954 << (type==Types::kTraining?
"training":
"testing") <<
" sample" << Endl;
956 mvaProb->Resize( nEvents );
957 for (Int_t ievt=0; ievt<nEvents; ievt++) {
959 Data()->SetCurrentEvent(ievt);
960 Float_t proba = ((Float_t)GetProba( GetMvaValue(), 0.5 ));
961 if (proba < 0)
break;
962 mvaProb->SetValue( proba, ievt );
965 Int_t modulo = Int_t(nEvents/100);
966 if (modulo <= 0 ) modulo = 1;
967 if (ievt%modulo == 0) timer.DrawProgressBar( ievt );
970 Log() << kDEBUG <<Form(
"Dataset[%s] : ",DataInfo().GetName())
971 <<
"Elapsed time for evaluation of " << nEvents <<
" events: "
972 << timer.GetElapsedTime() <<
" " << Endl;
982 void TMVA::MethodBase::TestRegression( Double_t& bias, Double_t& biasT,
983 Double_t& dev, Double_t& devT,
984 Double_t& rms, Double_t& rmsT,
985 Double_t& mInf, Double_t& mInfT,
987 Types::ETreeType type )
989 Types::ETreeType savedType = Data()->GetCurrentType();
990 Data()->SetCurrentType(type);
992 bias = 0; biasT = 0; dev = 0; devT = 0; rms = 0; rmsT = 0;
994 Double_t m1 = 0, m2 = 0, s1 = 0, s2 = 0, s12 = 0;
995 const Int_t nevt = GetNEvents();
996 Float_t* rV =
new Float_t[nevt];
997 Float_t* tV =
new Float_t[nevt];
998 Float_t* wV =
new Float_t[nevt];
999 Float_t xmin = 1e30, xmax = -1e30;
1000 Log() << kINFO <<
"Calculate regression for all events" << Endl;
1001 Timer timer( nevt, GetName(), kTRUE );
1002 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1004 const Event* ev = Data()->GetEvent(ievt);
1005 Float_t t = ev->GetTarget(0);
1006 Float_t w = ev->GetWeight();
1007 Float_t r = GetRegressionValues()[0];
1011 xmin = TMath::Min(xmin, TMath::Min(t, r));
1012 xmax = TMath::Max(xmax, TMath::Max(t, r));
1022 dev += w * TMath::Abs(d);
1026 m1 += t*w; s1 += t*t*w;
1027 m2 += r*w; s2 += r*r*w;
1029 if ((ievt & 0xFF) == 0) timer.DrawProgressBar(ievt);
1031 timer.DrawProgressBar(nevt - 1);
1032 Log() << kINFO <<
"Elapsed time for evaluation of " << nevt <<
" events: "
1033 << timer.GetElapsedTime() <<
" " << Endl;
1039 rms = TMath::Sqrt(rms - bias*bias);
1044 corr = s12/sumw - m1*m2;
1045 corr /= TMath::Sqrt( (s1/sumw - m1*m1) * (s2/sumw - m2*m2) );
1048 TH2F* hist =
new TH2F(
"hist",
"hist", 150, xmin, xmax, 100, xmin, xmax );
1049 TH2F* histT =
new TH2F(
"histT",
"histT", 150, xmin, xmax, 100, xmin, xmax );
1052 Double_t devMax = bias + 2*rms;
1053 Double_t devMin = bias - 2*rms;
1056 for (Long64_t ievt=0; ievt<nevt; ievt++) {
1057 Float_t d = (rV[ievt] - tV[ievt]);
1058 hist->Fill( rV[ievt], tV[ievt], wV[ievt] );
1059 if (d >= devMin && d <= devMax) {
1061 biasT += wV[ievt] * d;
1062 devT += wV[ievt] * TMath::Abs(d);
1063 rmsT += wV[ievt] * d * d;
1064 histT->Fill( rV[ievt], tV[ievt], wV[ievt] );
1071 rmsT = TMath::Sqrt(rmsT - biasT*biasT);
1072 mInf = gTools().GetMutualInformation( *hist );
1073 mInfT = gTools().GetMutualInformation( *histT );
1082 Data()->SetCurrentType(savedType);
1089 void TMVA::MethodBase::TestMulticlass()
1091 ResultsMulticlass* resMulticlass =
dynamic_cast<ResultsMulticlass*
>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
1092 if (!resMulticlass) Log() << kFATAL<<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"unable to create pointer in TestMulticlass, exiting."<<Endl;
1101 TString histNamePrefix(GetTestvarName());
1102 TString histNamePrefixTest{histNamePrefix +
"_Test"};
1103 TString histNamePrefixTrain{histNamePrefix +
"_Train"};
1105 resMulticlass->CreateMulticlassHistos(histNamePrefixTest, fNbinsMVAoutput, fNbinsH);
1106 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTest);
1108 resMulticlass->CreateMulticlassHistos(histNamePrefixTrain, fNbinsMVAoutput, fNbinsH);
1109 resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefixTrain);
1116 void TMVA::MethodBase::TestClassification()
1118 Data()->SetCurrentType(Types::kTesting);
1120 ResultsClassification* mvaRes =
dynamic_cast<ResultsClassification*
>
1121 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
1124 if (0==mvaRes && !(GetMethodTypeName().Contains(
"Cuts"))) {
1125 Log()<<Form(
"Dataset[%s] : ",DataInfo().GetName()) <<
"mvaRes " << mvaRes <<
" GetMethodTypeName " << GetMethodTypeName()
1126 <<
" contains " << !(GetMethodTypeName().Contains(
"Cuts")) << Endl;
1127 Log() << kFATAL<<Form(
"Dataset[%s] : ",DataInfo().GetName()) <<
"<TestInit> Test variable " << GetTestvarName()
1128 <<
" not found in tree" << Endl;
1132 gTools().ComputeStat( GetEventCollection(Types::kTesting), mvaRes->GetValueVector(),
1133 fMeanS, fMeanB, fRmsS, fRmsB, fXmin, fXmax, fSignalClass );
1137 fXmin = TMath::Max( TMath::Min( fMeanS - nrms*fRmsS, fMeanB - nrms*fRmsB ), fXmin );
1138 fXmax = TMath::Min( TMath::Max( fMeanS + nrms*fRmsS, fMeanB + nrms*fRmsB ), fXmax );
1141 fCutOrientation = (fMeanS > fMeanB) ? kPositive : kNegative;
1146 Double_t sxmax = fXmax+0.00001;
1150 TString TestvarName;
1153 TestvarName=Form(
"[%s]%s",DataInfo().GetName(),GetTestvarName().Data());
1156 TestvarName=GetTestvarName();
1158 TH1* mva_s =
new TH1D( TestvarName +
"_S",TestvarName +
"_S", fNbinsMVAoutput, fXmin, sxmax );
1159 TH1* mva_b =
new TH1D( TestvarName +
"_B",TestvarName +
"_B", fNbinsMVAoutput, fXmin, sxmax );
1160 mvaRes->Store(mva_s,
"MVA_S");
1161 mvaRes->Store(mva_b,
"MVA_B");
1171 proba_s =
new TH1D( TestvarName +
"_Proba_S", TestvarName +
"_Proba_S", fNbinsMVAoutput, 0.0, 1.0 );
1172 proba_b =
new TH1D( TestvarName +
"_Proba_B", TestvarName +
"_Proba_B", fNbinsMVAoutput, 0.0, 1.0 );
1173 mvaRes->Store(proba_s,
"Prob_S");
1174 mvaRes->Store(proba_b,
"Prob_B");
1179 rarity_s =
new TH1D( TestvarName +
"_Rarity_S", TestvarName +
"_Rarity_S", fNbinsMVAoutput, 0.0, 1.0 );
1180 rarity_b =
new TH1D( TestvarName +
"_Rarity_B", TestvarName +
"_Rarity_B", fNbinsMVAoutput, 0.0, 1.0 );
1181 mvaRes->Store(rarity_s,
"Rar_S");
1182 mvaRes->Store(rarity_b,
"Rar_B");
1188 TH1* mva_eff_s =
new TH1D( TestvarName +
"_S_high", TestvarName +
"_S_high", fNbinsH, fXmin, sxmax );
1189 TH1* mva_eff_b =
new TH1D( TestvarName +
"_B_high", TestvarName +
"_B_high", fNbinsH, fXmin, sxmax );
1190 mvaRes->Store(mva_eff_s,
"MVA_HIGHBIN_S");
1191 mvaRes->Store(mva_eff_b,
"MVA_HIGHBIN_B");
1197 ResultsClassification* mvaProb =
dynamic_cast<ResultsClassification*
>
1198 (Data()->GetResults( TString(
"prob_")+GetMethodName(), Types::kTesting, Types::kMaxAnalysisType ) );
1200 Log() << kHEADER <<Form(
"[%s] : ",DataInfo().GetName())<<
"Loop over test events and fill histograms with classifier response..." << Endl << Endl;
1201 if (mvaProb) Log() << kINFO <<
"Also filling probability and rarity histograms (on request)..." << Endl;
1202 std::vector<Bool_t>* mvaResTypes = mvaRes->GetValueVectorTypes();
1205 if ( mvaRes->GetSize() != GetNEvents() ) {
1206 Log() << kFATAL << TString::Format(
"Inconsistent result size %lld with number of events %u ", mvaRes->GetSize() , GetNEvents() ) << Endl;
1207 assert(mvaRes->GetSize() == GetNEvents());
1210 for (Long64_t ievt=0; ievt<GetNEvents(); ievt++) {
1212 const Event* ev = GetEvent(ievt);
1213 Float_t v = (*mvaRes)[ievt][0];
1214 Float_t w = ev->GetWeight();
1216 if (DataInfo().IsSignal(ev)) {
1217 mvaResTypes->push_back(kTRUE);
1218 mva_s ->Fill( v, w );
1220 proba_s->Fill( (*mvaProb)[ievt][0], w );
1221 rarity_s->Fill( GetRarity( v ), w );
1224 mva_eff_s ->Fill( v, w );
1227 mvaResTypes->push_back(kFALSE);
1228 mva_b ->Fill( v, w );
1230 proba_b->Fill( (*mvaProb)[ievt][0], w );
1231 rarity_b->Fill( GetRarity( v ), w );
1233 mva_eff_b ->Fill( v, w );
1238 gTools().NormHist( mva_s );
1239 gTools().NormHist( mva_b );
1240 gTools().NormHist( proba_s );
1241 gTools().NormHist( proba_b );
1242 gTools().NormHist( rarity_s );
1243 gTools().NormHist( rarity_b );
1244 gTools().NormHist( mva_eff_s );
1245 gTools().NormHist( mva_eff_b );
1248 if (fSplS) {
delete fSplS; fSplS = 0; }
1249 if (fSplB) {
delete fSplB; fSplB = 0; }
1250 fSplS =
new PDF( TString(GetName()) +
" PDF Sig", mva_s, PDF::kSpline2 );
1251 fSplB =
new PDF( TString(GetName()) +
" PDF Bkg", mva_b, PDF::kSpline2 );
1258 void TMVA::MethodBase::WriteStateToStream( std::ostream& tf )
const
1260 TString prefix =
"";
1261 UserGroup_t * userInfo = gSystem->GetUserInfo();
1263 tf << prefix <<
"#GEN -*-*-*-*-*-*-*-*-*-*-*- general info -*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1264 tf << prefix <<
"Method : " << GetMethodTypeName() <<
"::" << GetMethodName() << std::endl;
1265 tf.setf(std::ios::left);
1266 tf << prefix <<
"TMVA Release : " << std::setw(10) << GetTrainingTMVAVersionString() <<
" ["
1267 << GetTrainingTMVAVersionCode() <<
"]" << std::endl;
1268 tf << prefix <<
"ROOT Release : " << std::setw(10) << GetTrainingROOTVersionString() <<
" ["
1269 << GetTrainingROOTVersionCode() <<
"]" << std::endl;
1270 tf << prefix <<
"Creator : " << userInfo->fUser << std::endl;
1271 tf << prefix <<
"Date : "; TDatime *d =
new TDatime; tf << d->AsString() << std::endl;
delete d;
1272 tf << prefix <<
"Host : " << gSystem->GetBuildNode() << std::endl;
1273 tf << prefix <<
"Dir : " << gSystem->WorkingDirectory() << std::endl;
1274 tf << prefix <<
"Training events: " << Data()->GetNTrainingEvents() << std::endl;
1276 TString analysisType(((const_cast<TMVA::MethodBase*>(
this)->GetAnalysisType()==Types::kRegression) ?
"Regression" :
"Classification"));
1278 tf << prefix <<
"Analysis type : " <<
"[" << ((GetAnalysisType()==Types::kRegression) ?
"Regression" :
"Classification") <<
"]" << std::endl;
1279 tf << prefix << std::endl;
1284 tf << prefix << std::endl << prefix <<
"#OPT -*-*-*-*-*-*-*-*-*-*-*-*- options -*-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1285 WriteOptionsToStream( tf, prefix );
1286 tf << prefix << std::endl;
1289 tf << prefix << std::endl << prefix <<
"#VAR -*-*-*-*-*-*-*-*-*-*-*-* variables *-*-*-*-*-*-*-*-*-*-*-*-" << std::endl << prefix << std::endl;
1290 WriteVarsToStream( tf, prefix );
1291 tf << prefix << std::endl;
1297 void TMVA::MethodBase::AddInfoItem(
void* gi,
const TString& name,
const TString& value)
const
1299 void* it = gTools().AddChild(gi,
"Info");
1300 gTools().AddAttr(it,
"name", name);
1301 gTools().AddAttr(it,
"value", value);
1306 void TMVA::MethodBase::AddOutput( Types::ETreeType type, Types::EAnalysisType analysisType ) {
1307 if (analysisType == Types::kRegression) {
1308 AddRegressionOutput( type );
1309 }
else if (analysisType == Types::kMulticlass) {
1310 AddMulticlassOutput( type );
1312 AddClassifierOutput( type );
1314 AddClassifierOutputProb( type );
1322 void TMVA::MethodBase::WriteStateToXML(
void* parent )
const
1324 if (!parent)
return;
1326 UserGroup_t* userInfo = gSystem->GetUserInfo();
1328 void* gi = gTools().AddChild(parent,
"GeneralInfo");
1329 AddInfoItem( gi,
"TMVA Release", GetTrainingTMVAVersionString() +
" [" + gTools().StringFromInt(GetTrainingTMVAVersionCode()) +
"]" );
1330 AddInfoItem( gi,
"ROOT Release", GetTrainingROOTVersionString() +
" [" + gTools().StringFromInt(GetTrainingROOTVersionCode()) +
"]");
1331 AddInfoItem( gi,
"Creator", userInfo->fUser);
1332 TDatime dt; AddInfoItem( gi,
"Date", dt.AsString());
1333 AddInfoItem( gi,
"Host", gSystem->GetBuildNode() );
1334 AddInfoItem( gi,
"Dir", gSystem->WorkingDirectory());
1335 AddInfoItem( gi,
"Training events", gTools().StringFromInt(Data()->GetNTrainingEvents()));
1336 AddInfoItem( gi,
"TrainingTime", gTools().StringFromDouble(const_cast<TMVA::MethodBase*>(
this)->GetTrainTime()));
1338 Types::EAnalysisType aType =
const_cast<TMVA::MethodBase*
>(
this)->GetAnalysisType();
1339 TString analysisType((aType==Types::kRegression) ?
"Regression" :
1340 (aType==Types::kMulticlass ?
"Multiclass" :
"Classification"));
1341 AddInfoItem( gi,
"AnalysisType", analysisType );
1345 AddOptionsXMLTo( parent );
1348 AddVarsXMLTo( parent );
1351 if (fModelPersistence)
1352 AddSpectatorsXMLTo( parent );
1355 AddClassesXMLTo(parent);
1358 if (DoRegression()) AddTargetsXMLTo(parent);
1361 GetTransformationHandler(
false).AddXMLTo( parent );
1364 void* pdfs = gTools().AddChild(parent,
"MVAPdfs");
1365 if (fMVAPdfS) fMVAPdfS->AddXMLTo(pdfs);
1366 if (fMVAPdfB) fMVAPdfB->AddXMLTo(pdfs);
1369 AddWeightsXMLTo( parent );
1376 void TMVA::MethodBase::ReadStateFromStream( TFile& rf )
1378 Bool_t addDirStatus = TH1::AddDirectoryStatus();
1379 TH1::AddDirectory( 0 );
1380 fMVAPdfS = (TMVA::PDF*)rf.Get(
"MVA_PDF_Signal" );
1381 fMVAPdfB = (TMVA::PDF*)rf.Get(
"MVA_PDF_Background" );
1383 TH1::AddDirectory( addDirStatus );
1385 ReadWeightsFromStream( rf );
1395 void TMVA::MethodBase::WriteStateToFile()
const
1398 TString tfname( GetWeightFileName() );
1401 TString xmlfname( tfname ); xmlfname.ReplaceAll(
".txt",
".xml" );
1403 <<
"Creating xml weight file: "
1404 << gTools().Color(
"lightblue") << xmlfname << gTools().Color(
"reset") << Endl;
1405 void* doc = gTools().xmlengine().NewDoc();
1406 void* rootnode = gTools().AddChild(0,
"MethodSetup",
"",
true);
1407 gTools().xmlengine().DocSetRootElement(doc,rootnode);
1408 gTools().AddAttr(rootnode,
"Method", GetMethodTypeName() +
"::" + GetMethodName());
1409 WriteStateToXML(rootnode);
1410 gTools().xmlengine().SaveDoc(doc,xmlfname);
1411 gTools().xmlengine().FreeDoc(doc);
1417 void TMVA::MethodBase::ReadStateFromFile()
1421 TString tfname(GetWeightFileName());
1424 <<
"Reading weight file: "
1425 << gTools().Color(
"lightblue") << tfname << gTools().Color(
"reset") << Endl;
1427 if (tfname.EndsWith(
".xml") ) {
1428 void* doc = gTools().xmlengine().ParseFile(tfname,gTools().xmlenginebuffersize());
1430 Log() << kFATAL <<
"Error parsing XML file " << tfname << Endl;
1432 void* rootnode = gTools().xmlengine().DocGetRootElement(doc);
1433 ReadStateFromXML(rootnode);
1434 gTools().xmlengine().FreeDoc(doc);
1438 fb.open(tfname.Data(),std::ios::in);
1439 if (!fb.is_open()) {
1440 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<ReadStateFromFile> "
1441 <<
"Unable to open input weight file: " << tfname << Endl;
1443 std::istream fin(&fb);
1444 ReadStateFromStream(fin);
1447 if (!fTxtWeightsOnly) {
1449 TString rfname( tfname ); rfname.ReplaceAll(
".txt",
".root" );
1450 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Reading root weight file: "
1451 << gTools().Color(
"lightblue") << rfname << gTools().Color(
"reset") << Endl;
1452 TFile* rfile = TFile::Open( rfname,
"READ" );
1453 ReadStateFromStream( *rfile );
1460 void TMVA::MethodBase::ReadStateFromXMLString(
const char* xmlstr ) {
1461 void* doc = gTools().xmlengine().ParseString(xmlstr);
1462 void* rootnode = gTools().xmlengine().DocGetRootElement(doc);
1463 ReadStateFromXML(rootnode);
1464 gTools().xmlengine().FreeDoc(doc);
1471 void TMVA::MethodBase::ReadStateFromXML(
void* methodNode )
1474 TString fullMethodName;
1475 gTools().ReadAttr( methodNode,
"Method", fullMethodName );
1477 fMethodName = fullMethodName(fullMethodName.Index(
"::")+2,fullMethodName.Length());
1480 Log().SetSource( GetName() );
1482 <<
"Read method \"" << GetMethodName() <<
"\" of type \"" << GetMethodTypeName() <<
"\"" << Endl;
1487 TString nodeName(
"");
1488 void* ch = gTools().GetChild(methodNode);
1490 nodeName = TString( gTools().GetName(ch) );
1492 if (nodeName==
"GeneralInfo") {
1495 TString name(
""),val(
"");
1496 void* antypeNode = gTools().GetChild(ch);
1497 while (antypeNode) {
1498 gTools().ReadAttr( antypeNode,
"name", name );
1500 if (name ==
"TrainingTime")
1501 gTools().ReadAttr( antypeNode,
"value", fTrainTime );
1503 if (name ==
"AnalysisType") {
1504 gTools().ReadAttr( antypeNode,
"value", val );
1506 if (val ==
"regression" ) SetAnalysisType( Types::kRegression );
1507 else if (val ==
"classification" ) SetAnalysisType( Types::kClassification );
1508 else if (val ==
"multiclass" ) SetAnalysisType( Types::kMulticlass );
1509 else Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Analysis type " << val <<
" is not known." << Endl;
1512 if (name ==
"TMVA Release" || name ==
"TMVA") {
1514 gTools().ReadAttr( antypeNode,
"value", s);
1515 fTMVATrainingVersion = TString(s(s.Index(
"[")+1,s.Index(
"]")-s.Index(
"[")-1)).Atoi();
1516 Log() << kDEBUG <<Form(
"[%s] : ",DataInfo().GetName()) <<
"MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
1519 if (name ==
"ROOT Release" || name ==
"ROOT") {
1521 gTools().ReadAttr( antypeNode,
"value", s);
1522 fROOTTrainingVersion = TString(s(s.Index(
"[")+1,s.Index(
"]")-s.Index(
"[")-1)).Atoi();
1524 <<
"MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
1526 antypeNode = gTools().GetNextChild(antypeNode);
1529 else if (nodeName==
"Options") {
1530 ReadOptionsFromXML(ch);
1534 else if (nodeName==
"Variables") {
1535 ReadVariablesFromXML(ch);
1537 else if (nodeName==
"Spectators") {
1538 ReadSpectatorsFromXML(ch);
1540 else if (nodeName==
"Classes") {
1541 if (DataInfo().GetNClasses()==0) ReadClassesFromXML(ch);
1543 else if (nodeName==
"Targets") {
1544 if (DataInfo().GetNTargets()==0 && DoRegression()) ReadTargetsFromXML(ch);
1546 else if (nodeName==
"Transformations") {
1547 GetTransformationHandler().ReadFromXML(ch);
1549 else if (nodeName==
"MVAPdfs") {
1551 if (fMVAPdfS) {
delete fMVAPdfS; fMVAPdfS=0; }
1552 if (fMVAPdfB) {
delete fMVAPdfB; fMVAPdfB=0; }
1553 void* pdfnode = gTools().GetChild(ch);
1555 gTools().ReadAttr(pdfnode,
"Name", pdfname);
1556 fMVAPdfS =
new PDF(pdfname);
1557 fMVAPdfS->ReadXML(pdfnode);
1558 pdfnode = gTools().GetNextChild(pdfnode);
1559 gTools().ReadAttr(pdfnode,
"Name", pdfname);
1560 fMVAPdfB =
new PDF(pdfname);
1561 fMVAPdfB->ReadXML(pdfnode);
1564 else if (nodeName==
"Weights") {
1565 ReadWeightsFromXML(ch);
1568 Log() << kWARNING <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Unparsed XML node: '" << nodeName <<
"'" << Endl;
1570 ch = gTools().GetNextChild(ch);
1575 if (GetTransformationHandler().GetCallerName() ==
"") GetTransformationHandler().SetCallerName( GetName() );
1581 void TMVA::MethodBase::ReadStateFromStream( std::istream& fin )
1586 SetAnalysisType(Types::kClassification);
1591 while (!TString(buf).BeginsWith(
"Method")) GetLine(fin,buf);
1592 TString namestr(buf);
1594 TString methodType = namestr(0,namestr.Index(
"::"));
1595 methodType = methodType(methodType.Last(
' '),methodType.Length());
1596 methodType = methodType.Strip(TString::kLeading);
1598 TString methodName = namestr(namestr.Index(
"::")+2,namestr.Length());
1599 methodName = methodName.Strip(TString::kLeading);
1600 if (methodName ==
"") methodName = methodType;
1601 fMethodName = methodName;
1603 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Read method \"" << GetMethodName() <<
"\" of type \"" << GetMethodTypeName() <<
"\"" << Endl;
1606 Log().SetSource( GetName() );
1620 while (!TString(buf).BeginsWith(
"#OPT")) GetLine(fin,buf);
1621 ReadOptionsFromStream(fin);
1625 fin.getline(buf,512);
1626 while (!TString(buf).BeginsWith(
"#VAR")) fin.getline(buf,512);
1627 ReadVarsFromStream(fin);
1632 if (IsNormalised()) {
1633 VariableNormalizeTransform* norm = (VariableNormalizeTransform*)
1634 GetTransformationHandler().AddTransformation(
new VariableNormalizeTransform(DataInfo()), -1 );
1635 norm->BuildTransformationFromVarInfo( DataInfo().GetVariableInfos() );
1637 VariableTransformBase *varTrafo(0), *varTrafo2(0);
1638 if ( fVarTransformString ==
"None") {
1640 varTrafo = GetTransformationHandler().AddTransformation(
new VariableDecorrTransform(DataInfo()), -1 );
1641 }
else if ( fVarTransformString ==
"Decorrelate" ) {
1642 varTrafo = GetTransformationHandler().AddTransformation(
new VariableDecorrTransform(DataInfo()), -1 );
1643 }
else if ( fVarTransformString ==
"PCA" ) {
1644 varTrafo = GetTransformationHandler().AddTransformation(
new VariablePCATransform(DataInfo()), -1 );
1645 }
else if ( fVarTransformString ==
"Uniform" ) {
1646 varTrafo = GetTransformationHandler().AddTransformation(
new VariableGaussTransform(DataInfo(),
"Uniform"), -1 );
1647 }
else if ( fVarTransformString ==
"Gauss" ) {
1648 varTrafo = GetTransformationHandler().AddTransformation(
new VariableGaussTransform(DataInfo()), -1 );
1649 }
else if ( fVarTransformString ==
"GaussDecorr" ) {
1650 varTrafo = GetTransformationHandler().AddTransformation(
new VariableGaussTransform(DataInfo()), -1 );
1651 varTrafo2 = GetTransformationHandler().AddTransformation(
new VariableDecorrTransform(DataInfo()), -1 );
1653 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<ProcessOptions> Variable transform '"
1654 << fVarTransformString <<
"' unknown." << Endl;
1657 if (GetTransformationHandler().GetTransformationList().GetSize() > 0) {
1658 fin.getline(buf,512);
1659 while (!TString(buf).BeginsWith(
"#MAT")) fin.getline(buf,512);
1661 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1662 varTrafo->ReadTransformationFromStream(fin, trafo );
1665 TString trafo(fVariableTransformTypeString); trafo.ToLower();
1666 varTrafo2->ReadTransformationFromStream(fin, trafo );
1673 fin.getline(buf,512);
1674 while (!TString(buf).BeginsWith(
"#MVAPDFS")) fin.getline(buf,512);
1675 if (fMVAPdfS != 0) {
delete fMVAPdfS; fMVAPdfS = 0; }
1676 if (fMVAPdfB != 0) {
delete fMVAPdfB; fMVAPdfB = 0; }
1677 fMVAPdfS =
new PDF(TString(GetName()) +
" MVA PDF Sig");
1678 fMVAPdfB =
new PDF(TString(GetName()) +
" MVA PDF Bkg");
1679 fMVAPdfS->SetReadingVersion( GetTrainingTMVAVersionCode() );
1680 fMVAPdfB->SetReadingVersion( GetTrainingTMVAVersionCode() );
1687 fin.getline(buf,512);
1688 while (!TString(buf).BeginsWith(
"#WGT")) fin.getline(buf,512);
1689 fin.getline(buf,512);
1690 ReadWeightsFromStream( fin );;
1693 if (GetTransformationHandler().GetCallerName() ==
"") GetTransformationHandler().SetCallerName( GetName() );
1701 void TMVA::MethodBase::WriteVarsToStream( std::ostream& o,
const TString& prefix )
const
1703 o << prefix <<
"NVar " << DataInfo().GetNVariables() << std::endl;
1704 std::vector<VariableInfo>::const_iterator varIt = DataInfo().GetVariableInfos().begin();
1705 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1706 o << prefix <<
"NSpec " << DataInfo().GetNSpectators() << std::endl;
1707 varIt = DataInfo().GetSpectatorInfos().begin();
1708 for (; varIt!=DataInfo().GetSpectatorInfos().end(); ++varIt) { o << prefix; varIt->WriteToStream(o); }
1716 void TMVA::MethodBase::ReadVarsFromStream( std::istream& istr )
1720 istr >> dummy >> readNVar;
1722 if (readNVar!=DataInfo().GetNVariables()) {
1723 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"You declared "<< DataInfo().GetNVariables() <<
" variables in the Reader"
1724 <<
" while there are " << readNVar <<
" variables declared in the file"
1729 VariableInfo varInfo;
1730 std::vector<VariableInfo>::iterator varIt = DataInfo().GetVariableInfos().begin();
1732 for (; varIt!=DataInfo().GetVariableInfos().end(); ++varIt, ++varIdx) {
1733 varInfo.ReadFromStream(istr);
1734 if (varIt->GetExpression() == varInfo.GetExpression()) {
1735 varInfo.SetExternalLink((*varIt).GetExternalLink());
1739 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"ERROR in <ReadVarsFromStream>" << Endl;
1740 Log() << kINFO <<
"The definition (or the order) of the variables found in the input file is" << Endl;
1741 Log() << kINFO <<
"is not the same as the one declared in the Reader (which is necessary for" << Endl;
1742 Log() << kINFO <<
"the correct working of the method):" << Endl;
1743 Log() << kINFO <<
" var #" << varIdx <<
" declared in Reader: " << varIt->GetExpression() << Endl;
1744 Log() << kINFO <<
" var #" << varIdx <<
" declared in file : " << varInfo.GetExpression() << Endl;
1745 Log() << kFATAL <<
"The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1753 void TMVA::MethodBase::AddVarsXMLTo(
void* parent )
const
1755 void* vars = gTools().AddChild(parent,
"Variables");
1756 gTools().AddAttr( vars,
"NVar", gTools().StringFromInt(DataInfo().GetNVariables()) );
1758 for (UInt_t idx=0; idx<DataInfo().GetVariableInfos().size(); idx++) {
1759 VariableInfo& vi = DataInfo().GetVariableInfos()[idx];
1760 void* var = gTools().AddChild( vars,
"Variable" );
1761 gTools().AddAttr( var,
"VarIndex", idx );
1769 void TMVA::MethodBase::AddSpectatorsXMLTo(
void* parent )
const
1771 void* specs = gTools().AddChild(parent,
"Spectators");
1774 for (UInt_t idx=0; idx<DataInfo().GetSpectatorInfos().size(); idx++) {
1776 VariableInfo& vi = DataInfo().GetSpectatorInfos()[idx];
1780 if (vi.GetVarType()==
'C')
continue;
1782 void* spec = gTools().AddChild( specs,
"Spectator" );
1783 gTools().AddAttr( spec,
"SpecIndex", writeIdx++ );
1784 vi.AddToXML( spec );
1786 gTools().AddAttr( specs,
"NSpec", gTools().StringFromInt(writeIdx) );
1792 void TMVA::MethodBase::AddClassesXMLTo(
void* parent )
const
1794 UInt_t nClasses=DataInfo().GetNClasses();
1796 void* classes = gTools().AddChild(parent,
"Classes");
1797 gTools().AddAttr( classes,
"NClass", nClasses );
1799 for (UInt_t iCls=0; iCls<nClasses; ++iCls) {
1800 ClassInfo *classInfo=DataInfo().GetClassInfo (iCls);
1801 TString className =classInfo->GetName();
1802 UInt_t classNumber=classInfo->GetNumber();
1804 void* classNode=gTools().AddChild(classes,
"Class");
1805 gTools().AddAttr( classNode,
"Name", className );
1806 gTools().AddAttr( classNode,
"Index", classNumber );
1812 void TMVA::MethodBase::AddTargetsXMLTo(
void* parent )
const
1814 void* targets = gTools().AddChild(parent,
"Targets");
1815 gTools().AddAttr( targets,
"NTrgt", gTools().StringFromInt(DataInfo().GetNTargets()) );
1817 for (UInt_t idx=0; idx<DataInfo().GetTargetInfos().size(); idx++) {
1818 VariableInfo& vi = DataInfo().GetTargetInfos()[idx];
1819 void* tar = gTools().AddChild( targets,
"Target" );
1820 gTools().AddAttr( tar,
"TargetIndex", idx );
1828 void TMVA::MethodBase::ReadVariablesFromXML(
void* varnode )
1831 gTools().ReadAttr( varnode,
"NVar", readNVar);
1833 if (readNVar!=DataInfo().GetNVariables()) {
1834 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"You declared "<< DataInfo().GetNVariables() <<
" variables in the Reader"
1835 <<
" while there are " << readNVar <<
" variables declared in the file"
1840 VariableInfo readVarInfo, existingVarInfo;
1842 void* ch = gTools().GetChild(varnode);
1844 gTools().ReadAttr( ch,
"VarIndex", varIdx);
1845 existingVarInfo = DataInfo().GetVariableInfos()[varIdx];
1846 readVarInfo.ReadFromXML(ch);
1848 if (existingVarInfo.GetExpression() == readVarInfo.GetExpression()) {
1849 readVarInfo.SetExternalLink(existingVarInfo.GetExternalLink());
1850 existingVarInfo = readVarInfo;
1853 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"ERROR in <ReadVariablesFromXML>" << Endl;
1854 Log() << kINFO <<
"The definition (or the order) of the variables found in the input file is" << Endl;
1855 Log() << kINFO <<
"not the same as the one declared in the Reader (which is necessary for the" << Endl;
1856 Log() << kINFO <<
"correct working of the method):" << Endl;
1857 Log() << kINFO <<
" var #" << varIdx <<
" declared in Reader: " << existingVarInfo.GetExpression() << Endl;
1858 Log() << kINFO <<
" var #" << varIdx <<
" declared in file : " << readVarInfo.GetExpression() << Endl;
1859 Log() << kFATAL <<
"The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1861 ch = gTools().GetNextChild(ch);
1868 void TMVA::MethodBase::ReadSpectatorsFromXML(
void* specnode )
1871 gTools().ReadAttr( specnode,
"NSpec", readNSpec);
1873 if (readNSpec!=DataInfo().GetNSpectators(kFALSE)) {
1874 Log() << kFATAL<<Form(
"Dataset[%s] : ",DataInfo().GetName()) <<
"You declared "<< DataInfo().GetNSpectators(kFALSE) <<
" spectators in the Reader"
1875 <<
" while there are " << readNSpec <<
" spectators declared in the file"
1880 VariableInfo readSpecInfo, existingSpecInfo;
1882 void* ch = gTools().GetChild(specnode);
1884 gTools().ReadAttr( ch,
"SpecIndex", specIdx);
1885 existingSpecInfo = DataInfo().GetSpectatorInfos()[specIdx];
1886 readSpecInfo.ReadFromXML(ch);
1888 if (existingSpecInfo.GetExpression() == readSpecInfo.GetExpression()) {
1889 readSpecInfo.SetExternalLink(existingSpecInfo.GetExternalLink());
1890 existingSpecInfo = readSpecInfo;
1893 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"ERROR in <ReadSpectatorsFromXML>" << Endl;
1894 Log() << kINFO <<
"The definition (or the order) of the spectators found in the input file is" << Endl;
1895 Log() << kINFO <<
"not the same as the one declared in the Reader (which is necessary for the" << Endl;
1896 Log() << kINFO <<
"correct working of the method):" << Endl;
1897 Log() << kINFO <<
" spec #" << specIdx <<
" declared in Reader: " << existingSpecInfo.GetExpression() << Endl;
1898 Log() << kINFO <<
" spec #" << specIdx <<
" declared in file : " << readSpecInfo.GetExpression() << Endl;
1899 Log() << kFATAL <<
"The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
1901 ch = gTools().GetNextChild(ch);
1908 void TMVA::MethodBase::ReadClassesFromXML(
void* clsnode )
1912 gTools().ReadAttr( clsnode,
"NClass", readNCls);
1914 TString className=
"";
1915 UInt_t classIndex=0;
1916 void* ch = gTools().GetChild(clsnode);
1918 for (UInt_t icls = 0; icls<readNCls;++icls) {
1919 TString classname = Form(
"class%i",icls);
1920 DataInfo().AddClass(classname);
1926 gTools().ReadAttr( ch,
"Index", classIndex);
1927 gTools().ReadAttr( ch,
"Name", className );
1928 DataInfo().AddClass(className);
1930 ch = gTools().GetNextChild(ch);
1935 if (DataInfo().GetClassInfo(
"Signal") != 0) {
1936 fSignalClass = DataInfo().GetClassInfo(
"Signal")->GetNumber();
1940 if (DataInfo().GetClassInfo(
"Background") != 0) {
1941 fBackgroundClass = DataInfo().GetClassInfo(
"Background")->GetNumber();
1950 void TMVA::MethodBase::ReadTargetsFromXML(
void* tarnode )
1953 gTools().ReadAttr( tarnode,
"NTrgt", readNTar);
1957 void* ch = gTools().GetChild(tarnode);
1959 gTools().ReadAttr( ch,
"TargetIndex", tarIdx);
1960 gTools().ReadAttr( ch,
"Expression", expression);
1961 DataInfo().AddTarget(expression,
"",
"",0,0);
1963 ch = gTools().GetNextChild(ch);
1971 TDirectory* TMVA::MethodBase::BaseDir()
const
1973 if (fBaseDir != 0)
return fBaseDir;
1974 Log()<<kDEBUG<<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
" Base Directory for " << GetMethodName() <<
" not set yet --> check if already there.." <<Endl;
1976 if (IsSilentFile()) {
1977 Log() << kFATAL << Form(
"Dataset[%s] : ", DataInfo().GetName())
1978 <<
"MethodBase::BaseDir() - No directory exists when running a Method without output file. Enable the "
1979 "output when creating the factory"
1983 TDirectory* methodDir = MethodBaseDir();
1985 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"MethodBase::BaseDir() - MethodBaseDir() return a NULL pointer!" << Endl;
1987 TString defaultDir = GetMethodName();
1988 TDirectory *sdir = methodDir->GetDirectory(defaultDir.Data());
1991 Log()<<kDEBUG<<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
" Base Directory for " << GetMethodTypeName() <<
" does not exist yet--> created it" <<Endl;
1992 sdir = methodDir->mkdir(defaultDir);
1995 if (fModelPersistence) {
1996 TObjString wfilePath( gSystem->WorkingDirectory() );
1997 TObjString wfileName( GetWeightFileName() );
1998 wfilePath.Write(
"TrainingPath" );
1999 wfileName.Write(
"WeightFileName" );
2003 Log()<<kDEBUG<<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
" Base Directory for " << GetMethodTypeName() <<
" existed, return it.." <<Endl;
2011 TDirectory *TMVA::MethodBase::MethodBaseDir()
const
2013 if (fMethodBaseDir != 0) {
2014 return fMethodBaseDir;
2017 const char *datasetName = DataInfo().GetName();
2019 Log() << kDEBUG << Form(
"Dataset[%s] : ", datasetName) <<
" Base Directory for " << GetMethodTypeName()
2020 <<
" not set yet --> check if already there.." << Endl;
2022 TDirectory *factoryBaseDir = GetFile();
2023 if (!factoryBaseDir)
return nullptr;
2024 fMethodBaseDir = factoryBaseDir->GetDirectory(datasetName);
2025 if (!fMethodBaseDir) {
2026 fMethodBaseDir = factoryBaseDir->mkdir(datasetName, Form(
"Base directory for dataset %s", datasetName));
2027 if (!fMethodBaseDir) {
2028 Log() << kFATAL <<
"Can not create dir " << datasetName;
2031 TString methodTypeDir = Form(
"Method_%s", GetMethodTypeName().Data());
2032 fMethodBaseDir = fMethodBaseDir->GetDirectory(methodTypeDir.Data());
2034 if (!fMethodBaseDir) {
2035 TDirectory *datasetDir = factoryBaseDir->GetDirectory(datasetName);
2036 TString methodTypeDirHelpStr = Form(
"Directory for all %s methods", GetMethodTypeName().Data());
2037 fMethodBaseDir = datasetDir->mkdir(methodTypeDir.Data(), methodTypeDirHelpStr);
2038 Log() << kDEBUG << Form(
"Dataset[%s] : ", datasetName) <<
" Base Directory for " << GetMethodName()
2039 <<
" does not exist yet--> created it" << Endl;
2042 Log() << kDEBUG << Form(
"Dataset[%s] : ", datasetName)
2043 <<
"Return from MethodBaseDir() after creating base directory " << Endl;
2044 return fMethodBaseDir;
2050 void TMVA::MethodBase::SetWeightFileDir( TString fileDir )
2053 gSystem->mkdir( fFileDir, kTRUE );
2059 void TMVA::MethodBase::SetWeightFileName( TString theWeightFile)
2061 fWeightFile = theWeightFile;
2067 TString TMVA::MethodBase::GetWeightFileName()
const
2069 if (fWeightFile!=
"")
return fWeightFile;
2073 TString suffix =
"";
2074 TString wFileDir(GetWeightFileDir());
2075 TString wFileName = GetJobName() +
"_" + GetMethodName() +
2076 suffix +
"." + gConfig().GetIONames().fWeightFileExtension +
".xml";
2077 if (wFileDir.IsNull() )
return wFileName;
2079 return ( wFileDir + (wFileDir[wFileDir.Length()-1]==
'/' ?
"" :
"/")
2085 void TMVA::MethodBase::WriteEvaluationHistosToFile(Types::ETreeType treetype)
2091 if (0 != fMVAPdfS) {
2092 fMVAPdfS->GetOriginalHist()->Write();
2093 fMVAPdfS->GetSmoothedHist()->Write();
2094 fMVAPdfS->GetPDFHist()->Write();
2096 if (0 != fMVAPdfB) {
2097 fMVAPdfB->GetOriginalHist()->Write();
2098 fMVAPdfB->GetSmoothedHist()->Write();
2099 fMVAPdfB->GetPDFHist()->Write();
2103 Results* results = Data()->GetResults( GetMethodName(), treetype, Types::kMaxAnalysisType );
2105 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<WriteEvaluationHistosToFile> Unknown result: "
2106 << GetMethodName() << (treetype==Types::kTraining?
"/kTraining":
"/kTesting")
2107 <<
"/kMaxAnalysisType" << Endl;
2108 results->GetStorage()->Write();
2109 if (treetype==Types::kTesting) {
2111 if ((
int) DataInfo().GetNVariables()< gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables)
2112 GetTransformationHandler().PlotVariables (GetEventCollection( Types::kTesting ), BaseDir() );
2114 Log() << kINFO << TString::Format(
"Dataset[%s] : ",DataInfo().GetName())
2115 <<
" variable plots are not produces ! The number of variables is " << DataInfo().GetNVariables()
2116 <<
" , it is larger than " << gConfig().GetVariablePlotting().fMaxNumOfAllowedVariables << Endl;
2124 void TMVA::MethodBase::WriteMonitoringHistosToFile(
void )
const
2133 Bool_t TMVA::MethodBase::GetLine(std::istream& fin,
char* buf )
2135 fin.getline(buf,512);
2137 if (line.BeginsWith(
"TMVA Release")) {
2138 Ssiz_t start = line.First(
'[')+1;
2139 Ssiz_t length = line.Index(
"]",start)-start;
2140 TString code = line(start,length);
2141 std::stringstream s(code.Data());
2142 s >> fTMVATrainingVersion;
2143 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"MVA method was trained with TMVA Version: " << GetTrainingTMVAVersionString() << Endl;
2145 if (line.BeginsWith(
"ROOT Release")) {
2146 Ssiz_t start = line.First(
'[')+1;
2147 Ssiz_t length = line.Index(
"]",start)-start;
2148 TString code = line(start,length);
2149 std::stringstream s(code.Data());
2150 s >> fROOTTrainingVersion;
2151 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"MVA method was trained with ROOT Version: " << GetTrainingROOTVersionString() << Endl;
2153 if (line.BeginsWith(
"Analysis type")) {
2154 Ssiz_t start = line.First(
'[')+1;
2155 Ssiz_t length = line.Index(
"]",start)-start;
2156 TString code = line(start,length);
2157 std::stringstream s(code.Data());
2158 std::string analysisType;
2160 if (analysisType ==
"regression" || analysisType ==
"Regression") SetAnalysisType( Types::kRegression );
2161 else if (analysisType ==
"classification" || analysisType ==
"Classification") SetAnalysisType( Types::kClassification );
2162 else if (analysisType ==
"multiclass" || analysisType ==
"Multiclass") SetAnalysisType( Types::kMulticlass );
2163 else Log() << kFATAL <<
"Analysis type " << analysisType <<
" from weight-file not known!" << std::endl;
2165 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Method was trained for "
2166 << (GetAnalysisType() == Types::kRegression ?
"Regression" :
2167 (GetAnalysisType() == Types::kMulticlass ?
"Multiclass" :
"Classification")) << Endl;
2176 void TMVA::MethodBase::CreateMVAPdfs()
2178 Data()->SetCurrentType(Types::kTraining);
2182 ResultsClassification * mvaRes =
dynamic_cast<ResultsClassification*
>
2183 ( Data()->GetResults(GetMethodName(), Types::kTraining, Types::kClassification) );
2185 if (mvaRes==0 || mvaRes->GetSize()==0) {
2186 Log() << kERROR<<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<CreateMVAPdfs> No result of classifier testing available" << Endl;
2189 Double_t minVal = *std::min_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2190 Double_t maxVal = *std::max_element(mvaRes->GetValueVector()->begin(),mvaRes->GetValueVector()->end());
2193 TH1* histMVAPdfS =
new TH1D( GetMethodTypeName() +
"_tr_S", GetMethodTypeName() +
"_tr_S",
2194 fMVAPdfS->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2195 TH1* histMVAPdfB =
new TH1D( GetMethodTypeName() +
"_tr_B", GetMethodTypeName() +
"_tr_B",
2196 fMVAPdfB->GetHistNBins( mvaRes->GetSize() ), minVal, maxVal );
2200 histMVAPdfS->Sumw2();
2201 histMVAPdfB->Sumw2();
2204 for (UInt_t ievt=0; ievt<mvaRes->GetSize(); ievt++) {
2205 Double_t theVal = mvaRes->GetValueVector()->at(ievt);
2206 Double_t theWeight = Data()->GetEvent(ievt)->GetWeight();
2208 if (DataInfo().IsSignal(Data()->GetEvent(ievt))) histMVAPdfS->Fill( theVal, theWeight );
2209 else histMVAPdfB->Fill( theVal, theWeight );
2212 gTools().NormHist( histMVAPdfS );
2213 gTools().NormHist( histMVAPdfB );
2218 histMVAPdfS->Write();
2219 histMVAPdfB->Write();
2222 fMVAPdfS->BuildPDF ( histMVAPdfS );
2223 fMVAPdfB->BuildPDF ( histMVAPdfB );
2224 fMVAPdfS->ValidatePDF( histMVAPdfS );
2225 fMVAPdfB->ValidatePDF( histMVAPdfB );
2227 if (DataInfo().GetNClasses() == 2) {
2228 Log() << kINFO<<Form(
"Dataset[%s] : ",DataInfo().GetName())
2229 << Form(
"<CreateMVAPdfs> Separation from histogram (PDF): %1.3f (%1.3f)",
2230 GetSeparation( histMVAPdfS, histMVAPdfB ), GetSeparation( fMVAPdfS, fMVAPdfB ) )
2238 Double_t TMVA::MethodBase::GetProba(
const Event *ev){
2242 if (!fMVAPdfS || !fMVAPdfB) {
2243 Log() << kINFO<<Form(
"Dataset[%s] : ",DataInfo().GetName()) <<
"<GetProba> MVA PDFs for Signal and Background don't exist yet, we'll create them on demand" << Endl;
2246 Double_t sigFraction = DataInfo().GetTrainingSumSignalWeights() / (DataInfo().GetTrainingSumSignalWeights() + DataInfo().GetTrainingSumBackgrWeights() );
2247 Double_t mvaVal = GetMvaValue(ev);
2249 return GetProba(mvaVal,sigFraction);
2255 Double_t TMVA::MethodBase::GetProba( Double_t mvaVal, Double_t ap_sig )
2257 if (!fMVAPdfS || !fMVAPdfB) {
2258 Log() << kWARNING <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetProba> MVA PDFs for Signal and Background don't exist" << Endl;
2261 Double_t p_s = fMVAPdfS->GetVal( mvaVal );
2262 Double_t p_b = fMVAPdfB->GetVal( mvaVal );
2264 Double_t denom = p_s*ap_sig + p_b*(1 - ap_sig);
2266 return (denom > 0) ? (p_s*ap_sig) / denom : -1;
2276 Double_t TMVA::MethodBase::GetRarity( Double_t mvaVal, Types::ESBType reftype )
const
2278 if ((reftype == Types::kSignal && !fMVAPdfS) || (reftype == Types::kBackground && !fMVAPdfB)) {
2279 Log() << kWARNING <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetRarity> Required MVA PDF for Signal or Background does not exist: "
2280 <<
"select option \"CreateMVAPdfs\"" << Endl;
2284 PDF* thePdf = ((reftype == Types::kSignal) ? fMVAPdfS : fMVAPdfB);
2286 return thePdf->GetIntegral( thePdf->GetXmin(), mvaVal );
2293 Double_t TMVA::MethodBase::GetEfficiency(
const TString& theString, Types::ETreeType type,Double_t& effSerr )
2295 Data()->SetCurrentType(type);
2296 Results* results = Data()->GetResults( GetMethodName(), type, Types::kClassification );
2297 std::vector<Float_t>* mvaRes =
dynamic_cast<ResultsClassification*
>(results)->GetValueVector();
2300 TList* list = gTools().ParseFormatLine( theString );
2303 Bool_t computeArea = kFALSE;
2304 if (!list || list->GetSize() < 2) computeArea = kTRUE;
2305 else if (list->GetSize() > 2) {
2306 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetEfficiency> Wrong number of arguments"
2307 <<
" in string: " << theString
2308 <<
" | required format, e.g., Efficiency:0.05, or empty string" << Endl;
2314 if ( results->GetHist(
"MVA_S")->GetNbinsX() != results->GetHist(
"MVA_B")->GetNbinsX() ||
2315 results->GetHist(
"MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist(
"MVA_HIGHBIN_B")->GetNbinsX() ) {
2316 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetEfficiency> Binning mismatch between signal and background histos" << Endl;
2324 TH1 * effhist = results->GetHist(
"MVA_HIGHBIN_S");
2325 Double_t xmin = effhist->GetXaxis()->GetXmin();
2326 Double_t xmax = effhist->GetXaxis()->GetXmax();
2328 TTHREAD_TLS(Double_t) nevtS;
2331 if (results->DoesExist("MVA_EFF_S")==0) {
2334 TH1* eff_s =
new TH1D( GetTestvarName() +
"_effS", GetTestvarName() +
" (signal)", fNbinsH, xmin, xmax );
2335 TH1* eff_b =
new TH1D( GetTestvarName() +
"_effB", GetTestvarName() +
" (background)", fNbinsH, xmin, xmax );
2336 results->Store(eff_s,
"MVA_EFF_S");
2337 results->Store(eff_b,
"MVA_EFF_B");
2340 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2344 for (UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2347 Bool_t isSignal = DataInfo().IsSignal(GetEvent(ievt));
2348 Float_t theWeight = GetEvent(ievt)->GetWeight();
2349 Float_t theVal = (*mvaRes)[ievt];
2352 TH1* theHist = isSignal ? eff_s : eff_b;
2355 if (isSignal) nevtS+=theWeight;
2357 TAxis* axis = theHist->GetXaxis();
2358 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2359 if (sign > 0 && maxbin > fNbinsH)
continue;
2360 if (sign < 0 && maxbin < 1 )
continue;
2361 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2362 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2365 for (Int_t ibin=1; ibin<=maxbin; ibin++) theHist->AddBinContent( ibin , theWeight);
2367 for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theHist->AddBinContent( ibin , theWeight );
2369 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetEfficiency> Mismatch in sign" << Endl;
2376 eff_s->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(),eff_s->GetMaximum()) );
2377 eff_b->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(),eff_b->GetMaximum()) );
2380 TH1* eff_BvsS =
new TH1D( GetTestvarName() +
"_effBvsS", GetTestvarName() +
"", fNbins, 0, 1 );
2381 results->Store(eff_BvsS,
"MVA_EFF_BvsS");
2382 eff_BvsS->SetXTitle(
"Signal eff" );
2383 eff_BvsS->SetYTitle(
"Backgr eff" );
2386 TH1* rej_BvsS =
new TH1D( GetTestvarName() +
"_rejBvsS", GetTestvarName() +
"", fNbins, 0, 1 );
2387 results->Store(rej_BvsS);
2388 rej_BvsS->SetXTitle(
"Signal eff" );
2389 rej_BvsS->SetYTitle(
"Backgr rejection (1-eff)" );
2392 TH1* inveff_BvsS =
new TH1D( GetTestvarName() +
"_invBeffvsSeff",
2393 GetTestvarName(), fNbins, 0, 1 );
2394 results->Store(inveff_BvsS);
2395 inveff_BvsS->SetXTitle(
"Signal eff" );
2396 inveff_BvsS->SetYTitle(
"Inverse backgr. eff (1/eff)" );
2401 if (Use_Splines_for_Eff_) {
2402 fSplRefS =
new TSpline1(
"spline2_signal",
new TGraph( eff_s ) );
2403 fSplRefB =
new TSpline1(
"spline2_background",
new TGraph( eff_b ) );
2406 gTools().CheckSplines( eff_s, fSplRefS );
2407 gTools().CheckSplines( eff_b, fSplRefB );
2413 RootFinder rootFinder(
this, fXmin, fXmax );
2417 for (Int_t bini=1; bini<=fNbins; bini++) {
2420 Double_t effS = eff_BvsS->GetBinCenter( bini );
2421 Double_t cut = rootFinder.Root( effS );
2424 if (Use_Splines_for_Eff_) effB = fSplRefB->Eval( cut );
2425 else effB = eff_b->GetBinContent( eff_b->FindBin( cut ) );
2428 eff_BvsS->SetBinContent( bini, effB );
2429 rej_BvsS->SetBinContent( bini, 1.0-effB );
2430 if (effB>std::numeric_limits<double>::epsilon())
2431 inveff_BvsS->SetBinContent( bini, 1.0/effB );
2435 fSpleffBvsS =
new TSpline1(
"effBvsS",
new TGraph( eff_BvsS ) );
2439 Double_t effS = 0., rejB, effS_ = 0., rejB_ = 0.;
2440 Int_t nbins_ = 5000;
2441 for (Int_t bini=1; bini<=nbins_; bini++) {
2444 effS = (bini - 0.5)/Float_t(nbins_);
2445 rejB = 1.0 - fSpleffBvsS->Eval( effS );
2448 if ((effS - rejB)*(effS_ - rejB_) < 0)
break;
2454 Double_t cut = rootFinder.Root( 0.5*(effS + effS_) );
2455 SetSignalReferenceCut( cut );
2460 if (0 == fSpleffBvsS) {
2466 Double_t effS = 0, effB = 0, effS_ = 0, effB_ = 0;
2467 Int_t nbins_ = 1000;
2472 Double_t integral = 0;
2473 for (Int_t bini=1; bini<=nbins_; bini++) {
2476 effS = (bini - 0.5)/Float_t(nbins_);
2477 effB = fSpleffBvsS->Eval( effS );
2478 integral += (1.0 - effB);
2489 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2492 for (Int_t bini=1; bini<=nbins_; bini++) {
2495 effS = (bini - 0.5)/Float_t(nbins_);
2496 effB = fSpleffBvsS->Eval( effS );
2499 if ((effB - effBref)*(effB_ - effBref) <= 0)
break;
2505 effS = 0.5*(effS + effS_);
2508 if (nevtS > 0) effSerr = TMath::Sqrt( effS*(1.0 - effS)/nevtS );
2519 Double_t TMVA::MethodBase::GetTrainingEfficiency(
const TString& theString)
2521 Data()->SetCurrentType(Types::kTraining);
2523 Results* results = Data()->GetResults(GetMethodName(), Types::kTesting, Types::kNoAnalysisType);
2529 TList* list = gTools().ParseFormatLine( theString );
2532 if (list->GetSize() != 2) {
2533 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetTrainingEfficiency> Wrong number of arguments"
2534 <<
" in string: " << theString
2535 <<
" | required format, e.g., Efficiency:0.05" << Endl;
2541 Float_t effBref = atof( ((TObjString*)list->At(1))->GetString() );
2546 if (results->GetHist(
"MVA_S")->GetNbinsX() != results->GetHist(
"MVA_B")->GetNbinsX() ||
2547 results->GetHist(
"MVA_HIGHBIN_S")->GetNbinsX() != results->GetHist(
"MVA_HIGHBIN_B")->GetNbinsX() ) {
2548 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetTrainingEfficiency> Binning mismatch between signal and background histos"
2556 TH1 * effhist = results->GetHist(
"MVA_HIGHBIN_S");
2557 Double_t xmin = effhist->GetXaxis()->GetXmin();
2558 Double_t xmax = effhist->GetXaxis()->GetXmax();
2561 if (results->DoesExist(
"MVA_TRAIN_S")==0) {
2564 Double_t sxmax = fXmax+0.00001;
2567 TH1* mva_s_tr =
new TH1D( GetTestvarName() +
"_Train_S",GetTestvarName() +
"_Train_S", fNbinsMVAoutput, fXmin, sxmax );
2568 TH1* mva_b_tr =
new TH1D( GetTestvarName() +
"_Train_B",GetTestvarName() +
"_Train_B", fNbinsMVAoutput, fXmin, sxmax );
2569 results->Store(mva_s_tr,
"MVA_TRAIN_S");
2570 results->Store(mva_b_tr,
"MVA_TRAIN_B");
2575 TH1* mva_eff_tr_s =
new TH1D( GetTestvarName() +
"_trainingEffS", GetTestvarName() +
" (signal)",
2576 fNbinsH, xmin, xmax );
2577 TH1* mva_eff_tr_b =
new TH1D( GetTestvarName() +
"_trainingEffB", GetTestvarName() +
" (background)",
2578 fNbinsH, xmin, xmax );
2579 results->Store(mva_eff_tr_s,
"MVA_TRAINEFF_S");
2580 results->Store(mva_eff_tr_b,
"MVA_TRAINEFF_B");
2583 Int_t sign = (fCutOrientation == kPositive) ? +1 : -1;
2585 std::vector<Double_t> mvaValues = GetMvaValues(0,Data()->GetNEvents());
2586 assert( (Long64_t) mvaValues.size() == Data()->GetNEvents());
2589 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
2591 Data()->SetCurrentEvent(ievt);
2592 const Event* ev = GetEvent();
2594 Double_t theVal = mvaValues[ievt];
2595 Double_t theWeight = ev->GetWeight();
2597 TH1* theEffHist = DataInfo().IsSignal(ev) ? mva_eff_tr_s : mva_eff_tr_b;
2598 TH1* theClsHist = DataInfo().IsSignal(ev) ? mva_s_tr : mva_b_tr;
2600 theClsHist->Fill( theVal, theWeight );
2602 TAxis* axis = theEffHist->GetXaxis();
2603 Int_t maxbin = Int_t((theVal - axis->GetXmin())/(axis->GetXmax() - axis->GetXmin())*fNbinsH) + 1;
2604 if (sign > 0 && maxbin > fNbinsH)
continue;
2605 if (sign < 0 && maxbin < 1 )
continue;
2606 if (sign > 0 && maxbin < 1 ) maxbin = 1;
2607 if (sign < 0 && maxbin > fNbinsH) maxbin = fNbinsH;
2609 if (sign > 0)
for (Int_t ibin=1; ibin<=maxbin; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2610 else for (Int_t ibin=maxbin+1; ibin<=fNbinsH; ibin++) theEffHist->AddBinContent( ibin , theWeight );
2615 gTools().NormHist( mva_s_tr );
2616 gTools().NormHist( mva_b_tr );
2619 mva_eff_tr_s->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_s->GetMaximum()) );
2620 mva_eff_tr_b->Scale( 1.0/TMath::Max(std::numeric_limits<double>::epsilon(), mva_eff_tr_b->GetMaximum()) );
2623 TH1* eff_bvss =
new TH1D( GetTestvarName() +
"_trainingEffBvsS", GetTestvarName() +
"", fNbins, 0, 1 );
2625 TH1* rej_bvss =
new TH1D( GetTestvarName() +
"_trainingRejBvsS", GetTestvarName() +
"", fNbins, 0, 1 );
2626 results->Store(eff_bvss,
"EFF_BVSS_TR");
2627 results->Store(rej_bvss,
"REJ_BVSS_TR");
2632 if (Use_Splines_for_Eff_) {
2633 if (fSplTrainRefS)
delete fSplTrainRefS;
2634 if (fSplTrainRefB)
delete fSplTrainRefB;
2635 fSplTrainRefS =
new TSpline1(
"spline2_signal",
new TGraph( mva_eff_tr_s ) );
2636 fSplTrainRefB =
new TSpline1(
"spline2_background",
new TGraph( mva_eff_tr_b ) );
2639 gTools().CheckSplines( mva_eff_tr_s, fSplTrainRefS );
2640 gTools().CheckSplines( mva_eff_tr_b, fSplTrainRefB );
2646 RootFinder rootFinder(
this, fXmin, fXmax );
2649 fEffS = results->GetHist(
"MVA_TRAINEFF_S");
2650 for (Int_t bini=1; bini<=fNbins; bini++) {
2653 Double_t effS = eff_bvss->GetBinCenter( bini );
2655 Double_t cut = rootFinder.Root( effS );
2658 if (Use_Splines_for_Eff_) effB = fSplTrainRefB->Eval( cut );
2659 else effB = mva_eff_tr_b->GetBinContent( mva_eff_tr_b->FindBin( cut ) );
2662 eff_bvss->SetBinContent( bini, effB );
2663 rej_bvss->SetBinContent( bini, 1.0-effB );
2668 fSplTrainEffBvsS =
new TSpline1(
"effBvsS",
new TGraph( eff_bvss ) );
2672 if (0 == fSplTrainEffBvsS)
return 0.0;
2675 Double_t effS = 0., effB, effS_ = 0., effB_ = 0.;
2676 Int_t nbins_ = 1000;
2677 for (Int_t bini=1; bini<=nbins_; bini++) {
2680 effS = (bini - 0.5)/Float_t(nbins_);
2681 effB = fSplTrainEffBvsS->Eval( effS );
2684 if ((effB - effBref)*(effB_ - effBref) <= 0)
break;
2689 return 0.5*(effS + effS_);
2694 std::vector<Float_t> TMVA::MethodBase::GetMulticlassEfficiency(std::vector<std::vector<Float_t> >& purity)
2696 Data()->SetCurrentType(Types::kTesting);
2697 ResultsMulticlass* resMulticlass =
dynamic_cast<ResultsMulticlass*
>(Data()->GetResults(GetMethodName(), Types::kTesting, Types::kMulticlass));
2698 if (!resMulticlass) Log() << kFATAL<<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"unable to create pointer in GetMulticlassEfficiency, exiting."<<Endl;
2700 purity.push_back(resMulticlass->GetAchievablePur());
2701 return resMulticlass->GetAchievableEff();
2706 std::vector<Float_t> TMVA::MethodBase::GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity)
2708 Data()->SetCurrentType(Types::kTraining);
2709 ResultsMulticlass* resMulticlass =
dynamic_cast<ResultsMulticlass*
>(Data()->GetResults(GetMethodName(), Types::kTraining, Types::kMulticlass));
2710 if (!resMulticlass) Log() << kFATAL<<
"unable to create pointer in GetMulticlassTrainingEfficiency, exiting."<<Endl;
2712 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Determine optimal multiclass cuts for training data..." << Endl;
2713 for (UInt_t icls = 0; icls<DataInfo().GetNClasses(); ++icls) {
2714 resMulticlass->GetBestMultiClassCuts(icls);
2717 purity.push_back(resMulticlass->GetAchievablePur());
2718 return resMulticlass->GetAchievableEff();
2741 TMatrixD TMVA::MethodBase::GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
2743 if (GetAnalysisType() != Types::kMulticlass) {
2744 Log() << kFATAL <<
"Cannot get confusion matrix for non-multiclass analysis." << std::endl;
2745 return TMatrixD(0, 0);
2748 Data()->SetCurrentType(type);
2749 ResultsMulticlass *resMulticlass =
2750 dynamic_cast<ResultsMulticlass *
>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));
2752 if (resMulticlass ==
nullptr) {
2753 Log() << kFATAL << Form(
"Dataset[%s] : ", DataInfo().GetName())
2754 <<
"unable to create pointer in GetMulticlassEfficiency, exiting." << Endl;
2755 return TMatrixD(0, 0);
2758 return resMulticlass->GetConfusionMatrix(effB);
2767 Double_t TMVA::MethodBase::GetSignificance(
void )
const
2769 Double_t rms = sqrt( fRmsS*fRmsS + fRmsB*fRmsB );
2771 return (rms > 0) ? TMath::Abs(fMeanS - fMeanB)/rms : 0;
2780 Double_t TMVA::MethodBase::GetSeparation( TH1* histoS, TH1* histoB )
const
2782 return gTools().GetSeparation( histoS, histoB );
2791 Double_t TMVA::MethodBase::GetSeparation( PDF* pdfS, PDF* pdfB )
const
2795 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2796 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetSeparation> Mismatch in pdfs" << Endl;
2797 if (!pdfS) pdfS = fSplS;
2798 if (!pdfB) pdfB = fSplB;
2800 if (!fSplS || !fSplB) {
2801 Log()<<kDEBUG<<Form(
"[%s] : ",DataInfo().GetName())<<
"could not calculate the separation, distributions"
2802 <<
" fSplS or fSplB are not yet filled" << Endl;
2805 return gTools().GetSeparation( *pdfS, *pdfB );
2813 Double_t TMVA::MethodBase::GetROCIntegral(TH1D *histS, TH1D *histB)
const
2817 if ((!histS && histB) || (histS && !histB))
2818 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetROCIntegral(TH1D*, TH1D*)> Mismatch in hists" << Endl;
2820 if (histS==0 || histB==0)
return 0.;
2822 TMVA::PDF *pdfS =
new TMVA::PDF(
" PDF Sig", histS, TMVA::PDF::kSpline3 );
2823 TMVA::PDF *pdfB =
new TMVA::PDF(
" PDF Bkg", histB, TMVA::PDF::kSpline3 );
2826 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2827 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2829 Double_t integral = 0;
2830 UInt_t nsteps = 1000;
2831 Double_t step = (xmax-xmin)/Double_t(nsteps);
2832 Double_t cut = xmin;
2833 for (UInt_t i=0; i<nsteps; i++) {
2834 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2839 return integral*step;
2847 Double_t TMVA::MethodBase::GetROCIntegral(PDF *pdfS, PDF *pdfB)
const
2851 if ((!pdfS && pdfB) || (pdfS && !pdfB))
2852 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetSeparation> Mismatch in pdfs" << Endl;
2853 if (!pdfS) pdfS = fSplS;
2854 if (!pdfB) pdfB = fSplB;
2856 if (pdfS==0 || pdfB==0)
return 0.;
2858 Double_t xmin = TMath::Min(pdfS->GetXmin(), pdfB->GetXmin());
2859 Double_t xmax = TMath::Max(pdfS->GetXmax(), pdfB->GetXmax());
2861 Double_t integral = 0;
2862 UInt_t nsteps = 1000;
2863 Double_t step = (xmax-xmin)/Double_t(nsteps);
2864 Double_t cut = xmin;
2865 for (UInt_t i=0; i<nsteps; i++) {
2866 integral += (1-pdfB->GetIntegral(cut,xmax)) * pdfS->GetVal(cut);
2869 return integral*step;
2877 Double_t TMVA::MethodBase::GetMaximumSignificance( Double_t SignalEvents,
2878 Double_t BackgroundEvents,
2879 Double_t& max_significance_value )
const
2881 Results* results = Data()->GetResults( GetMethodName(), Types::kTesting, Types::kMaxAnalysisType );
2883 Double_t max_significance(0);
2884 Double_t effS(0),effB(0),significance(0);
2885 TH1D *temp_histogram =
new TH1D(
"temp",
"temp", fNbinsH, fXmin, fXmax );
2887 if (SignalEvents <= 0 || BackgroundEvents <= 0) {
2888 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<GetMaximumSignificance> "
2889 <<
"Number of signal or background events is <= 0 ==> abort"
2893 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Using ratio SignalEvents/BackgroundEvents = "
2894 << SignalEvents/BackgroundEvents << Endl;
2896 TH1* eff_s = results->GetHist(
"MVA_EFF_S");
2897 TH1* eff_b = results->GetHist(
"MVA_EFF_B");
2899 if ( (eff_s==0) || (eff_b==0) ) {
2900 Log() << kWARNING <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Efficiency histograms empty !" << Endl;
2901 Log() << kWARNING <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"no maximum cut found, return 0" << Endl;
2905 for (Int_t bin=1; bin<=fNbinsH; bin++) {
2906 effS = eff_s->GetBinContent( bin );
2907 effB = eff_b->GetBinContent( bin );
2910 significance = sqrt(SignalEvents)*( effS )/sqrt( effS + ( BackgroundEvents / SignalEvents) * effB );
2912 temp_histogram->SetBinContent(bin,significance);
2916 max_significance = temp_histogram->GetBinCenter( temp_histogram->GetMaximumBin() );
2917 max_significance_value = temp_histogram->GetBinContent( temp_histogram->GetMaximumBin() );
2920 delete temp_histogram;
2922 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Optimal cut at : " << max_significance << Endl;
2923 Log() << kINFO<<Form(
"Dataset[%s] : ",DataInfo().GetName()) <<
"Maximum significance: " << max_significance_value << Endl;
2925 return max_significance;
2933 void TMVA::MethodBase::Statistics( Types::ETreeType treeType,
const TString& theVarName,
2934 Double_t& meanS, Double_t& meanB,
2935 Double_t& rmsS, Double_t& rmsB,
2936 Double_t& xmin, Double_t& xmax )
2938 Types::ETreeType previousTreeType = Data()->GetCurrentType();
2939 Data()->SetCurrentType(treeType);
2941 Long64_t entries = Data()->GetNEvents();
2945 Log() << kFATAL <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"<CalculateEstimator> Wrong tree type: " << treeType << Endl;
2948 UInt_t varIndex = DataInfo().FindVarIndex( theVarName );
2953 Long64_t nEventsS = -1;
2954 Long64_t nEventsB = -1;
2961 Double_t sumwS = 0, sumwB = 0;
2964 for (Int_t ievt = 0; ievt < entries; ievt++) {
2966 const Event* ev = GetEvent(ievt);
2968 Double_t theVar = ev->GetValue(varIndex);
2969 Double_t weight = ev->GetWeight();
2971 if (DataInfo().IsSignal(ev)) {
2973 meanS += weight*theVar;
2974 rmsS += weight*theVar*theVar;
2978 meanB += weight*theVar;
2979 rmsB += weight*theVar*theVar;
2981 xmin = TMath::Min( xmin, theVar );
2982 xmax = TMath::Max( xmax, theVar );
2987 meanS = meanS/sumwS;
2988 meanB = meanB/sumwB;
2989 rmsS = TMath::Sqrt( rmsS/sumwS - meanS*meanS );
2990 rmsB = TMath::Sqrt( rmsB/sumwB - meanB*meanB );
2992 Data()->SetCurrentType(previousTreeType);
2998 void TMVA::MethodBase::MakeClass(
const TString& theClassFileName )
const
3001 TString classFileName =
"";
3002 if (theClassFileName ==
"")
3003 classFileName = GetWeightFileDir() +
"/" + GetJobName() +
"_" + GetMethodName() +
".class.C";
3005 classFileName = theClassFileName;
3007 TString className = TString(
"Read") + GetMethodName();
3009 TString tfname( classFileName );
3011 <<
"Creating standalone class: "
3012 << gTools().Color(
"lightblue") << classFileName << gTools().Color(
"reset") << Endl;
3014 std::ofstream fout( classFileName );
3016 Log() << kFATAL <<
"<MakeClass> Unable to open file: " << classFileName << Endl;
3021 fout <<
"// Class: " << className << std::endl;
3022 fout <<
"// Automatically generated by MethodBase::MakeClass" << std::endl <<
"//" << std::endl;
3026 fout <<
"/* configuration options =====================================================" << std::endl << std::endl;
3027 WriteStateToStream( fout );
3029 fout <<
"============================================================================ */" << std::endl;
3032 fout <<
"" << std::endl;
3033 fout <<
"#include <array>" << std::endl;
3034 fout <<
"#include <vector>" << std::endl;
3035 fout <<
"#include <cmath>" << std::endl;
3036 fout <<
"#include <string>" << std::endl;
3037 fout <<
"#include <iostream>" << std::endl;
3038 fout <<
"" << std::endl;
3041 this->MakeClassSpecificHeader( fout, className );
3043 fout <<
"#ifndef IClassifierReader__def" << std::endl;
3044 fout <<
"#define IClassifierReader__def" << std::endl;
3046 fout <<
"class IClassifierReader {" << std::endl;
3048 fout <<
" public:" << std::endl;
3050 fout <<
" // constructor" << std::endl;
3051 fout <<
" IClassifierReader() : fStatusIsClean( true ) {}" << std::endl;
3052 fout <<
" virtual ~IClassifierReader() {}" << std::endl;
3054 fout <<
" // return classifier response" << std::endl;
3055 if(GetAnalysisType() == Types::kMulticlass) {
3056 fout <<
" virtual std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3058 fout <<
" virtual double GetMvaValue( const std::vector<double>& inputValues ) const = 0;" << std::endl;
3061 fout <<
" // returns classifier status" << std::endl;
3062 fout <<
" bool IsStatusClean() const { return fStatusIsClean; }" << std::endl;
3064 fout <<
" protected:" << std::endl;
3066 fout <<
" bool fStatusIsClean;" << std::endl;
3067 fout <<
"};" << std::endl;
3069 fout <<
"#endif" << std::endl;
3071 fout <<
"class " << className <<
" : public IClassifierReader {" << std::endl;
3073 fout <<
" public:" << std::endl;
3075 fout <<
" // constructor" << std::endl;
3076 fout <<
" " << className <<
"( std::vector<std::string>& theInputVars )" << std::endl;
3077 fout <<
" : IClassifierReader()," << std::endl;
3078 fout <<
" fClassName( \"" << className <<
"\" )," << std::endl;
3079 fout <<
" fNvars( " << GetNvar() <<
" )" << std::endl;
3080 fout <<
" {" << std::endl;
3081 fout <<
" // the training input variables" << std::endl;
3082 fout <<
" const char* inputVars[] = { ";
3083 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3084 fout <<
"\"" << GetOriginalVarName(ivar) <<
"\"";
3085 if (ivar<GetNvar()-1) fout <<
", ";
3087 fout <<
" };" << std::endl;
3089 fout <<
" // sanity checks" << std::endl;
3090 fout <<
" if (theInputVars.size() <= 0) {" << std::endl;
3091 fout <<
" std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": empty input vector\" << std::endl;" << std::endl;
3092 fout <<
" fStatusIsClean = false;" << std::endl;
3093 fout <<
" }" << std::endl;
3095 fout <<
" if (theInputVars.size() != fNvars) {" << std::endl;
3096 fout <<
" std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in number of input values: \"" << std::endl;
3097 fout <<
" << theInputVars.size() << \" != \" << fNvars << std::endl;" << std::endl;
3098 fout <<
" fStatusIsClean = false;" << std::endl;
3099 fout <<
" }" << std::endl;
3101 fout <<
" // validate input variables" << std::endl;
3102 fout <<
" for (size_t ivar = 0; ivar < theInputVars.size(); ivar++) {" << std::endl;
3103 fout <<
" if (theInputVars[ivar] != inputVars[ivar]) {" << std::endl;
3104 fout <<
" std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": mismatch in input variable names\" << std::endl" << std::endl;
3105 fout <<
" << \" for variable [\" << ivar << \"]: \" << theInputVars[ivar].c_str() << \" != \" << inputVars[ivar] << std::endl;" << std::endl;
3106 fout <<
" fStatusIsClean = false;" << std::endl;
3107 fout <<
" }" << std::endl;
3108 fout <<
" }" << std::endl;
3110 fout <<
" // initialize min and max vectors (for normalisation)" << std::endl;
3111 for (UInt_t ivar = 0; ivar < GetNvar(); ivar++) {
3112 fout <<
" fVmin[" << ivar <<
"] = " << std::setprecision(15) << GetXmin( ivar ) <<
";" << std::endl;
3113 fout <<
" fVmax[" << ivar <<
"] = " << std::setprecision(15) << GetXmax( ivar ) <<
";" << std::endl;
3116 fout <<
" // initialize input variable types" << std::endl;
3117 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
3118 fout <<
" fType[" << ivar <<
"] = \'" << DataInfo().GetVariableInfo(ivar).GetVarType() <<
"\';" << std::endl;
3121 fout <<
" // initialize constants" << std::endl;
3122 fout <<
" Initialize();" << std::endl;
3124 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
3125 fout <<
" // initialize transformation" << std::endl;
3126 fout <<
" InitTransform();" << std::endl;
3128 fout <<
" }" << std::endl;
3130 fout <<
" // destructor" << std::endl;
3131 fout <<
" virtual ~" << className <<
"() {" << std::endl;
3132 fout <<
" Clear(); // method-specific" << std::endl;
3133 fout <<
" }" << std::endl;
3135 fout <<
" // the classifier response" << std::endl;
3136 fout <<
" // \"inputValues\" is a vector of input values in the same order as the" << std::endl;
3137 fout <<
" // variables given to the constructor" << std::endl;
3138 if(GetAnalysisType() == Types::kMulticlass) {
3139 fout <<
" std::vector<double> GetMulticlassValues( const std::vector<double>& inputValues ) const override;" << std::endl;
3141 fout <<
" double GetMvaValue( const std::vector<double>& inputValues ) const override;" << std::endl;
3144 fout <<
" private:" << std::endl;
3146 fout <<
" // method-specific destructor" << std::endl;
3147 fout <<
" void Clear();" << std::endl;
3149 if (GetTransformationHandler().GetTransformationList().GetSize()!=0) {
3150 fout <<
" // input variable transformation" << std::endl;
3151 GetTransformationHandler().MakeFunction(fout, className,1);
3152 fout <<
" void InitTransform();" << std::endl;
3153 fout <<
" void Transform( std::vector<double> & iv, int sigOrBgd ) const;" << std::endl;
3156 fout <<
" // common member variables" << std::endl;
3157 fout <<
" const char* fClassName;" << std::endl;
3159 fout <<
" const size_t fNvars;" << std::endl;
3160 fout <<
" size_t GetNvar() const { return fNvars; }" << std::endl;
3161 fout <<
" char GetType( int ivar ) const { return fType[ivar]; }" << std::endl;
3163 fout <<
" // normalisation of input variables" << std::endl;
3164 fout <<
" double fVmin[" << GetNvar() <<
"];" << std::endl;
3165 fout <<
" double fVmax[" << GetNvar() <<
"];" << std::endl;
3166 fout <<
" double NormVariable( double x, double xmin, double xmax ) const {" << std::endl;
3167 fout <<
" // normalise to output range: [-1, 1]" << std::endl;
3168 fout <<
" return 2*(x - xmin)/(xmax - xmin) - 1.0;" << std::endl;
3169 fout <<
" }" << std::endl;
3171 fout <<
" // type of input variable: 'F' or 'I'" << std::endl;
3172 fout <<
" char fType[" << GetNvar() <<
"];" << std::endl;
3174 fout <<
" // initialize internal variables" << std::endl;
3175 fout <<
" void Initialize();" << std::endl;
3176 if(GetAnalysisType() == Types::kMulticlass) {
3177 fout <<
" std::vector<double> GetMulticlassValues__( const std::vector<double>& inputValues ) const;" << std::endl;
3179 fout <<
" double GetMvaValue__( const std::vector<double>& inputValues ) const;" << std::endl;
3181 fout <<
"" << std::endl;
3182 fout <<
" // private members (method specific)" << std::endl;
3185 MakeClassSpecific( fout, className );
3187 if(GetAnalysisType() == Types::kMulticlass) {
3188 fout <<
"inline std::vector<double> " << className <<
"::GetMulticlassValues( const std::vector<double>& inputValues ) const" << std::endl;
3190 fout <<
"inline double " << className <<
"::GetMvaValue( const std::vector<double>& inputValues ) const" << std::endl;
3192 fout <<
"{" << std::endl;
3193 fout <<
" // classifier response value" << std::endl;
3194 if(GetAnalysisType() == Types::kMulticlass) {
3195 fout <<
" std::vector<double> retval;" << std::endl;
3197 fout <<
" double retval = 0;" << std::endl;
3200 fout <<
" // classifier response, sanity check first" << std::endl;
3201 fout <<
" if (!IsStatusClean()) {" << std::endl;
3202 fout <<
" std::cout << \"Problem in class \\\"\" << fClassName << \"\\\": cannot return classifier response\"" << std::endl;
3203 fout <<
" << \" because status is dirty\" << std::endl;" << std::endl;
3204 fout <<
" }" << std::endl;
3205 fout <<
" else {" << std::endl;
3206 if (IsNormalised()) {
3207 fout <<
" // normalise variables" << std::endl;
3208 fout <<
" std::vector<double> iV;" << std::endl;
3209 fout <<
" iV.reserve(inputValues.size());" << std::endl;
3210 fout <<
" int ivar = 0;" << std::endl;
3211 fout <<
" for (std::vector<double>::const_iterator varIt = inputValues.begin();" << std::endl;
3212 fout <<
" varIt != inputValues.end(); varIt++, ivar++) {" << std::endl;
3213 fout <<
" iV.push_back(NormVariable( *varIt, fVmin[ivar], fVmax[ivar] ));" << std::endl;
3214 fout <<
" }" << std::endl;
3215 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3216 GetMethodType() != Types::kHMatrix) {
3217 fout <<
" Transform( iV, -1 );" << std::endl;
3220 if(GetAnalysisType() == Types::kMulticlass) {
3221 fout <<
" retval = GetMulticlassValues__( iV );" << std::endl;
3223 fout <<
" retval = GetMvaValue__( iV );" << std::endl;
3226 if (GetTransformationHandler().GetTransformationList().GetSize() != 0 && GetMethodType() != Types::kLikelihood &&
3227 GetMethodType() != Types::kHMatrix) {
3228 fout <<
" std::vector<double> iV(inputValues);" << std::endl;
3229 fout <<
" Transform( iV, -1 );" << std::endl;
3230 if(GetAnalysisType() == Types::kMulticlass) {
3231 fout <<
" retval = GetMulticlassValues__( iV );" << std::endl;
3233 fout <<
" retval = GetMvaValue__( iV );" << std::endl;
3236 if(GetAnalysisType() == Types::kMulticlass) {
3237 fout <<
" retval = GetMulticlassValues__( inputValues );" << std::endl;
3239 fout <<
" retval = GetMvaValue__( inputValues );" << std::endl;
3243 fout <<
" }" << std::endl;
3245 fout <<
" return retval;" << std::endl;
3246 fout <<
"}" << std::endl;
3249 if (GetTransformationHandler().GetTransformationList().GetSize()!=0)
3250 GetTransformationHandler().MakeFunction(fout, className,2);
3259 void TMVA::MethodBase::PrintHelpMessage()
const
3262 std::streambuf* cout_sbuf = std::cout.rdbuf();
3263 std::ofstream* o = 0;
3264 if (gConfig().WriteOptionsReference()) {
3265 Log() << kINFO <<
"Print Help message for class " << GetName() <<
" into file: " << GetReferenceFile() << Endl;
3266 o =
new std::ofstream( GetReferenceFile(), std::ios::app );
3268 Log() << kFATAL <<
"<PrintHelpMessage> Unable to append to output file: " << GetReferenceFile() << Endl;
3270 std::cout.rdbuf( o->rdbuf() );
3275 Log() << kINFO << Endl;
3276 Log() << gTools().Color(
"bold")
3277 <<
"================================================================"
3278 << gTools().Color(
"reset" )
3280 Log() << gTools().Color(
"bold")
3281 <<
"H e l p f o r M V A m e t h o d [ " << GetName() <<
" ] :"
3282 << gTools().Color(
"reset" )
3286 Log() <<
"Help for MVA method [ " << GetName() <<
" ] :" << Endl;
3294 Log() <<
"<Suppress this message by specifying \"!H\" in the booking option>" << Endl;
3295 Log() << gTools().Color(
"bold")
3296 <<
"================================================================"
3297 << gTools().Color(
"reset" )
3303 Log() <<
"# End of Message___" << Endl;
3306 std::cout.rdbuf( cout_sbuf );
3315 Double_t TMVA::MethodBase::GetValueForRoot( Double_t theCut )
3320 if (Use_Splines_for_Eff_) {
3321 retval = fSplRefS->Eval( theCut );
3323 else retval = fEffS->GetBinContent( fEffS->FindBin( theCut ) );
3331 Double_t eps = 1.0e-5;
3332 if (theCut-fXmin < eps) retval = (GetCutOrientation() == kPositive) ? 1.0 : 0.0;
3333 else if (fXmax-theCut < eps) retval = (GetCutOrientation() == kPositive) ? 0.0 : 1.0;
3342 const std::vector<TMVA::Event*>& TMVA::MethodBase::GetEventCollection( Types::ETreeType type)
3346 if (GetTransformationHandler().GetTransformationList().GetEntries() <= 0) {
3347 return (Data()->GetEventCollection(type));
3353 Int_t idx = Data()->TreeIndex(type);
3354 if (fEventCollections.at(idx) == 0) {
3355 fEventCollections.at(idx) = &(Data()->GetEventCollection(type));
3356 fEventCollections.at(idx) = GetTransformationHandler().CalcTransformations(*(fEventCollections.at(idx)),kTRUE);
3358 return *(fEventCollections.at(idx));
3364 TString TMVA::MethodBase::GetTrainingTMVAVersionString()
const
3366 UInt_t a = GetTrainingTMVAVersionCode() & 0xff0000; a>>=16;
3367 UInt_t b = GetTrainingTMVAVersionCode() & 0x00ff00; b>>=8;
3368 UInt_t c = GetTrainingTMVAVersionCode() & 0x0000ff;
3370 return TString(Form(
"%i.%i.%i",a,b,c));
3376 TString TMVA::MethodBase::GetTrainingROOTVersionString()
const
3378 UInt_t a = GetTrainingROOTVersionCode() & 0xff0000; a>>=16;
3379 UInt_t b = GetTrainingROOTVersionCode() & 0x00ff00; b>>=8;
3380 UInt_t c = GetTrainingROOTVersionCode() & 0x0000ff;
3382 return TString(Form(
"%i.%02i/%02i",a,b,c));
3387 Double_t TMVA::MethodBase::GetKSTrainingVsTest(Char_t SorB, TString opt){
3388 ResultsClassification* mvaRes =
dynamic_cast<ResultsClassification*
>
3389 ( Data()->GetResults(GetMethodName(),Types::kTesting, Types::kClassification) );
3391 if (mvaRes != NULL) {
3392 TH1D *mva_s =
dynamic_cast<TH1D*
> (mvaRes->GetHist(
"MVA_S"));
3393 TH1D *mva_b =
dynamic_cast<TH1D*
> (mvaRes->GetHist(
"MVA_B"));
3394 TH1D *mva_s_tr =
dynamic_cast<TH1D*
> (mvaRes->GetHist(
"MVA_TRAIN_S"));
3395 TH1D *mva_b_tr =
dynamic_cast<TH1D*
> (mvaRes->GetHist(
"MVA_TRAIN_B"));
3397 if ( !mva_s || !mva_b || !mva_s_tr || !mva_b_tr)
return -1;
3399 if (SorB ==
's' || SorB ==
'S')
3400 return mva_s->KolmogorovTest( mva_s_tr, opt.Data() );
3402 return mva_b->KolmogorovTest( mva_b_tr, opt.Data() );