33 #ifndef ROOT_TMVA_MethodBase
34 #define ROOT_TMVA_MethodBase
89 namespace Experimental {
92 class TrainingHistory;
94 class IPythonInteractive {
97 ~IPythonInteractive();
98 void Init(std::vector<TString>& graphTitles);
100 void AddPoint(Double_t x, Double_t y1, Double_t y2);
101 void AddPoint(std::vector<Double_t>& dat);
102 inline TMultiGraph* Get() {
return fMultiGraph;}
103 inline bool NotInitialized(){
return fNumGraphs==0;};
105 TMultiGraph* fMultiGraph;
106 std::vector<TGraph*> fGraphs;
111 class MethodBase :
virtual public IMethod,
public Configurable {
113 friend class CrossValidation;
114 friend class Factory;
115 friend class RootFinder;
116 friend class MethodBoost;
117 friend class MethodCrossValidation;
118 friend class Experimental::Classification;
122 enum EWeightFileType { kROOT=0, kTEXT };
125 MethodBase(
const TString& jobName,
126 Types::EMVA methodType,
127 const TString& methodTitle,
129 const TString& theOption =
"" );
133 MethodBase( Types::EMVA methodType,
135 const TString& weightFile );
138 virtual ~MethodBase();
143 virtual void CheckSetup();
148 void AddOutput( Types::ETreeType type, Types::EAnalysisType analysisType );
155 virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType=
"ROCIntegral", TString fitType=
"FitGA");
156 virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
158 virtual void Train() = 0;
161 void SetTrainTime( Double_t trainTime ) { fTrainTime = trainTime; }
162 Double_t GetTrainTime()
const {
return fTrainTime; }
165 void SetTestTime ( Double_t testTime ) { fTestTime = testTime; }
166 Double_t GetTestTime ()
const {
return fTestTime; }
169 virtual void TestClassification();
170 virtual Double_t GetKSTrainingVsTest(Char_t SorB, TString opt=
"X");
173 virtual void TestMulticlass();
176 virtual void TestRegression( Double_t& bias, Double_t& biasT,
177 Double_t& dev, Double_t& devT,
178 Double_t& rms, Double_t& rmsT,
179 Double_t& mInf, Double_t& mInfT,
181 Types::ETreeType type );
184 virtual void Init() = 0;
185 virtual void DeclareOptions() = 0;
186 virtual void ProcessOptions() = 0;
187 virtual void DeclareCompatibilityOptions();
193 virtual void Reset(){
return;}
198 virtual Double_t GetMvaValue( Double_t* errLower = 0, Double_t* errUpper = 0) = 0;
201 Double_t GetMvaValue(
const TMVA::Event*
const ev, Double_t* err = 0, Double_t* errUpper = 0 );
205 void NoErrorCalc(Double_t*
const err, Double_t*
const errUpper);
208 virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress =
false);
213 const std::vector<Float_t>& GetRegressionValues(
const TMVA::Event*
const ev){
215 const std::vector<Float_t>* ptr = &GetRegressionValues();
220 virtual const std::vector<Float_t>& GetRegressionValues() {
221 std::vector<Float_t>* ptr =
new std::vector<Float_t>(0);
226 virtual const std::vector<Float_t>& GetMulticlassValues() {
227 std::vector<Float_t>* ptr =
new std::vector<Float_t>(0);
232 virtual const std::vector<Float_t>& GetTrainingHistory(
const char* ) {
233 std::vector<Float_t>* ptr =
new std::vector<Float_t>(0);
238 virtual Double_t GetProba(
const Event *ev);
239 virtual Double_t GetProba( Double_t mvaVal, Double_t ap_sig );
242 virtual Double_t GetRarity( Double_t mvaVal, Types::ESBType reftype = Types::kBackground )
const;
245 virtual const Ranking* CreateRanking() = 0;
248 virtual void MakeClass(
const TString& classFileName = TString(
"") )
const;
251 void PrintHelpMessage()
const;
257 void WriteStateToFile ()
const;
258 void ReadStateFromFile ();
262 virtual void AddWeightsXMLTo (
void* parent )
const = 0;
263 virtual void ReadWeightsFromXML (
void* wghtnode ) = 0;
264 virtual void ReadWeightsFromStream( std::istream& ) = 0;
265 virtual void ReadWeightsFromStream( TFile& ) {}
268 friend class MethodCategory;
269 friend class MethodCompositeBase;
270 void WriteStateToXML (
void* parent )
const;
271 void ReadStateFromXML (
void* parent );
272 void WriteStateToStream ( std::ostream& tf )
const;
273 void WriteVarsToStream ( std::ostream& tf,
const TString& prefix =
"" )
const;
277 void ReadStateFromStream ( std::istream& tf );
278 void ReadStateFromStream ( TFile& rf );
279 void ReadStateFromXMLString(
const char* xmlstr );
283 void AddVarsXMLTo (
void* parent )
const;
284 void AddSpectatorsXMLTo (
void* parent )
const;
285 void AddTargetsXMLTo (
void* parent )
const;
286 void AddClassesXMLTo (
void* parent )
const;
287 void ReadVariablesFromXML (
void* varnode );
288 void ReadSpectatorsFromXML(
void* specnode);
289 void ReadTargetsFromXML (
void* tarnode );
290 void ReadClassesFromXML (
void* clsnode );
291 void ReadVarsFromStream ( std::istream& istr );
297 virtual void WriteEvaluationHistosToFile(Types::ETreeType treetype);
300 virtual void WriteMonitoringHistosToFile()
const;
312 virtual Double_t GetEfficiency(
const TString&, Types::ETreeType, Double_t& err );
313 virtual Double_t GetTrainingEfficiency(
const TString& );
314 virtual std::vector<Float_t> GetMulticlassEfficiency( std::vector<std::vector<Float_t> >& purity );
315 virtual std::vector<Float_t> GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity );
316 virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type);
317 virtual Double_t GetSignificance()
const;
318 virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB)
const;
319 virtual Double_t GetROCIntegral(PDF *pdfS=0, PDF *pdfB=0)
const;
320 virtual Double_t GetMaximumSignificance( Double_t SignalEvents, Double_t BackgroundEvents,
321 Double_t& optimal_significance_value )
const;
322 virtual Double_t GetSeparation( TH1*, TH1* )
const;
323 virtual Double_t GetSeparation( PDF* pdfS = 0, PDF* pdfB = 0 )
const;
325 virtual void GetRegressionDeviation(UInt_t tgtNum, Types::ETreeType type, Double_t& stddev,Double_t& stddev90Percent )
const;
329 const TString& GetJobName ()
const {
return fJobName; }
330 const TString& GetMethodName ()
const {
return fMethodName; }
331 TString GetMethodTypeName()
const {
return Types::Instance().GetMethodName(fMethodType); }
332 Types::EMVA GetMethodType ()
const {
return fMethodType; }
333 const char* GetName ()
const {
return fMethodName.Data(); }
334 const TString& GetTestvarName ()
const {
return fTestvar; }
335 const TString GetProbaName ()
const {
return fTestvar +
"_Proba"; }
336 TString GetWeightFileName()
const;
340 void SetTestvarName (
const TString & v=
"" ) { fTestvar = (v==
"") ? (
"MVA_" + GetMethodName()) : v; }
343 UInt_t GetNvar()
const {
return DataInfo().GetNVariables(); }
344 UInt_t GetNVariables()
const {
return DataInfo().GetNVariables(); }
345 UInt_t GetNTargets()
const {
return DataInfo().GetNTargets(); };
348 const TString& GetInputVar ( Int_t i )
const {
return DataInfo().GetVariableInfo(i).GetInternalName(); }
349 const TString& GetInputLabel( Int_t i )
const {
return DataInfo().GetVariableInfo(i).GetLabel(); }
350 const char * GetInputTitle( Int_t i )
const {
return DataInfo().GetVariableInfo(i).GetTitle(); }
353 Double_t GetMean( Int_t ivar )
const {
return GetTransformationHandler().GetMean(ivar); }
354 Double_t GetRMS ( Int_t ivar )
const {
return GetTransformationHandler().GetRMS(ivar); }
355 Double_t GetXmin( Int_t ivar )
const {
return GetTransformationHandler().GetMin(ivar); }
356 Double_t GetXmax( Int_t ivar )
const {
return GetTransformationHandler().GetMax(ivar); }
359 Double_t GetSignalReferenceCut()
const {
return fSignalReferenceCut; }
360 Double_t GetSignalReferenceCutOrientation()
const {
return fSignalReferenceCutOrientation; }
363 void SetSignalReferenceCut( Double_t cut ) { fSignalReferenceCut = cut; }
364 void SetSignalReferenceCutOrientation( Double_t cutOrientation ) { fSignalReferenceCutOrientation = cutOrientation; }
367 TDirectory* BaseDir()
const;
368 TDirectory* MethodBaseDir()
const;
369 TFile* GetFile()
const {
return fFile;}
371 void SetMethodDir ( TDirectory* methodDir ) { fBaseDir = fMethodBaseDir = methodDir; }
372 void SetBaseDir( TDirectory* methodDir ){ fBaseDir = methodDir; }
373 void SetMethodBaseDir( TDirectory* methodDir ){ fMethodBaseDir = methodDir; }
374 void SetFile(TFile* file){fFile=file;}
377 void SetSilentFile(Bool_t status) {fSilentFile=status;}
378 Bool_t IsSilentFile()
const {
return fSilentFile;}
381 void SetModelPersistence(Bool_t status){fModelPersistence=status;}
382 Bool_t IsModelPersistence()
const {
return fModelPersistence;}
388 UInt_t GetTrainingTMVAVersionCode()
const {
return fTMVATrainingVersion; }
389 UInt_t GetTrainingROOTVersionCode()
const {
return fROOTTrainingVersion; }
390 TString GetTrainingTMVAVersionString()
const;
391 TString GetTrainingROOTVersionString()
const;
393 TransformationHandler& GetTransformationHandler(Bool_t takeReroutedIfAvailable=
true)
395 if(fTransformationPointer && takeReroutedIfAvailable)
return *fTransformationPointer;
else return fTransformation;
397 const TransformationHandler& GetTransformationHandler(Bool_t takeReroutedIfAvailable=
true)
const
399 if(fTransformationPointer && takeReroutedIfAvailable)
return *fTransformationPointer;
else return fTransformation;
402 void RerouteTransformationHandler (TransformationHandler* fTargetTransformation) { fTransformationPointer=fTargetTransformation; }
408 DataSet* Data()
const {
return DataInfo().GetDataSet(); }
409 DataSetInfo& DataInfo()
const {
return fDataSetInfo; }
411 mutable const Event* fTmpEvent;
416 UInt_t GetNEvents ()
const {
return Data()->GetNEvents(); }
417 const Event* GetEvent ()
const;
418 const Event* GetEvent (
const TMVA::Event* ev )
const;
419 const Event* GetEvent ( Long64_t ievt )
const;
420 const Event* GetEvent ( Long64_t ievt , Types::ETreeType type )
const;
421 const Event* GetTrainingEvent( Long64_t ievt )
const;
422 const Event* GetTestingEvent ( Long64_t ievt )
const;
423 const std::vector<TMVA::Event*>& GetEventCollection( Types::ETreeType type );
425 TrainingHistory fTrainHistory;
431 virtual Bool_t IsSignalLike();
432 virtual Bool_t IsSignalLike(Double_t mvaVal);
435 Bool_t HasMVAPdfs()
const {
return fHasMVAPdfs; }
436 virtual void SetAnalysisType( Types::EAnalysisType type ) { fAnalysisType = type; }
437 Types::EAnalysisType GetAnalysisType()
const {
return fAnalysisType; }
438 Bool_t DoRegression()
const {
return fAnalysisType == Types::kRegression; }
439 Bool_t DoMulticlass()
const {
return fAnalysisType == Types::kMulticlass; }
442 void DisableWriting(Bool_t setter){ fModelPersistence = setter?kFALSE:kTRUE; }
446 IPythonInteractive *fInteractive =
nullptr;
447 bool fExitFromTraining =
false;
448 UInt_t fIPyMaxIter = 0, fIPyCurrentIter = 0;
453 inline void InitIPythonInteractive(){
454 if (fInteractive)
delete fInteractive;
455 fInteractive =
new IPythonInteractive();
459 inline TMultiGraph* GetInteractiveTrainingError(){
return fInteractive->Get();}
462 inline void ExitFromTraining(){
463 fExitFromTraining =
true;
467 inline bool TrainingEnded(){
468 if (fExitFromTraining && fInteractive){
470 fInteractive =
nullptr;
472 return fExitFromTraining;
476 inline UInt_t GetMaxIter(){
return fIPyMaxIter; }
479 inline UInt_t GetCurrentIter(){
return fIPyCurrentIter; }
488 void SetWeightFileName( TString );
490 const TString& GetWeightFileDir()
const {
return fFileDir; }
491 void SetWeightFileDir( TString fileDir );
494 Bool_t IsNormalised()
const {
return fNormalise; }
495 void SetNormalised( Bool_t norm ) { fNormalise = norm; }
501 Bool_t Verbose()
const {
return fVerbose; }
502 Bool_t Help ()
const {
return fHelp; }
508 const TString& GetInternalVarName( Int_t ivar )
const {
return (*fInputVars)[ivar]; }
509 const TString& GetOriginalVarName( Int_t ivar )
const {
return DataInfo().GetVariableInfo(ivar).GetExpression(); }
511 Bool_t HasTrainingTree()
const {
return Data()->GetNTrainingEvents() != 0; }
518 virtual void MakeClassSpecific( std::ostream&,
const TString& =
"" )
const {}
521 virtual void MakeClassSpecificHeader( std::ostream&,
const TString& =
"" )
const {}
527 void Statistics( Types::ETreeType treeType,
const TString& theVarName,
528 Double_t&, Double_t&, Double_t&,
529 Double_t&, Double_t&, Double_t& );
532 Bool_t TxtWeightsOnly()
const {
return kTRUE; }
538 Bool_t IsConstructedFromWeightFile()
const {
return fConstructedFromWeightFile; }
545 void DeclareBaseOptions();
546 void ProcessBaseOptions();
549 enum ECutOrientation { kNegative = -1, kPositive = +1 };
550 ECutOrientation GetCutOrientation()
const {
return fCutOrientation; }
555 void ResetThisBase();
560 void CreateMVAPdfs();
564 virtual Double_t GetValueForRoot ( Double_t );
567 Bool_t GetLine( std::istream& fin,
char * buf );
570 virtual void AddClassifierOutput ( Types::ETreeType type );
571 virtual void AddClassifierOutputProb( Types::ETreeType type );
572 virtual void AddRegressionOutput ( Types::ETreeType type );
573 virtual void AddMulticlassOutput ( Types::ETreeType type );
577 void AddInfoItem(
void* gi,
const TString& name,
578 const TString& value)
const;
586 std::vector<TString>* fInputVars;
590 Int_t fNbinsMVAoutput;
593 Types::EAnalysisType fAnalysisType;
595 std::vector<Float_t>* fRegressionReturnVal;
596 std::vector<Float_t>* fMulticlassReturnVal;
601 friend class MethodCuts;
605 DataSetInfo& fDataSetInfo;
607 Double_t fSignalReferenceCut;
608 Double_t fSignalReferenceCutOrientation;
609 Types::ESBType fVariableTransformType;
614 Types::EMVA fMethodType;
616 UInt_t fTMVATrainingVersion;
617 UInt_t fROOTTrainingVersion;
618 Bool_t fConstructedFromWeightFile;
623 TDirectory* fBaseDir;
624 mutable TDirectory* fMethodBaseDir;
631 Bool_t fModelPersistence;
650 TSpline* fSpleffBvsS;
654 TSpline* fSplTrainEffBvsS;
667 TString fVarTransformString;
669 TransformationHandler* fTransformationPointer;
670 TransformationHandler fTransformation;
675 TString fVerbosityLevelString;
676 EMsgType fVerbosityLevel;
680 Bool_t fIgnoreNegWeightsInTraining;
684 Bool_t IgnoreEventsWithNegWeightsInTraining()
const {
return fIgnoreNegWeightsInTraining; }
688 UInt_t fBackgroundClass;
697 ECutOrientation fCutOrientation;
703 TSpline1* fSplTrainRefS;
704 TSpline1* fSplTrainRefB;
706 mutable std::vector<const std::vector<TMVA::Event*>*> fEventCollections;
709 Bool_t fSetupCompleted;
722 TString fVariableTransformTypeString;
723 Bool_t fTxtWeightsOnly;
725 Int_t fNsmoothMVAPdf;
729 ClassDef(MethodBase,0);
744 inline const TMVA::Event* TMVA::MethodBase::GetEvent(
const TMVA::Event* ev )
const
746 return GetTransformationHandler().Transform(ev);
749 inline const TMVA::Event* TMVA::MethodBase::GetEvent()
const
752 return GetTransformationHandler().Transform(fTmpEvent);
754 return GetTransformationHandler().Transform(Data()->GetEvent());
757 inline const TMVA::Event* TMVA::MethodBase::GetEvent( Long64_t ievt )
const
759 assert(fTmpEvent==0);
760 return GetTransformationHandler().Transform(Data()->GetEvent(ievt));
763 inline const TMVA::Event* TMVA::MethodBase::GetEvent( Long64_t ievt, Types::ETreeType type )
const
765 assert(fTmpEvent==0);
766 return GetTransformationHandler().Transform(Data()->GetEvent(ievt, type));
769 inline const TMVA::Event* TMVA::MethodBase::GetTrainingEvent( Long64_t ievt )
const
771 assert(fTmpEvent==0);
772 return GetEvent(ievt, Types::kTraining);
775 inline const TMVA::Event* TMVA::MethodBase::GetTestingEvent( Long64_t ievt )
const
777 assert(fTmpEvent==0);
778 return GetEvent(ievt, Types::kTesting);