27 #ifndef ROOT_TMVA_RuleFit
28 #define ROOT_TMVA_RuleFit
50 RuleFit(
const TMVA::MethodBase *rfbase );
55 virtual ~RuleFit(
void );
58 void InitPtrs(
const TMVA::MethodBase *rfbase );
59 void Initialize(
const TMVA::MethodBase *rfbase );
61 void SetMsgType( EMsgType t );
63 void SetTrainingEvents(
const std::vector<const TMVA::Event *> & el );
65 void ReshuffleEvents()
67 std::shuffle(fTrainingEventsRndm.begin(), fTrainingEventsRndm.end(), fRNGEngine);
70 void SetMethodBase(
const MethodBase *rfbase );
76 void BuildTree( TMVA::DecisionTree *dt );
79 void SaveEventWeights();
82 void RestoreEventWeights();
85 void Boost( TMVA::DecisionTree *dt );
88 void ForestStatistics();
91 Double_t EvalEvent(
const Event& e );
94 Double_t CalcWeightSum(
const std::vector<const TMVA::Event *> *events, UInt_t neve=0 );
97 void FitCoefficients();
100 void CalcImportance();
103 void SetModelLinear() { fRuleEnsemble.SetModelLinear(); }
105 void SetModelRules() { fRuleEnsemble.SetModelRules(); }
107 void SetModelFull() { fRuleEnsemble.SetModelFull(); }
109 void SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
111 void SetRuleMinDist( Double_t d ) { fRuleEnsemble.SetRuleMinDist(d); }
113 void SetGDTau( Double_t t=0.0 ) { fRuleFitParams.SetGDTau(t); }
114 void SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
115 void SetGDNPathSteps( Int_t n=100 ) { fRuleFitParams.SetGDNPathSteps(n); }
117 void SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
118 void UseImportanceVisHists() { fVisHistsUseImp = kTRUE; }
119 void UseCoefficientsVisHists() { fVisHistsUseImp = kFALSE; }
121 void FillVisHistCut(
const Rule * rule, std::vector<TH2F *> & hlist);
122 void FillVisHistCorr(
const Rule * rule, std::vector<TH2F *> & hlist);
123 void FillCut(TH2F* h2,
const TMVA::Rule *rule,Int_t vind);
124 void FillLin(TH2F* h2,Int_t vind);
125 void FillCorr(TH2F* h2,
const TMVA::Rule *rule,Int_t v1, Int_t v2);
126 void NormVisHists(std::vector<TH2F *> & hlist);
127 void MakeDebugHists();
128 Bool_t GetCorrVars(TString & title, TString & var1, TString & var2);
130 UInt_t GetNTreeSample()
const {
return fNTreeSample; }
131 Double_t GetNEveEff()
const {
return fNEveEffTrain; }
132 const Event* GetTrainingEvent(UInt_t i)
const {
return static_cast< const Event *
>(fTrainingEvents[i]); }
133 Double_t GetTrainingEventWeight(UInt_t i)
const {
return fTrainingEvents[i]->GetWeight(); }
137 const std::vector< const TMVA::Event * > & GetTrainingEvents()
const {
return fTrainingEvents; }
141 void GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
143 const std::vector< const TMVA::DecisionTree *> & GetForest()
const {
return fForest; }
144 const RuleEnsemble & GetRuleEnsemble()
const {
return fRuleEnsemble; }
145 RuleEnsemble * GetRuleEnsemblePtr() {
return &fRuleEnsemble; }
146 const RuleFitParams & GetRuleFitParams()
const {
return fRuleFitParams; }
147 RuleFitParams * GetRuleFitParamsPtr() {
return &fRuleFitParams; }
148 const MethodRuleFit * GetMethodRuleFit()
const {
return fMethodRuleFit; }
149 const MethodBase * GetMethodBase()
const {
return fMethodBase; }
154 RuleFit(
const RuleFit & other );
157 void Copy(
const RuleFit & other );
159 std::vector<const TMVA::Event *> fTrainingEvents;
160 std::vector<const TMVA::Event *> fTrainingEventsRndm;
161 std::vector<Double_t> fEventWeights;
164 Double_t fNEveEffTrain;
165 std::vector< const TMVA::DecisionTree *> fForest;
166 RuleEnsemble fRuleEnsemble;
167 RuleFitParams fRuleFitParams;
168 const MethodRuleFit *fMethodRuleFit;
169 const MethodBase *fMethodBase;
170 Bool_t fVisHistsUseImp;
172 mutable MsgLogger* fLogger;
173 MsgLogger& Log()
const {
return *fLogger; }
175 static const Int_t randSEED = 0;
176 std::default_random_engine fRNGEngine;