29 #ifndef ROOT_TMVA_RuleEnsemble
30 #define ROOT_TMVA_RuleEnsemble
49 std::ostream& operator<<( std::ostream& os,
const RuleEnsemble& event );
54 friend std::ostream& operator<< ( std::ostream& os,
const RuleEnsemble& rules );
58 enum ELearningModel { kFull=0, kRules=1, kLinear=2 };
61 RuleEnsemble( RuleFit* rf );
64 RuleEnsemble(
const RuleEnsemble& other );
70 virtual ~RuleEnsemble();
73 void Initialize(
const RuleFit* rf );
76 void SetMsgType( EMsgType t );
82 void MakeRules(
const std::vector< const TMVA::DecisionTree *>& forest );
85 void MakeLinearTerms();
88 void SetModelLinear() { fLearningModel = kLinear; }
91 void SetModelRules() { fLearningModel = kRules; }
94 void SetModelFull() { fLearningModel = kFull; }
97 void SetRules(
const std::vector< TMVA::Rule *> & rules );
100 void SetRuleFit(
const RuleFit *rf ) { fRuleFit = rf; }
103 void SetCoefficients(
const std::vector< Double_t >& v );
104 void SetCoefficient( UInt_t i, Double_t v ) {
if (i<fRules.size()) fRules[i]->SetCoefficient(v); }
106 void SetOffset(Double_t v=0.0) { fOffset=v; }
107 void AddOffset(Double_t v) { fOffset+=v; }
108 void SetLinCoefficients(
const std::vector< Double_t >& v ) { fLinCoefficients = v; }
109 void SetLinCoefficient( UInt_t i, Double_t v ) { fLinCoefficients[i] = v; }
110 void SetLinDM(
const std::vector<Double_t> & xmin ) { fLinDM = xmin; }
111 void SetLinDP(
const std::vector<Double_t> & xmax ) { fLinDP = xmax; }
112 void SetLinNorm(
const std::vector<Double_t> & norm ) { fLinNorm = norm; }
114 Double_t CalcLinNorm( Double_t stdev ) {
return ( stdev>0 ? fAverageRuleSigma/stdev : 1.0 ); }
117 void ClearCoefficients( Double_t val=0 ) {
for (UInt_t i=0; i<fRules.size(); i++) fRules[i]->SetCoefficient(val); }
118 void ClearLinCoefficients( Double_t val=0 ) {
for (UInt_t i=0; i<fLinCoefficients.size(); i++) fLinCoefficients[i]=val; }
119 void ClearLinNorm( Double_t val=1.0 ) {
for (UInt_t i=0; i<fLinNorm.size(); i++) fLinNorm[i]=val; }
122 void SetRuleMinDist(Double_t d) { fRuleMinDist = d; }
125 void SetImportanceCut(Double_t minimp=0) { fImportanceCut=minimp; }
128 void SetLinQuantile(Double_t q) { fLinQuantile=q; }
131 void SetAverageRuleSigma(Double_t v) {
if (v>0.5) v=0.5; fAverageRuleSigma = v; fAverageSupport = 0.5*(1.0+TMath::Sqrt(1.0-4.0*v*v)); }
134 Int_t CalcNRules(
const TMVA::DecisionTree* dtree );
136 void FindNEndNodes(
const TMVA::Node* node, Int_t& nendnodes );
139 void SetEvent(
const Event & e ) { fEvent = &e; fEventCacheOK = kFALSE; }
142 void UpdateEventVal();
145 void MakeRuleMap(
const std::vector<const TMVA::Event *> *events=0, UInt_t ifirst=0, UInt_t ilast=0);
148 void ClearRuleMap() { fRuleMap.clear(); fRuleMapEvents=0; }
152 Double_t EvalEvent()
const;
153 Double_t EvalEvent(
const Event & e );
156 Double_t EvalEvent( Double_t ofs,
157 const std::vector<Double_t> & coefs,
158 const std::vector<Double_t> & lincoefs)
const;
159 Double_t EvalEvent(
const Event & e,
161 const std::vector<Double_t> & coefs,
162 const std::vector<Double_t> & lincoefs);
166 Double_t EvalEvent( UInt_t evtidx )
const;
167 Double_t EvalEvent( UInt_t evtidx,
169 const std::vector<Double_t> & coefs,
170 const std::vector<Double_t> & lincoefs)
const;
174 Double_t EvalLinEvent()
const;
175 Double_t EvalLinEvent(
const std::vector<Double_t> & coefs )
const;
176 Double_t EvalLinEvent(
const Event &e );
177 Double_t EvalLinEvent(
const Event &e, UInt_t vind );
178 Double_t EvalLinEvent(
const Event &e,
const std::vector<Double_t> & coefs );
181 Double_t EvalLinEvent( UInt_t evtidx )
const;
182 Double_t EvalLinEvent( UInt_t evtidx,
const std::vector<Double_t> & coefs )
const;
183 Double_t EvalLinEvent( UInt_t evtidx, UInt_t vind )
const;
184 Double_t EvalLinEvent( UInt_t evtidx, UInt_t vind, Double_t coefs )
const;
187 Double_t EvalLinEventRaw( UInt_t vind,
const Event &e, Bool_t norm )
const;
188 Double_t EvalLinEventRaw( UInt_t vind, UInt_t evtidx, Bool_t norm )
const;
191 Double_t PdfLinear( Double_t & nsig, Double_t & ntot )
const;
194 Double_t PdfRule( Double_t & nsig, Double_t & ntot )
const;
197 Double_t FStar()
const;
198 Double_t FStar(
const TMVA::Event & e );
201 void SetImportanceRef(Double_t impref);
204 void CalcRuleSupport();
207 void CalcImportance();
210 Double_t CalcRuleImportance();
213 Double_t CalcLinImportance();
216 void CalcVarImportance();
222 void CleanupLinear();
225 void RemoveSimilarRules();
228 void RuleStatistics();
231 void RuleResponseStats();
234 void operator=(
const RuleEnsemble& other ) { Copy( other ); }
237 Double_t CoefficientRadius();
240 void GetCoefficients( std::vector< Double_t >& v );
243 const MethodRuleFit* GetMethodRuleFit()
const;
244 const MethodBase* GetMethodBase()
const;
245 const RuleFit* GetRuleFit()
const {
return fRuleFit; }
247 const std::vector<const TMVA::Event *>* GetTrainingEvents()
const;
248 const Event* GetTrainingEvent(UInt_t i)
const;
249 const Event* GetEvent()
const {
return fEvent; }
251 Bool_t DoLinear()
const {
return (fLearningModel==kFull) || (fLearningModel==kLinear); }
252 Bool_t DoRules()
const {
return (fLearningModel==kFull) || (fLearningModel==kRules); }
253 Bool_t DoOnlyRules()
const {
return (fLearningModel==kRules); }
254 Bool_t DoOnlyLinear()
const {
return (fLearningModel==kLinear); }
255 Bool_t DoFull()
const {
return (fLearningModel==kFull); }
256 ELearningModel GetLearningModel()
const {
return fLearningModel; }
257 Double_t GetImportanceCut()
const {
return fImportanceCut; }
258 Double_t GetImportanceRef()
const {
return fImportanceRef; }
259 Double_t GetOffset()
const {
return fOffset; }
260 UInt_t GetNRules()
const {
return (DoRules() ? fRules.size():0); }
261 const std::vector<TMVA::Rule*>& GetRulesConst()
const {
return fRules; }
262 std::vector<TMVA::Rule*>& GetRules() {
return fRules; }
263 const std::vector< Double_t >& GetLinCoefficients()
const {
return fLinCoefficients; }
264 const std::vector< Double_t >& GetLinNorm()
const {
return fLinNorm; }
265 const std::vector< Double_t >& GetLinImportance()
const {
return fLinImportance; }
266 const std::vector< Double_t >& GetVarImportance()
const {
return fVarImportance; }
267 UInt_t GetNLinear()
const {
return (DoLinear() ? fLinNorm.size():0); }
268 Double_t GetLinQuantile()
const {
return fLinQuantile; }
270 const Rule *GetRulesConst(
int i)
const {
return fRules[i]; }
271 Rule *GetRules(
int i) {
return fRules[i]; }
273 UInt_t GetRulesNCuts(
int i)
const {
return fRules[i]->GetRuleCut()->GetNcuts(); }
274 Double_t GetRuleMinDist()
const {
return fRuleMinDist; }
275 Double_t GetLinCoefficients(
int i)
const {
return fLinCoefficients[i]; }
276 Double_t GetLinNorm(
int i)
const {
return fLinNorm[i]; }
277 Double_t GetLinDM(
int i)
const {
return fLinDM[i]; }
278 Double_t GetLinDP(
int i)
const {
return fLinDP[i]; }
279 Double_t GetLinImportance(
int i)
const {
return fLinImportance[i]; }
280 Double_t GetVarImportance(
int i)
const {
return fVarImportance[i]; }
281 Double_t GetRulePTag(
int i)
const {
return fRulePTag[i]; }
282 Double_t GetRulePSS(
int i)
const {
return fRulePSS[i]; }
283 Double_t GetRulePSB(
int i)
const {
return fRulePSB[i]; }
284 Double_t GetRulePBS(
int i)
const {
return fRulePBS[i]; }
285 Double_t GetRulePBB(
int i)
const {
return fRulePBB[i]; }
287 Bool_t IsLinTermOK(
int i)
const {
return fLinTermOK[i]; }
289 Double_t GetAverageSupport()
const {
return fAverageSupport; }
290 Double_t GetAverageRuleSigma()
const {
return fAverageRuleSigma; }
291 Double_t GetEventRuleVal(UInt_t i)
const {
return (fEventRuleVal[i] ? 1.0:0.0); }
292 Double_t GetEventLinearVal(UInt_t i)
const {
return fEventLinearVal[i]; }
293 Double_t GetEventLinearValNorm(UInt_t i)
const {
return fEventLinearVal[i]*fLinNorm[i]; }
295 const std::vector<UInt_t> & GetEventRuleMap(UInt_t evtidx)
const {
return fRuleMap[evtidx]; }
296 const TMVA::Event *GetRuleMapEvent(UInt_t evtidx)
const {
return (*fRuleMapEvents)[evtidx]; }
297 Bool_t IsRuleMapOK()
const {
return fRuleMapOK; }
300 void PrintRuleGen()
const;
306 void PrintRaw ( std::ostream& os )
const;
307 void* AddXMLTo (
void* parent )
const;
310 void ReadRaw ( std::istream& istr );
311 void ReadFromXML(
void* wghtnode );
317 void DeleteRules() {
for (UInt_t i=0; i<fRules.size(); i++)
delete fRules[i]; fRules.clear(); }
320 void Copy( RuleEnsemble
const& other );
323 void ResetCoefficients();
326 void MakeRulesFromTree(
const DecisionTree *dtree );
329 void AddRule(
const Node *node );
332 Rule *MakeTheRule(
const Node *node );
335 ELearningModel fLearningModel;
336 Double_t fImportanceCut;
337 Double_t fLinQuantile;
339 std::vector< TMVA::Rule* > fRules;
340 std::vector< Char_t > fLinTermOK;
341 std::vector< Double_t > fLinDP;
342 std::vector< Double_t > fLinDM;
343 std::vector< Double_t > fLinCoefficients;
344 std::vector< Double_t > fLinNorm;
345 std::vector< TH1F* > fLinPDFB;
346 std::vector< TH1F* > fLinPDFS;
347 std::vector< Double_t > fLinImportance;
348 std::vector< Double_t > fVarImportance;
349 Double_t fImportanceRef;
350 Double_t fAverageSupport;
351 Double_t fAverageRuleSigma;
353 std::vector< Double_t > fRuleVarFrac;
354 std::vector< Double_t > fRulePSS;
355 std::vector< Double_t > fRulePSB;
356 std::vector< Double_t > fRulePBS;
357 std::vector< Double_t > fRulePBB;
358 std::vector< Double_t > fRulePTag;
363 Double_t fRuleMinDist;
364 UInt_t fNRulesGenerated;
367 Bool_t fEventCacheOK;
368 std::vector<Char_t> fEventRuleVal;
369 std::vector<Double_t> fEventLinearVal;
372 std::vector< std::vector<UInt_t> > fRuleMap;
375 const std::vector<const TMVA::Event *> *fRuleMapEvents;
377 const RuleFit* fRuleFit;
379 mutable MsgLogger* fLogger;
380 MsgLogger& Log()
const {
return *fLogger; }
385 inline void TMVA::RuleEnsemble::UpdateEventVal()
390 if (fEventCacheOK)
return;
393 UInt_t nrules = fRules.size();
394 fEventRuleVal.resize(nrules,kFALSE);
395 for (UInt_t r=0; r<nrules; r++) {
396 fEventRuleVal[r] = fRules[r]->EvalEvent(*fEvent);
400 UInt_t nlin = fLinTermOK.size();
401 fEventLinearVal.resize(nlin,0);
402 for (UInt_t r=0; r<nlin; r++) {
403 fEventLinearVal[r] = EvalLinEventRaw(r,*fEvent,kFALSE);
406 fEventCacheOK = kTRUE;
410 inline Double_t TMVA::RuleEnsemble::EvalEvent()
const
414 Int_t nrules = fRules.size();
415 Double_t rval=fOffset;
422 for ( Int_t i=0; i<nrules; i++ ) {
423 if (fEventRuleVal[i])
424 rval += fRules[i]->GetCoefficient();
430 if (DoLinear()) linear = EvalLinEvent();
437 inline Double_t TMVA::RuleEnsemble::EvalEvent( Double_t ofs,
438 const std::vector<Double_t> & coefs,
439 const std::vector<Double_t> & lincoefs )
const
443 Int_t nrules = fRules.size();
450 for ( Int_t i=0; i<nrules; i++ ) {
451 if (fEventRuleVal[i])
458 if (DoLinear()) linear = EvalLinEvent(lincoefs);
465 inline Double_t TMVA::RuleEnsemble::EvalEvent(
const TMVA::Event & e)
474 inline Double_t TMVA::RuleEnsemble::EvalEvent(
const TMVA::Event & e,
476 const std::vector<Double_t> & coefs,
477 const std::vector<Double_t> & lincoefs )
482 return EvalEvent(ofs,coefs,lincoefs);
486 inline Double_t TMVA::RuleEnsemble::EvalEvent(UInt_t evtidx)
const
489 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1))
return 0;
491 Double_t rval=fOffset;
493 UInt_t nrules = fRuleMap[evtidx].size();
495 for (UInt_t ir = 0; ir<nrules; ir++) {
496 rind = fRuleMap[evtidx][ir];
497 rval += fRules[rind]->GetCoefficient();
501 UInt_t nlin = fLinTermOK.size();
502 for (UInt_t r=0; r<nlin; r++) {
504 rval += fLinCoefficients[r] * EvalLinEventRaw(r,*(*fRuleMapEvents)[evtidx],kTRUE);
512 inline Double_t TMVA::RuleEnsemble::EvalEvent(UInt_t evtidx,
514 const std::vector<Double_t> & coefs,
515 const std::vector<Double_t> & lincoefs )
const
519 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1))
return 0;
522 UInt_t nrules = fRuleMap[evtidx].size();
524 for (UInt_t ir = 0; ir<nrules; ir++) {
525 rind = fRuleMap[evtidx][ir];
530 rval += EvalLinEvent( evtidx, lincoefs );
536 inline Double_t TMVA::RuleEnsemble::EvalLinEventRaw( UInt_t vind,
const TMVA::Event & e, Bool_t norm)
const
540 Double_t val = e.GetValue(vind);
541 Double_t rval = TMath::Min( fLinDP[vind], TMath::Max( fLinDM[vind], val ) );
542 if (norm) rval *= fLinNorm[vind];
547 inline Double_t TMVA::RuleEnsemble::EvalLinEventRaw( UInt_t vind, UInt_t evtidx, Bool_t norm)
const
551 Double_t val = (*fRuleMapEvents)[evtidx]->GetValue(vind);
552 Double_t rval = TMath::Min( fLinDP[vind], TMath::Max( fLinDM[vind], val ) );
553 if (norm) rval *= fLinNorm[vind];
558 inline Double_t TMVA::RuleEnsemble::EvalLinEvent()
const
563 for (UInt_t v=0; v<fLinTermOK.size(); v++) {
565 rval += fLinCoefficients[v]*fEventLinearVal[v]*fLinNorm[v];
571 inline Double_t TMVA::RuleEnsemble::EvalLinEvent(
const std::vector<Double_t> & coefs)
const
576 for (UInt_t v=0; v<fLinTermOK.size(); v++) {
578 rval += coefs[v]*fEventLinearVal[v]*fLinNorm[v];
584 inline Double_t TMVA::RuleEnsemble::EvalLinEvent(
const TMVA::Event& e )
590 return EvalLinEvent();
594 inline Double_t TMVA::RuleEnsemble::EvalLinEvent(
const TMVA::Event& e, UInt_t vind )
600 return GetEventLinearValNorm(vind);
604 inline Double_t TMVA::RuleEnsemble::EvalLinEvent(
const TMVA::Event& e,
const std::vector<Double_t> & coefs )
610 return EvalLinEvent(coefs);
614 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx,
const std::vector<Double_t> & coefs )
const
617 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1))
return 0;
619 UInt_t nlin = fLinTermOK.size();
620 for (UInt_t r=0; r<nlin; r++) {
622 rval += coefs[r] * EvalLinEventRaw(r,*(*fRuleMapEvents)[evtidx],kTRUE);
629 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx )
const
632 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1))
return 0;
634 UInt_t nlin = fLinTermOK.size();
635 for (UInt_t r=0; r<nlin; r++) {
637 rval += fLinCoefficients[r] * EvalLinEventRaw(r,*(*fRuleMapEvents)[evtidx],kTRUE);
644 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx, UInt_t vind )
const
647 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1))
return 0;
649 rval = fLinCoefficients[vind] * EvalLinEventRaw(vind,*(*fRuleMapEvents)[evtidx],kTRUE);
654 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx, UInt_t vind, Double_t coefs )
const
657 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1))
return 0;
659 rval = coefs * EvalLinEventRaw(vind,*(*fRuleMapEvents)[evtidx],kTRUE);