32 #ifndef ROOT_TMVA_DecisionTree
33 #define ROOT_TMVA_DecisionTree
64 class DecisionTree :
public BinaryTree {
68 static const Int_t fgRandomSeed;
72 typedef std::vector<TMVA::Event*> EventList;
73 typedef std::vector<const TMVA::Event*> EventConstList;
79 DecisionTree( SeparationBase *sepType, Float_t minSize,
80 Int_t nCuts, DataSetInfo* = NULL,
82 Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
83 UInt_t nMaxDepth=9999999,
84 Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
88 DecisionTree (
const DecisionTree &d);
90 virtual ~DecisionTree(
void );
93 virtual DecisionTreeNode* GetRoot()
const {
return static_cast<TMVA::DecisionTreeNode*
>(fRoot); }
94 virtual DecisionTreeNode * CreateNode(UInt_t)
const {
return new DecisionTreeNode(); }
95 virtual BinaryTree* CreateTree()
const {
return new DecisionTree(); }
96 static DecisionTree* CreateFromXML(
void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
97 virtual const char* ClassName()
const {
return "DecisionTree"; }
103 UInt_t BuildTree(
const EventConstList & eventSample,
104 DecisionTreeNode *node = NULL);
107 Double_t TrainNode(
const EventConstList & eventSample, DecisionTreeNode *node ) {
return TrainNodeFast( eventSample, node ); }
108 Double_t TrainNodeFast(
const EventConstList & eventSample, DecisionTreeNode *node );
109 Double_t TrainNodeFull(
const EventConstList & eventSample, DecisionTreeNode *node );
110 void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
111 std::vector<Double_t> GetFisherCoefficients(
const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
116 void FillTree(
const EventList & eventSample);
120 void FillEvent(
const TMVA::Event & event,
121 TMVA::DecisionTreeNode *node );
125 Double_t CheckEvent(
const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE )
const;
126 TMVA::DecisionTreeNode* GetEventNode(
const TMVA::Event & e)
const;
129 std::vector< Double_t > GetVariableImportance();
131 Double_t GetVariableImportance(UInt_t ivar);
138 enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
139 void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
142 Double_t PruneTree(
const EventConstList* validationSample = NULL );
145 void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
146 Double_t GetPruneStrength( )
const {
return fPruneStrength; }
149 void ApplyValidationSample(
const EventConstList* validationSample )
const;
152 Double_t TestPrunedTreeQuality(
const DecisionTreeNode* dt = NULL, Int_t mode=0 )
const;
155 void CheckEventWithPrunedTree(
const TMVA::Event* )
const;
158 Double_t GetSumWeights(
const EventConstList* validationSample )
const;
160 void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
161 Double_t GetNodePurityLimit( )
const {
return fNodePurityLimit; }
163 void DescendTree( Node *n = NULL );
164 void SetParentTreeInNodes( Node *n = NULL );
169 Node* GetNode( ULong_t sequence, UInt_t depth );
171 UInt_t CleanTree(DecisionTreeNode *node=NULL);
173 void PruneNode(TMVA::DecisionTreeNode *node);
177 void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
179 Int_t GetNNodesBeforePruning(){
return (fNNodesBeforePruning)?fNNodesBeforePruning:fNNodesBeforePruning=GetNNodes();}
182 UInt_t CountLeafNodes(TMVA::Node *n = NULL);
184 void SetTreeID(Int_t treeID){fTreeID = treeID;};
185 Int_t GetTreeID(){
return fTreeID;};
187 Bool_t DoRegression()
const {
return fAnalysisType == Types::kRegression; }
188 void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
189 Types::EAnalysisType GetAnalysisType (
void ) {
return fAnalysisType;}
190 inline void SetUseFisherCuts(Bool_t t=kTRUE) { fUseFisherCuts = t;}
191 inline void SetMinLinCorrForFisher(Double_t min){fMinLinCorrForFisher = min;}
192 inline void SetUseExclusiveVars(Bool_t t=kTRUE){fUseExclusiveVars = t;}
193 inline void SetNVars(Int_t n){fNvars = n;}
202 Double_t SamplePurity(EventList eventSample);
206 Bool_t fUseFisherCuts;
207 Double_t fMinLinCorrForFisher;
208 Bool_t fUseExclusiveVars;
210 SeparationBase *fSepType;
211 RegressionVariance *fRegType;
214 Double_t fMinNodeSize;
215 Double_t fMinSepGain;
217 Bool_t fUseSearchTree;
218 Double_t fPruneStrength;
220 EPruneMethod fPruneMethod;
221 Int_t fNNodesBeforePruning;
223 Double_t fNodePurityLimit;
225 Bool_t fRandomisedTree;
227 Bool_t fUsePoissonNvars;
229 TRandom3 *fMyTrandom;
231 std::vector< Double_t > fVariableImportance;
235 static const Int_t fgDebugLevel = 0;
238 Types::EAnalysisType fAnalysisType;
240 DataSetInfo* fDataSetInfo;
242 ClassDef(DecisionTree,0);