Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Rule.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Rule *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * A class describing a 'rule' *
12  * Each internal node of a tree defines a rule from all the parental nodes. *
13  * A rule consists of at least 2 nodes. *
14  * Input: a decision tree (in the constructor) *
15  * *
16  * Authors (alphabetical): *
17  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. *
19  * *
20  * Copyright (c) 2005: *
21  * CERN, Switzerland *
22  * Iowa State U. *
23  * MPI-K Heidelberg, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 /*! \class TMVA::Rule
31 \ingroup TMVA
32 
33 Implementation of a rule.
34 
35 A rule is simply a branch or a part of a branch in a tree.
36 It fulfills the following:
37 
38  - First node is the root node of the originating tree
39  - Consists of a minimum of 2 nodes
40  - A rule returns for a given event:
41  - 0 : if the event fails at any node
42  - 1 : otherwise
43  - If the rule contains <2 nodes, it returns 0 SHOULD NOT HAPPEN!
44 
45 The coefficient is found by either brute force or some sort of
46 intelligent fitting. See the RuleEnsemble class for more info.
47 */
48 
49 #include "TMVA/Rule.h"
50 
51 #include "TMVA/Event.h"
52 #include "TMVA/MethodBase.h"
53 #include "TMVA/MethodRuleFit.h"
54 #include "TMVA/MsgLogger.h"
55 #include "TMVA/RuleCut.h"
56 #include "TMVA/RuleFit.h"
57 #include "TMVA/RuleEnsemble.h"
58 #include "TMVA/Tools.h"
59 #include "TMVA/Types.h"
60 
61 ////////////////////////////////////////////////////////////////////////////////
62 /// the main constructor for a Rule
63 
64 TMVA::Rule::Rule( RuleEnsemble *re,
65  const std::vector< const Node * >& nodes )
66  : fCut ( 0 )
67  , fNorm ( 1.0 )
68  , fSupport ( 0.0 )
69  , fSigma ( 0.0 )
70  , fCoefficient ( 0.0 )
71  , fImportance ( 0.0 )
72  , fImportanceRef ( 1.0 )
73  , fRuleEnsemble ( re )
74  , fSSB ( 0 )
75  , fSSBNeve ( 0 )
76  , fLogger( new MsgLogger("RuleFit") )
77 {
78  //
79  // input:
80  // nodes - a vector of Node; from these all possible rules will be created
81  //
82  //
83 
84  fCut = new RuleCut( nodes );
85  fSSB = fCut->GetPurity();
86  fSSBNeve = fCut->GetCutNeve();
87 }
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 /// the simple constructor
91 
92 TMVA::Rule::Rule( RuleEnsemble *re )
93  : fCut ( 0 )
94  , fNorm ( 1.0 )
95  , fSupport ( 0.0 )
96  , fSigma ( 0.0 )
97  , fCoefficient ( 0.0 )
98  , fImportance ( 0.0 )
99  , fImportanceRef ( 1.0 )
100  , fRuleEnsemble ( re )
101  , fSSB ( 0 )
102  , fSSBNeve ( 0 )
103  , fLogger( new MsgLogger("RuleFit") )
104 {
105 }
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 /// the simple constructor
109 
110 TMVA::Rule::Rule()
111  : fCut ( 0 )
112  , fNorm ( 1.0 )
113  , fSupport ( 0.0 )
114  , fSigma ( 0.0 )
115  , fCoefficient ( 0.0 )
116  , fImportance ( 0.0 )
117  , fImportanceRef ( 1.0 )
118  , fRuleEnsemble ( 0 )
119  , fSSB ( 0 )
120  , fSSBNeve ( 0 )
121  , fLogger( new MsgLogger("RuleFit") )
122 {
123 }
124 
125 ////////////////////////////////////////////////////////////////////////////////
126 /// destructor
127 
128 TMVA::Rule::~Rule()
129 {
130  delete fCut;
131  delete fLogger;
132 }
133 
134 ////////////////////////////////////////////////////////////////////////////////
135 /// check if variable in node
136 
137 Bool_t TMVA::Rule::ContainsVariable(UInt_t iv) const
138 {
139  Bool_t found = kFALSE;
140  Bool_t doneLoop = kFALSE;
141  UInt_t nvars = fCut->GetNvars();
142  UInt_t i = 0;
143  //
144  while (!doneLoop) {
145  found = (fCut->GetSelector(i) == iv);
146  i++;
147  doneLoop = (found || (i==nvars));
148  }
149  return found;
150 }
151 
152 ////////////////////////////////////////////////////////////////////////////////
153 
154 void TMVA::Rule::SetMsgType( EMsgType t )
155 {
156  fLogger->SetMinType(t);
157 }
158 
159 
160 ////////////////////////////////////////////////////////////////////////////////
161 /// Compare two rules.
162 ///
163 /// - useCutValue:
164 /// - true -> calculate a distance between the two rules based on the cut values
165 /// if the rule cuts are not equal, the distance is < 0 (-1.0)
166 /// return true if d<mindist
167 /// - false-> ignore mindist, return true if rules are equal, ignoring cut values
168 /// - mindist: min distance allowed between rules; if < 0 => set useCutValue=false;
169 
170 Bool_t TMVA::Rule::Equal( const Rule& other, Bool_t useCutValue, Double_t mindist ) const
171 {
172  Bool_t rval=kFALSE;
173  if (mindist<0) useCutValue=kFALSE;
174  Double_t d = RuleDist( other, useCutValue );
175  // cut value used - return true if 0<=d<mindist
176  if (useCutValue) rval = ( (!(d<0)) && (d<mindist) );
177  else rval = (!(d<0));
178  // cut value not used, return true if <> -1
179  return rval;
180 }
181 
182 ////////////////////////////////////////////////////////////////////////////////
183 /// Returns:
184 ///
185 /// * -1.0 : rules are NOT equal, i.e, variables and/or cut directions are wrong
186 /// * >=0: rules are equal apart from the cutvalue, returns \f$ d = \sqrt{\sum(c1-c2)^2} \f$
187 ///
188 /// If not useCutValue, the distance is exactly zero if they are equal
189 
190 Double_t TMVA::Rule::RuleDist( const Rule& other, Bool_t useCutValue ) const
191 {
192  if (fCut->GetNvars()!=other.GetRuleCut()->GetNvars()) return -1.0; // check number of cuts
193  //
194  const UInt_t nvars = fCut->GetNvars();
195  //
196  Int_t sel; // cut variable
197  Double_t rms; // rms of cut variable
198  Double_t smin; // distance between the lower range
199  Double_t smax; // distance between the upper range
200  Double_t vminA,vmaxA; // min,max range of cut A (cut from this Rule)
201  Double_t vminB,vmaxB; // idem from other Rule
202  //
203  // compare nodes
204  // A 'distance' is assigned if the two rules has exactly the same set of cuts but with
205  // different cut values.
206  // The distance is given in number of sigmas
207  //
208  UInt_t in = 0; // cut index
209  Double_t sumdc2 = 0; // sum of 'distances'
210  Bool_t equal = true; // flag if cut are equal
211  //
212  const RuleCut *otherCut = other.GetRuleCut();
213  while ((equal) && (in<nvars)) {
214  // check equality in cut topology
215  equal = ( (fCut->GetSelector(in) == (otherCut->GetSelector(in))) &&
216  (fCut->GetCutDoMin(in) == (otherCut->GetCutDoMin(in))) &&
217  (fCut->GetCutDoMax(in) == (otherCut->GetCutDoMax(in))) );
218  // if equal topology, check cut values
219  if (equal) {
220  if (useCutValue) {
221  sel = fCut->GetSelector(in);
222  vminA = fCut->GetCutMin(in);
223  vmaxA = fCut->GetCutMax(in);
224  vminB = other.GetRuleCut()->GetCutMin(in);
225  vmaxB = other.GetRuleCut()->GetCutMax(in);
226  // messy - but ok...
227  rms = fRuleEnsemble->GetRuleFit()->GetMethodBase()->GetRMS(sel);
228  smin=0;
229  smax=0;
230  if (fCut->GetCutDoMin(in))
231  smin = ( rms>0 ? (vminA-vminB)/rms : 0 );
232  if (fCut->GetCutDoMax(in))
233  smax = ( rms>0 ? (vmaxA-vmaxB)/rms : 0 );
234  sumdc2 += smin*smin + smax*smax;
235  // sumw += 1.0/(rms*rms); // TODO: probably not needed
236  }
237  }
238  in++;
239  }
240  if (!useCutValue) sumdc2 = (equal ? 0.0:-1.0); // ignore cut values
241  else sumdc2 = (equal ? sqrt(sumdc2) : -1.0);
242 
243  return sumdc2;
244 }
245 
246 ////////////////////////////////////////////////////////////////////////////////
247 /// comparison operator ==
248 
249 Bool_t TMVA::Rule::operator==( const Rule& other ) const
250 {
251  return this->Equal( other, kTRUE, 1e-3 );
252 }
253 
254 ////////////////////////////////////////////////////////////////////////////////
255 /// comparison operator <
256 
257 Bool_t TMVA::Rule::operator<( const Rule& other ) const
258 {
259  return (fImportance < other.GetImportance());
260 }
261 
262 ////////////////////////////////////////////////////////////////////////////////
263 /// std::ostream operator
264 
265 std::ostream& TMVA::operator<< ( std::ostream& os, const Rule& rule )
266 {
267  rule.Print( os );
268  return os;
269 }
270 
271 ////////////////////////////////////////////////////////////////////////////////
272 /// returns the name of a rule
273 
274 const TString & TMVA::Rule::GetVarName( Int_t i ) const
275 {
276  return fRuleEnsemble->GetMethodBase()->GetInputLabel(i);
277 }
278 
279 ////////////////////////////////////////////////////////////////////////////////
280 /// copy function
281 
282 void TMVA::Rule::Copy( const Rule& other )
283 {
284  if(this != &other) {
285  SetRuleEnsemble( other.GetRuleEnsemble() );
286  fCut = new RuleCut( *(other.GetRuleCut()) );
287  fSSB = other.GetSSB();
288  fSSBNeve = other.GetSSBNeve();
289  SetCoefficient(other.GetCoefficient());
290  SetSupport( other.GetSupport() );
291  SetSigma( other.GetSigma() );
292  SetNorm( other.GetNorm() );
293  CalcImportance();
294  SetImportanceRef( other.GetImportanceRef() );
295  }
296 }
297 
298 ////////////////////////////////////////////////////////////////////////////////
299 /// print function
300 
301 void TMVA::Rule::Print( std::ostream& os ) const
302 {
303  const UInt_t nvars = fCut->GetNvars();
304  if (nvars<1) os << " *** WARNING - <EMPTY RULE> ***" << std::endl; // TODO: Fix this, use fLogger
305  //
306  Int_t sel;
307  Double_t valmin, valmax;
308  //
309  os << " Importance = " << Form("%1.4f", fImportance/fImportanceRef) << std::endl;
310  os << " Coefficient = " << Form("%1.4f", fCoefficient) << std::endl;
311  os << " Support = " << Form("%1.4f", fSupport) << std::endl;
312  os << " S/(S+B) = " << Form("%1.4f", fSSB) << std::endl;
313 
314  for ( UInt_t i=0; i<nvars; i++) {
315  os << " ";
316  sel = fCut->GetSelector(i);
317  valmin = fCut->GetCutMin(i);
318  valmax = fCut->GetCutMax(i);
319  //
320  os << Form("* Cut %2d",i+1) << " : " << std::flush;
321  if (fCut->GetCutDoMin(i)) os << Form("%10.3g",valmin) << " < " << std::flush;
322  else os << " " << std::flush;
323  os << GetVarName(sel) << std::flush;
324  if (fCut->GetCutDoMax(i)) os << " < " << Form("%10.3g",valmax) << std::flush;
325  else os << " " << std::flush;
326  os << std::endl;
327  }
328 }
329 
330 ////////////////////////////////////////////////////////////////////////////////
331 /// print function
332 
333 void TMVA::Rule::PrintLogger(const char *title) const
334 {
335  const UInt_t nvars = fCut->GetNvars();
336  if (nvars<1) Log() << kWARNING << "BUG TRAP: EMPTY RULE!!!" << Endl;
337  //
338  Int_t sel;
339  Double_t valmin, valmax;
340  //
341  if (title) Log() << kINFO << title;
342  Log() << kINFO
343  << "Importance = " << Form("%1.4f", fImportance/fImportanceRef) << Endl;
344 
345  for ( UInt_t i=0; i<nvars; i++) {
346 
347  Log() << kINFO << " ";
348  sel = fCut->GetSelector(i);
349  valmin = fCut->GetCutMin(i);
350  valmax = fCut->GetCutMax(i);
351  //
352  Log() << kINFO << Form("Cut %2d",i+1) << " : ";
353  if (fCut->GetCutDoMin(i)) Log() << kINFO << Form("%10.3g",valmin) << " < ";
354  else Log() << kINFO << " ";
355  Log() << kINFO << GetVarName(sel);
356  if (fCut->GetCutDoMax(i)) Log() << kINFO << " < " << Form("%10.3g",valmax);
357  else Log() << kINFO << " ";
358  Log() << Endl;
359  }
360 }
361 
362 ////////////////////////////////////////////////////////////////////////////////
363 /// extensive print function used to print info for the weight file
364 
365 void TMVA::Rule::PrintRaw( std::ostream& os ) const
366 {
367  Int_t dp = os.precision();
368  const UInt_t nvars = fCut->GetNvars();
369  os << "Parameters: "
370  << std::setprecision(10)
371  << fImportance << " "
372  << fImportanceRef << " "
373  << fCoefficient << " "
374  << fSupport << " "
375  << fSigma << " "
376  << fNorm << " "
377  << fSSB << " "
378  << fSSBNeve << " "
379  << std::endl; \
380  os << "N(cuts): " << nvars << std::endl; // mark end of nodes
381  for ( UInt_t i=0; i<nvars; i++) {
382  os << "Cut " << i << " : " << std::flush;
383  os << fCut->GetSelector(i)
384  << std::setprecision(10)
385  << " " << fCut->GetCutMin(i)
386  << " " << fCut->GetCutMax(i)
387  << " " << (fCut->GetCutDoMin(i) ? "T":"F")
388  << " " << (fCut->GetCutDoMax(i) ? "T":"F")
389  << std::endl;
390  }
391  os << std::setprecision(dp);
392 }
393 
394 ////////////////////////////////////////////////////////////////////////////////
395 
396 void* TMVA::Rule::AddXMLTo( void* parent ) const
397 {
398  void* rule = gTools().AddChild( parent, "Rule" );
399  const UInt_t nvars = fCut->GetNvars();
400 
401  gTools().AddAttr( rule, "Importance", fImportance );
402  gTools().AddAttr( rule, "Ref", fImportanceRef );
403  gTools().AddAttr( rule, "Coeff", fCoefficient );
404  gTools().AddAttr( rule, "Support", fSupport );
405  gTools().AddAttr( rule, "Sigma", fSigma );
406  gTools().AddAttr( rule, "Norm", fNorm );
407  gTools().AddAttr( rule, "SSB", fSSB );
408  gTools().AddAttr( rule, "SSBNeve", fSSBNeve );
409  gTools().AddAttr( rule, "Nvars", nvars );
410 
411  for (UInt_t i=0; i<nvars; i++) {
412  void* cut = gTools().AddChild( rule, "Cut" );
413  gTools().AddAttr( cut, "Selector", fCut->GetSelector(i) );
414  gTools().AddAttr( cut, "Min", fCut->GetCutMin(i) );
415  gTools().AddAttr( cut, "Max", fCut->GetCutMax(i) );
416  gTools().AddAttr( cut, "DoMin", (fCut->GetCutDoMin(i) ? "T":"F") );
417  gTools().AddAttr( cut, "DoMax", (fCut->GetCutDoMax(i) ? "T":"F") );
418  }
419 
420  return rule;
421 }
422 
423 ////////////////////////////////////////////////////////////////////////////////
424 /// read rule from XML
425 
426 void TMVA::Rule::ReadFromXML( void* wghtnode )
427 {
428  TString nodeName = TString( gTools().GetName(wghtnode) );
429  if (nodeName != "Rule") Log() << kFATAL << "<ReadFromXML> Unexpected node name: " << nodeName << Endl;
430 
431  gTools().ReadAttr( wghtnode, "Importance", fImportance );
432  gTools().ReadAttr( wghtnode, "Ref", fImportanceRef );
433  gTools().ReadAttr( wghtnode, "Coeff", fCoefficient );
434  gTools().ReadAttr( wghtnode, "Support", fSupport );
435  gTools().ReadAttr( wghtnode, "Sigma", fSigma );
436  gTools().ReadAttr( wghtnode, "Norm", fNorm );
437  gTools().ReadAttr( wghtnode, "SSB", fSSB );
438  gTools().ReadAttr( wghtnode, "SSBNeve", fSSBNeve );
439 
440  UInt_t nvars;
441  gTools().ReadAttr( wghtnode, "Nvars", nvars );
442  if (fCut) delete fCut;
443  fCut = new RuleCut();
444  fCut->SetNvars( nvars );
445 
446  // read Cut
447  void* ch = gTools().GetChild( wghtnode );
448  UInt_t i = 0;
449  UInt_t ui;
450  Double_t d;
451  Char_t c;
452  while (ch) {
453  gTools().ReadAttr( ch, "Selector", ui );
454  fCut->SetSelector( i, ui );
455  gTools().ReadAttr( ch, "Min", d );
456  fCut->SetCutMin ( i, d );
457  gTools().ReadAttr( ch, "Max", d );
458  fCut->SetCutMax ( i, d );
459  gTools().ReadAttr( ch, "DoMin", c );
460  fCut->SetCutDoMin( i, (c == 'T' ? kTRUE : kFALSE ) );
461  gTools().ReadAttr( ch, "DoMax", c );
462  fCut->SetCutDoMax( i, (c == 'T' ? kTRUE : kFALSE ) );
463 
464  i++;
465  ch = gTools().GetNextChild(ch);
466  }
467 
468  // sanity check
469  if (i != nvars) Log() << kFATAL << "<ReadFromXML> Mismatch in number of cuts: " << i << " != " << nvars << Endl;
470 }
471 
472 ////////////////////////////////////////////////////////////////////////////////
473 /// read function (format is the same as written by PrintRaw)
474 
475 void TMVA::Rule::ReadRaw( std::istream& istr )
476 {
477  TString dummy;
478  UInt_t nvars;
479  istr >> dummy
480  >> fImportance
481  >> fImportanceRef
482  >> fCoefficient
483  >> fSupport
484  >> fSigma
485  >> fNorm
486  >> fSSB
487  >> fSSBNeve;
488  // coverity[tainted_data_argument]
489  istr >> dummy >> nvars;
490  Double_t cutmin,cutmax;
491  UInt_t sel,idum;
492  Char_t bA, bB;
493  //
494  if (fCut) delete fCut;
495  fCut = new RuleCut();
496  fCut->SetNvars( nvars );
497  for ( UInt_t i=0; i<nvars; i++) {
498  istr >> dummy >> idum; // get 'Node' and index
499  istr >> dummy; // get ':'
500  istr >> sel >> cutmin >> cutmax >> bA >> bB;
501  fCut->SetSelector(i,sel);
502  fCut->SetCutMin(i,cutmin);
503  fCut->SetCutMax(i,cutmax);
504  fCut->SetCutDoMin(i,(bA=='T' ? kTRUE:kFALSE));
505  fCut->SetCutDoMax(i,(bB=='T' ? kTRUE:kFALSE));
506  }
507 }