Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Rule.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Rule *
8  * *
9  * Description: *
10  * A class describing a 'rule' *
11  * Each internal node of a tree defines a rule from all the parental nodes. *
12  * A rule consists of at least 2 nodes. *
13  * Input: a decision tree (in the constructor) *
14  * its coefficient *
15  * *
16  * *
17  * Authors (alphabetical): *
18  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
19  * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. *
20  * *
21  * Copyright (c) 2005: *
22  * CERN, Switzerland *
23  * Iowa State U. *
24  * MPI-K Heidelberg, Germany *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://tmva.sourceforge.net/LICENSE) *
29  **********************************************************************************/
30 
31 #ifndef ROOT_TMVA_Rule
32 #define ROOT_TMVA_Rule
33 
34 #include "TMath.h"
35 
36 #include "TMVA/DecisionTree.h"
37 #include "TMVA/Event.h"
38 #include "TMVA/RuleCut.h"
39 
40 namespace TMVA {
41 
42  class RuleEnsemble;
43  class MsgLogger;
44  class Rule;
45 
46  std::ostream& operator<<( std::ostream& os, const Rule & rule );
47 
48  class Rule {
49 
50  // output operator for a Rule
51  friend std::ostream& operator<< ( std::ostream& os, const Rule & rule );
52 
53  public:
54 
55  // main constructor
56  Rule( RuleEnsemble *re, const std::vector< const TMVA::Node * > & nodes );
57 
58  // main constructor
59  Rule( RuleEnsemble *re );
60 
61  // copy constructor
62  Rule( const Rule & other ) { Copy( other ); }
63 
64  // empty constructor
65  Rule();
66 
67  virtual ~Rule();
68 
69  // set message type
70  void SetMsgType( EMsgType t );
71 
72  // set RuleEnsemble ptr
73  void SetRuleEnsemble( const RuleEnsemble *re ) { fRuleEnsemble = re; }
74 
75  // set RuleCut ptr
76  void SetRuleCut( RuleCut *rc ) { fCut = rc; }
77 
78  // set Rule norm
79  void SetNorm(Double_t norm) { fNorm = (norm>0 ? 1.0/norm:1.0); }
80 
81  // set coefficient
82  void SetCoefficient(Double_t v) { fCoefficient=v; }
83 
84  // set support
85  void SetSupport(Double_t v) { fSupport=v; fSigma = TMath::Sqrt(v*(1.0-v));}
86 
87  // set s/(s+b)
88  void SetSSB(Double_t v) { fSSB=v; }
89 
90  // set N(eve) accepted by rule
91  void SetSSBNeve(Double_t v) { fSSBNeve=v; }
92 
93  // set reference importance
94  void SetImportanceRef(Double_t v) { fImportanceRef=(v>0 ? v:1.0); }
95 
96  // calculate importance
97  void CalcImportance() { fImportance = TMath::Abs(fCoefficient)*fSigma; }
98 
99  // get the relative importance
100  Double_t GetRelImportance() const { return fImportance/fImportanceRef; }
101 
102  // evaluate the Rule for the given Event using the coefficient
103  // inline Double_t EvalEvent( const Event& e, Bool_t norm ) const;
104 
105  // evaluate the Rule for the given Event, not using normalization or the coefficient
106  inline Bool_t EvalEvent( const Event& e ) const;
107 
108  // test if two rules are equal
109  Bool_t Equal( const Rule & other, Bool_t useCutValue, Double_t maxdist ) const;
110 
111  // get distance between two equal (ie apart from the cut values) rules
112  Double_t RuleDist( const Rule & other, Bool_t useCutValue ) const;
113 
114  // returns true if the trained S/(S+B) of the last node is > 0.5
115  Double_t GetSSB() const { return fSSB; }
116  Double_t GetSSBNeve() const { return fSSBNeve; }
117  Bool_t IsSignalRule() const { return (fSSB>0.5); }
118 
119  // copy operator
120  void operator=( const Rule & other ) { Copy( other ); }
121 
122  // identical operator
123  Bool_t operator==( const Rule & other ) const;
124 
125  Bool_t operator<( const Rule & other ) const;
126 
127  // get number of variables used in Rule
128  UInt_t GetNumVarsUsed() const { return fCut->GetNvars(); }
129 
130  // get number of cuts in Rule
131  UInt_t GetNcuts() const { return fCut->GetNcuts(); }
132 
133  // check if variable is used by the rule
134  Bool_t ContainsVariable(UInt_t iv) const;
135 
136  // accessors
137  const RuleCut* GetRuleCut() const { return fCut; }
138  const RuleEnsemble* GetRuleEnsemble() const { return fRuleEnsemble; }
139  Double_t GetCoefficient() const { return fCoefficient; }
140  Double_t GetSupport() const { return fSupport; }
141  Double_t GetSigma() const { return fSigma; }
142  Double_t GetNorm() const { return fNorm; }
143  Double_t GetImportance() const { return fImportance; }
144  Double_t GetImportanceRef() const { return fImportanceRef; }
145 
146  // print the rule using flogger
147  void PrintLogger( const char *title=0 ) const;
148 
149  // print just the raw info, used for weight file generation
150  void PrintRaw ( std::ostream& os ) const; // obsolete
151  void* AddXMLTo ( void* parent ) const;
152 
153  void ReadRaw ( std::istream& os ); // obsolete
154  void ReadFromXML( void* wghtnode );
155 
156  private:
157 
158  // set sigma - don't use this as non private!
159  void SetSigma(Double_t v) { fSigma=v; }
160 
161  // print info about the Rule
162  void Print( std::ostream& os ) const;
163 
164  // copy from another rule
165  void Copy( const Rule & other );
166 
167  // get the name of variable with index i
168  const TString & GetVarName( Int_t i) const;
169 
170  RuleCut* fCut; // all cuts associated with the rule
171  Double_t fNorm; // normalization - usually 1.0/t(k)
172  Double_t fSupport; // s(k)
173  Double_t fSigma; // t(k) = sqrt(s*(1-s))
174  Double_t fCoefficient; // rule coeff. a(k)
175  Double_t fImportance; // importance of rule
176  Double_t fImportanceRef; // importance ref
177  const RuleEnsemble* fRuleEnsemble; // pointer to parent RuleEnsemble
178  Double_t fSSB; // S/(S+B) for rule
179  Double_t fSSBNeve; // N(events) reaching the last node in reevaluation
180 
181  mutable MsgLogger* fLogger; //! message logger
182  MsgLogger& Log() const { return *fLogger; }
183 
184  };
185 
186 } // end of TMVA namespace
187 
188 //_______________________________________________________________________
189 inline Bool_t TMVA::Rule::EvalEvent( const TMVA::Event& e ) const
190 {
191  // Checks if event is accepted by rule.
192  // Return true if yes and false if not.
193  //
194  return fCut->EvalEvent(e);
195 }
196 
197 #endif