Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
CCTreeWrapper.h
Go to the documentation of this file.
1 
2 /**********************************************************************************
3  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
4  * Package: TMVA *
5  * Class : CCTreeWrapper *
6  * Web : http://tmva.sourceforge.net *
7  * *
8  * Description: a light wrapper of a decision tree, used to perform cost *
9  * complexity pruning "in-place" Cost Complexity Pruning *
10  * *
11  * Author: Doug Schouten (dschoute@sfu.ca) *
12  * *
13  * *
14  * Copyright (c) 2007: *
15  * CERN, Switzerland *
16  * MPI-K Heidelberg, Germany *
17  * U. of Texas at Austin, USA *
18  * *
19  * Redistribution and use in source and binary forms, with or without *
20  * modification, are permitted according to the terms listed in LICENSE *
21  * (http://tmva.sourceforge.net/LICENSE) *
22  **********************************************************************************/
23 
24 #ifndef ROOT_TMVA_CCTreeWrapper
25 #define ROOT_TMVA_CCTreeWrapper
26 
27 #include "TMVA/Event.h"
28 #include "TMVA/SeparationBase.h"
29 #include "TMVA/DecisionTree.h"
30 #include "TMVA/DataSet.h"
31 #include "TMVA/Version.h"
32 
33 
34 namespace TMVA {
35 
36  class CCTreeWrapper {
37 
38  public:
39 
40  typedef std::vector<Event*> EventList;
41 
42  /////////////////////////////////////////////////////////////
43  // CCTreeNode - a light wrapper of a decision tree node //
44  // //
45  /////////////////////////////////////////////////////////////
46 
47  class CCTreeNode : virtual public Node {
48 
49  public:
50 
51  CCTreeNode( DecisionTreeNode* n = NULL );
52  virtual ~CCTreeNode( );
53 
54  virtual Node* CreateNode() const { return new CCTreeNode(); }
55 
56  // set |~T_t|, the number of terminal descendants of node t
57  inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }
58 
59  // return |~T_t|
60  inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }
61 
62  // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t
63  inline void SetNodeResubstitutionEstimate( Double_t R ) { fNodeResubstitutionEstimate = (R >= 0 ? R : 0.0); }
64 
65  // return R(t) for node t
66  inline Double_t GetNodeResubstitutionEstimate( ) const { return fNodeResubstitutionEstimate; }
67 
68  // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at
69  // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree)
70  inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ? R : 0.0); }
71 
72  // return R(T_t) for node t
73  inline Double_t GetResubstitutionEstimate( ) const { return fResubstitutionEstimate; }
74 
75  // set the critical point of alpha
76  // R(t) - R(T_t)
77  // alpha_c < ------------- := g(t)
78  // |~T_t| - 1
79  // which is the value of alpha such that the branch rooted at node t is pruned
80  inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }
81 
82  // get the critical alpha value for this node
83  inline Double_t GetAlphaC( ) const { return fAlphaC; }
84 
85  // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) )
86  inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }
87 
88  // get the minimum critical alpha value
89  inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }
90 
91  // get the pointer to the wrapped DT node
92  inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }
93 
94  // get pointers to children, mother in the CC tree
95  inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
96  inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
97  inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }
98 
99  // printout of the node (can be read in with ReadDataRecord)
100  virtual void Print( std::ostream& os ) const;
101 
102  // recursive printout of the node and its daughters
103  virtual void PrintRec ( std::ostream& os ) const;
104 
105  virtual void AddAttributesToNode(void* node) const;
106  virtual void AddContentToNode(std::stringstream& s) const;
107 
108 
109  // test event if it decends the tree at this node to the right
110  inline virtual Bool_t GoesRight( const Event& e ) const { return (GetDTNode() != NULL ?
111  GetDTNode()->GoesRight(e) : false); }
112 
113  // test event if it decends the tree at this node to the left
114  inline virtual Bool_t GoesLeft ( const Event& e ) const { return (GetDTNode() != NULL ?
115  GetDTNode()->GoesLeft(e) : false); }
116  // initialize a node from a data record
117  virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
118  virtual void ReadContent(std::stringstream& s);
119  virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
120 
121  private:
122 
123  Int_t fNLeafDaughters; //! number of terminal descendants
124  Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t
125  Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) }
126  Double_t fAlphaC; //! critical point, g(t) = alpha_c(t)
127  Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants
128  DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree
129  };
130 
131  CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex );
132  ~CCTreeWrapper( );
133 
134  // return the decision tree output for an event
135  Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
136  // return the misclassification rate of a pruned tree for a validation event sample
137  Double_t TestTreeQuality( const EventList* validationSample );
138  Double_t TestTreeQuality( const DataSet* validationSample );
139 
140  // remove the branch rooted at node t
141  void PruneNode( CCTreeNode* t );
142  // initialize the node t and all its descendants
143  void InitTree( CCTreeNode* t );
144 
145  // return the root node for this tree
146  CCTreeNode* GetRoot() { return fRoot; }
147  private:
148  SeparationBase* fQualityIndex; //! pointer to the used quality index calculator
149  DecisionTree* fDTParent; //! pointer to underlying DecisionTree
150  CCTreeNode* fRoot; //! the root node of the (wrapped) decision Tree
151  };
152 
153 }
154 
155 #endif
156 
157 
158