Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
ExpectedErrorPruneTool.h
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : TMVA::DecisionTree *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: *
8  * Implementation of a Decision Tree *
9  * *
10  * Authors (alphabetical): *
11  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
12  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
13  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
14  * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://mva.sourceforge.net/license.txt) *
24  * *
25  **********************************************************************************/
26 
27 #ifndef ROOT_TMVA_ExpectedErrorPruneTool
28 #define ROOT_TMVA_ExpectedErrorPruneTool
29 
30 /////////////////////////////////////////////////////////////////////////////////////////////////////////////
31 // ExpectedErrorPruneTool - a helper class to prune a decision tree using the expected error (C4.5) method //
32 // //
33 // Uses an upper limit on the error made by the classification done by each node. If the S/S+B of the node //
34 // is f, then according to the training sample, the error rate (fraction of misclassified events by this //
35 // node) is (1-f). Now f has a statistical error according to the binomial distribution hence the error on //
36 // f can be estimated (same error as the binomial error for efficiency calculations //
37 // ( sigma = sqrt(eff(1-eff)/nEvts ) ) //
38 // //
39 // This tool prunes branches from a tree if the expected error of a node is less than that of the sum of //
40 // the error in its descendants. //
41 // //
42 /////////////////////////////////////////////////////////////////////////////////////////////////////////////
43 
44 #include <vector>
45 #include <map>
46 
47 #include "TMath.h"
48 
49 #include "TMVA/IPruneTool.h"
50 
51 namespace TMVA {
52 
53  class MsgLogger;
54 
55  class ExpectedErrorPruneTool : public IPruneTool {
56  public:
57  ExpectedErrorPruneTool( );
58  virtual ~ExpectedErrorPruneTool( );
59 
60  // returns the PruningInfo object for a given tree and test sample
61  virtual PruningInfo* CalculatePruningInfo( DecisionTree* dt, const IPruneTool::EventSample* testEvents = NULL,
62  Bool_t isAutomatic = kFALSE );
63 
64  public:
65  // set the increment dalpha with which to scan for the optimal prune strength
66  inline void SetPruneStrengthIncrement( Double_t dalpha ) { fDeltaPruneStrength = dalpha; }
67 
68  private:
69  void FindListOfNodes( DecisionTreeNode* node );
70  Double_t GetNodeError( DecisionTreeNode* node ) const;
71  Double_t GetSubTreeError( DecisionTreeNode* node ) const;
72  Int_t CountNodes( DecisionTreeNode* node, Int_t icount = 0 );
73 
74  Double_t fDeltaPruneStrength; //! the stepsize for optimizing the pruning strength parameter
75  Double_t fNodePurityLimit; //! the purity limit for labelling a terminal node as signal
76  std::vector<DecisionTreeNode*> fPruneSequence; //! the (optimal) prune sequence
77  // std::multimap<const Double_t, Double_t> fQualityMap; //! map of tree quality <=> prune strength
78  mutable MsgLogger* fLogger; // message logger
79  MsgLogger& Log() const { return *fLogger; }
80  };
81 
82  inline Int_t ExpectedErrorPruneTool::CountNodes( DecisionTreeNode* node, Int_t icount ) {
83  DecisionTreeNode* l = (DecisionTreeNode*)node->GetLeft();
84  DecisionTreeNode* r = (DecisionTreeNode*)node->GetRight();
85  Int_t counter = icount + 1; // count this node
86  if(!(node->IsTerminal()) && l != NULL && r != NULL) {
87  counter = CountNodes(l,counter);
88  counter = CountNodes(r,counter);
89  }
90  return counter;
91  }
92 }
93 
94 #endif
95