Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
DecisionTree.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DecisionTree *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation of a Decision Tree *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
19  * *
20  * Copyright (c) 2005-2011: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * U. of Bonn, Germany *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://mva.sourceforge.net/license.txt) *
29  * *
30  **********************************************************************************/
31 
32 #ifndef ROOT_TMVA_DecisionTree
33 #define ROOT_TMVA_DecisionTree
34 
35 //////////////////////////////////////////////////////////////////////////
36 // //
37 // DecisionTree //
38 // //
39 // Implementation of a Decision Tree //
40 // //
41 //////////////////////////////////////////////////////////////////////////
42 
43 #include "TH2.h"
44 
45 #include "TMVA/Types.h"
46 #include "TMVA/DecisionTreeNode.h"
47 #include "TMVA/BinaryTree.h"
48 #include "TMVA/BinarySearchTree.h"
49 #include "TMVA/SeparationBase.h"
51 #include "TMVA/DataSetInfo.h"
52 
53 #ifdef R__USE_IMT
54 #include <ROOT/TThreadExecutor.hxx>
55 #include "TSystem.h"
56 #endif
57 
58 class TRandom3;
59 
60 namespace TMVA {
61 
62  class Event;
63 
64  class DecisionTree : public BinaryTree {
65 
66  private:
67 
68  static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
69 
70  public:
71 
72  typedef std::vector<TMVA::Event*> EventList;
73  typedef std::vector<const TMVA::Event*> EventConstList;
74 
75  // the constructur needed for the "reading" of the decision tree from weight files
76  DecisionTree( void );
77 
78  // the constructur needed for constructing the decision tree via training with events
79  DecisionTree( SeparationBase *sepType, Float_t minSize,
80  Int_t nCuts, DataSetInfo* = NULL,
81  UInt_t cls =0,
82  Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
83  UInt_t nMaxDepth=9999999,
84  Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
85  Int_t treeID = 0);
86 
87  // copy constructor
88  DecisionTree (const DecisionTree &d);
89 
90  virtual ~DecisionTree( void );
91 
92  // Retrieves the address of the root node
93  virtual DecisionTreeNode* GetRoot() const { return static_cast<TMVA::DecisionTreeNode*>(fRoot); }
94  virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
95  virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
96  static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
97  virtual const char* ClassName() const { return "DecisionTree"; }
98 
99  // building of a tree by recursivly splitting the nodes
100 
101  // UInt_t BuildTree( const EventList & eventSample,
102  // DecisionTreeNode *node = NULL);
103  UInt_t BuildTree( const EventConstList & eventSample,
104  DecisionTreeNode *node = NULL);
105  // determine the way how a node is split (which variable, which cut value)
106 
107  Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
108  Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
109  Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
110  void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
111  std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
112 
113  // fill at tree with a given structure already (just see how many signa/bkgr
114  // events end up in each node
115 
116  void FillTree( const EventList & eventSample);
117 
118  // fill the existing the decision tree structure by filling event
119  // in from the top node and see where they happen to end up
120  void FillEvent( const TMVA::Event & event,
121  TMVA::DecisionTreeNode *node );
122 
123  // returns: 1 = Signal (right), -1 = Bkg (left)
124 
125  Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
126  TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event & e) const;
127 
128  // return the individual relative variable importance
129  std::vector< Double_t > GetVariableImportance();
130 
131  Double_t GetVariableImportance(UInt_t ivar);
132 
133  // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
134 
135  void ClearTree();
136 
137  // set pruning method
138  enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
139  void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
140 
141  // recursive pruning of the tree, validation sample required for automatic pruning
142  Double_t PruneTree( const EventConstList* validationSample = NULL );
143 
144  // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
145  void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
146  Double_t GetPruneStrength( ) const { return fPruneStrength; }
147 
148  // apply pruning validation sample to a decision tree
149  void ApplyValidationSample( const EventConstList* validationSample ) const;
150 
151  // return the misclassification rate of a pruned tree
152  Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
153 
154  // pass a single validation event throught a pruned decision tree
155  void CheckEventWithPrunedTree( const TMVA::Event* ) const;
156 
157  // calculate the normalization factor for a pruning validation sample
158  Double_t GetSumWeights( const EventConstList* validationSample ) const;
159 
160  void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
161  Double_t GetNodePurityLimit( ) const { return fNodePurityLimit; }
162 
163  void DescendTree( Node *n = NULL );
164  void SetParentTreeInNodes( Node *n = NULL );
165 
166  // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
167  // is coded as a sequence of left-right moves starting from the root, coded as
168  // 0-1 bit patterns stored in the "long-integer" together with the depth
169  Node* GetNode( ULong_t sequence, UInt_t depth );
170 
171  UInt_t CleanTree(DecisionTreeNode *node=NULL);
172 
173  void PruneNode(TMVA::DecisionTreeNode *node);
174 
175  // prune a node from the tree without deleting its descendants; allows one to
176  // effectively prune a tree many times without making deep copies
177  void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
178 
179  Int_t GetNNodesBeforePruning(){return (fNNodesBeforePruning)?fNNodesBeforePruning:fNNodesBeforePruning=GetNNodes();}
180 
181 
182  UInt_t CountLeafNodes(TMVA::Node *n = NULL);
183 
184  void SetTreeID(Int_t treeID){fTreeID = treeID;};
185  Int_t GetTreeID(){return fTreeID;};
186 
187  Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
188  void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
189  Types::EAnalysisType GetAnalysisType ( void ) { return fAnalysisType;}
190  inline void SetUseFisherCuts(Bool_t t=kTRUE) { fUseFisherCuts = t;}
191  inline void SetMinLinCorrForFisher(Double_t min){fMinLinCorrForFisher = min;}
192  inline void SetUseExclusiveVars(Bool_t t=kTRUE){fUseExclusiveVars = t;}
193  inline void SetNVars(Int_t n){fNvars = n;}
194 
195  private:
196  // utility functions
197 
198  // calculate the Purity out of the number of sig and bkg events collected
199  // from individual samples.
200 
201  // calculates the purity S/(S+B) of a given event sample
202  Double_t SamplePurity(EventList eventSample);
203 
204  UInt_t fNvars; // number of variables used to separate S and B
205  Int_t fNCuts; // number of grid point in variable cut scans
206  Bool_t fUseFisherCuts; // use multivariate splits using the Fisher criterium
207  Double_t fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
208  Bool_t fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
209 
210  SeparationBase *fSepType; // the separation crition
211  RegressionVariance *fRegType; // the separation crition used in Regression
212 
213  Double_t fMinSize; // min number of events in node
214  Double_t fMinNodeSize; // min fraction of training events in node
215  Double_t fMinSepGain; // min number of separation gain to perform node splitting
216 
217  Bool_t fUseSearchTree; // cut scan done with binary trees or simple event loop.
218  Double_t fPruneStrength; // a parameter to set the "amount" of pruning..needs to be adjusted
219 
220  EPruneMethod fPruneMethod; // method used for prunig
221  Int_t fNNodesBeforePruning; //remember this one (in case of pruning, it allows to monitor the before/after
222 
223  Double_t fNodePurityLimit;// purity limit to decide whether a node is signal
224 
225  Bool_t fRandomisedTree; // choose at each node splitting a random set of variables
226  Int_t fUseNvars; // the number of variables used in randomised trees;
227  Bool_t fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
228 
229  TRandom3 *fMyTrandom; // random number generator for randomised trees
230 
231  std::vector< Double_t > fVariableImportance; // the relative importance of the different variables
232 
233  UInt_t fMaxDepth; // max depth
234  UInt_t fSigClass; // class which is treated as signal when building the tree
235  static const Int_t fgDebugLevel = 0; // debug level determining some printout/control plots etc.
236  Int_t fTreeID; // just an ID number given to the tree.. makes debugging easier as tree knows who he is.
237 
238  Types::EAnalysisType fAnalysisType; // kClassification(=0=false) or kRegression(=1=true)
239 
240  DataSetInfo* fDataSetInfo;
241 
242  ClassDef(DecisionTree,0); // implementation of a Decision Tree
243  };
244 
245 } // namespace TMVA
246 
247 #endif