Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
CCPruner.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : CCPruner *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: Cost Complexity Pruning *
8  *
9  * Author: Doug Schouten (dschoute@sfu.ca)
10  *
11  * *
12  * Copyright (c) 2007: *
13  * CERN, Switzerland *
14  * MPI-K Heidelberg, Germany *
15  * U. of Texas at Austin, USA *
16  * *
17  * Redistribution and use in source and binary forms, with or without *
18  * modification, are permitted according to the terms listed in LICENSE *
19  * (http://tmva.sourceforge.net/LICENSE) *
20  **********************************************************************************/
21 
22 #include "TMVA/CCPruner.h"
23 #include "TMVA/SeparationBase.h"
24 #include "TMVA/GiniIndex.h"
26 #include "TMVA/CCTreeWrapper.h"
27 #include "TMVA/DataSet.h"
28 
29 #include "Rtypes.h"
30 
31 #include <iostream>
32 #include <fstream>
33 #include <limits>
34 #include <math.h>
35 
36 /*! \class TMVA::CCPruner
37 \ingroup TMVA
38 A helper class to prune a decision tree using the Cost Complexity method
39 (see Classification and Regression Trees by Leo Breiman et al)
40 
41 ### Some definitions:
42 
43  - \f$ T_{max} \f$ - the initial, usually highly overtrained tree, that is to be pruned back
44  - \f$ R(T) \f$ - quality index (Gini, misclassification rate, or other) of a tree \f$ T \f$
45  - \f$ \sim T \f$ - set of terminal nodes in \f$ T \f$
46  - \f$ T' \f$ - the pruned subtree of \f$ T_max \f$ that has the best quality index \f$ R(T') \f$
47  - \f$ \alpha \f$ - the prune strength parameter in Cost Complexity pruning \f$ (R_{\alpha}(T) = R(T) + \alpha*|\sim T|) \f$
48 
49 There are two running modes in CCPruner: (i) one may select a prune strength and prune back
50 the tree \f$ T_{max}\f$ until the criterion:
51 \f[
52  \alpha < \frac{R(T) - R(t)}{|\sim T_t| - 1}
53 \f]
54 
55 is true for all nodes t in \f$ T \f$, or (ii) the algorithm finds the sequence of critical points
56 \f$ \alpha_k < \alpha_{k+1} ... < \alpha_K \f$ such that \f$ T_K = root(T_{max}) \f$ and then selects the optimally-pruned
57 subtree, defined to be the subtree with the best quality index for the validation sample.
58 */
59 
60 namespace TMVA {
61  class DecisionTree;
62 }
63 
64 using namespace TMVA;
65 
66 ////////////////////////////////////////////////////////////////////////////////
67 /// constructor
68 
69 CCPruner::CCPruner( DecisionTree* t_max, const EventList* validationSample,
70  SeparationBase* qualityIndex ) :
71  fAlpha(-1.0),
72  fValidationSample(validationSample),
73  fValidationDataSet(NULL),
74  fOptimalK(-1)
75 {
76  fTree = t_max;
77 
78  if(qualityIndex == NULL) {
79  fOwnQIndex = true;
80  fQualityIndex = new MisClassificationError();
81  }
82  else {
83  fOwnQIndex = false;
84  fQualityIndex = qualityIndex;
85  }
86  fDebug = kTRUE;
87 }
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 /// constructor
91 
92 CCPruner::CCPruner( DecisionTree* t_max, const DataSet* validationSample,
93  SeparationBase* qualityIndex ) :
94  fAlpha(-1.0),
95  fValidationSample(NULL),
96  fValidationDataSet(validationSample),
97  fOptimalK(-1)
98 {
99  fTree = t_max;
100 
101  if(qualityIndex == NULL) {
102  fOwnQIndex = true;
103  fQualityIndex = new MisClassificationError();
104  }
105  else {
106  fOwnQIndex = false;
107  fQualityIndex = qualityIndex;
108  }
109  fDebug = kTRUE;
110 }
111 
112 
113 ////////////////////////////////////////////////////////////////////////////////
114 
115 CCPruner::~CCPruner( )
116 {
117  if(fOwnQIndex) delete fQualityIndex;
118  // destructor
119 }
120 
121 ////////////////////////////////////////////////////////////////////////////////
122 /// determine the pruning sequence
123 
124 void CCPruner::Optimize( )
125 {
126  Bool_t HaveStopCondition = fAlpha > 0; // keep pruning the tree until reach the limit fAlpha
127 
128  // build a wrapper tree to perform work on
129  CCTreeWrapper* dTWrapper = new CCTreeWrapper(fTree, fQualityIndex);
130 
131  Int_t k = 0;
132  Double_t epsilon = std::numeric_limits<double>::epsilon();
133  Double_t alpha = -1.0e10;
134 
135  std::ofstream outfile;
136  if (fDebug) outfile.open("costcomplexity.log");
137  if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
138  if (fDebug) outfile << "ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
139  delete dTWrapper;
140  if (fDebug) outfile.close();
141  return;
142  }
143 
144  CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
145  while(R->GetNLeafDaughters() > 1) { // prune upwards to the root node
146  if(R->GetMinAlphaC() > alpha)
147  alpha = R->GetMinAlphaC(); // initialize alpha
148 
149  if(HaveStopCondition && alpha > fAlpha) break;
150 
151  CCTreeWrapper::CCTreeNode* t = R;
152 
153  while(t->GetMinAlphaC() < t->GetAlphaC()) { // descend to the weakest link
154 
155  if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon)
156  t = t->GetLeftDaughter();
157  else
158  t = t->GetRightDaughter();
159  }
160 
161  if( t == R ) {
162  if (fDebug) outfile << std::endl << "Caught trying to prune the root node!" << std::endl;
163  break;
164  }
165 
166  CCTreeWrapper::CCTreeNode* n = t;
167 
168  if (fDebug){
169  outfile << "===========================" << std::endl
170  << "Pruning branch listed below" << std::endl
171  << "===========================" << std::endl;
172  t->PrintRec( outfile );
173 
174  }
175  if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
176  break;
177  }
178  dTWrapper->PruneNode(t); // prune the branch rooted at node t
179 
180  while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
181  t = t->GetMother();
182  t->SetNLeafDaughters(t->GetLeftDaughter()->GetNLeafDaughters() + t->GetRightDaughter()->GetNLeafDaughters());
183  t->SetResubstitutionEstimate(t->GetLeftDaughter()->GetResubstitutionEstimate() +
184  t->GetRightDaughter()->GetResubstitutionEstimate());
185  t->SetAlphaC((t->GetNodeResubstitutionEstimate() - t->GetResubstitutionEstimate())/(t->GetNLeafDaughters() - 1));
186  t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
187  t->GetRightDaughter()->GetMinAlphaC())));
188  }
189  k += 1;
190  if(!HaveStopCondition) {
191  Double_t q;
192  if (fValidationDataSet != NULL) q = dTWrapper->TestTreeQuality(fValidationDataSet);
193  else q = dTWrapper->TestTreeQuality(fValidationSample);
194  fQualityIndexList.push_back(q);
195  }
196  else {
197  fQualityIndexList.push_back(1.0);
198  }
199  fPruneSequence.push_back(n->GetDTNode());
200  fPruneStrengthList.push_back(alpha);
201  }
202 
203  Double_t qmax = -1.0e6;
204  if(!HaveStopCondition) {
205  for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
206  if(fQualityIndexList[i] > qmax) {
207  qmax = fQualityIndexList[i];
208  k = i;
209  }
210  }
211  fOptimalK = k;
212  }
213  else {
214  fOptimalK = fPruneSequence.size() - 1;
215  }
216 
217  if (fDebug){
218  outfile << std::endl << "************ Summary **************" << std::endl
219  << "Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
220 
221  outfile << "Pruning strength parameters: [";
222  for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
223  outfile << fPruneStrengthList[i] << ", ";
224  outfile << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << std::endl;
225 
226  outfile << "Misclassification rates: [";
227  for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
228  outfile << fQualityIndexList[i] << ", ";
229  outfile << fQualityIndexList[fQualityIndexList.size()-1] << "]" << std::endl;
230 
231  outfile << "Optimal index: " << fOptimalK+1 << std::endl;
232  outfile.close();
233  }
234  delete dTWrapper;
235 }
236 
237 ////////////////////////////////////////////////////////////////////////////////
238 /// return the prune strength (=alpha) corresponding to the prune sequence
239 
240 std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( ) const
241 {
242  std::vector<DecisionTreeNode*> optimalSequence;
243  if( fOptimalK >= 0 ) {
244  for( Int_t i = 0; i < fOptimalK; i++ ) {
245  optimalSequence.push_back(fPruneSequence[i]);
246  }
247  }
248  return optimalSequence;
249 }
250 
251