30 #ifndef ROOT_TMVA_DecisionTreeNode
31 #define ROOT_TMVA_DecisionTreeNode
50 class DTNodeTrainingInfo
53 DTNodeTrainingInfo():fSampleMin(),
55 fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
56 fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0),
57 fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
59 fNSigEvents_unweighted ( 0 ),
60 fNBkgEvents_unweighted ( 0 ),
61 fNEvents_unweighted ( 0 ),
62 fNSigEvents_unboosted ( 0 ),
63 fNBkgEvents_unboosted ( 0 ),
64 fNEvents_unboosted ( 0 ),
65 fSeparationIndex (-1 ),
66 fSeparationGain ( -1 )
69 std::vector< Float_t > fSampleMin;
70 std::vector< Float_t > fSampleMax;
85 Float_t fNSigEvents_unweighted;
86 Float_t fNBkgEvents_unweighted;
87 Float_t fNEvents_unweighted;
88 Float_t fNSigEvents_unboosted;
89 Float_t fNBkgEvents_unboosted;
90 Float_t fNEvents_unboosted;
91 Float_t fSeparationIndex;
92 Float_t fSeparationGain;
95 DTNodeTrainingInfo(
const DTNodeTrainingInfo& n) :
96 fSampleMin(),fSampleMax(),
97 fNodeR(n.fNodeR), fSubTreeR(n.fSubTreeR),
98 fAlpha(n.fAlpha), fG(n.fG),
99 fNTerminal(n.fNTerminal),
100 fNB(n.fNB), fNS(n.fNS),
101 fSumTarget(0),fSumTarget2(0),
103 fNSigEvents ( n.fNSigEvents ), fNBkgEvents ( n.fNBkgEvents ),
104 fNEvents ( n.fNEvents ),
105 fNSigEvents_unweighted ( n.fNSigEvents_unweighted ),
106 fNBkgEvents_unweighted ( n.fNBkgEvents_unweighted ),
107 fNEvents_unweighted ( n.fNEvents_unweighted ),
108 fSeparationIndex( n.fSeparationIndex ),
109 fSeparationGain ( n.fSeparationGain )
116 class DecisionTreeNode:
public Node {
123 DecisionTreeNode (Node* p,
char pos);
126 DecisionTreeNode (
const DecisionTreeNode &n, DecisionTreeNode* parent = NULL);
129 virtual ~DecisionTreeNode();
131 virtual Node* CreateNode()
const {
return new DecisionTreeNode(); }
133 inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
134 inline UInt_t GetNFisherCoeff()
const {
return fFisherCoeff.size();}
136 void SetFisherCoeff(Int_t ivar, Double_t coeff);
138 Double_t GetFisherCoeff(Int_t ivar)
const {
return fFisherCoeff.at(ivar);}
141 virtual Bool_t GoesRight(
const Event & )
const;
144 virtual Bool_t GoesLeft (
const Event & )
const;
147 void SetSelector( Short_t i) { fSelector = i; }
149 Short_t GetSelector()
const {
return fSelector; }
152 void SetCutValue ( Float_t c ) { fCutValue = c; }
154 Float_t GetCutValue (
void )
const {
return fCutValue; }
157 void SetCutType( Bool_t t ) { fCutType = t; }
159 Bool_t GetCutType(
void )
const {
return fCutType; }
162 void SetNodeType( Int_t t ) { fNodeType = t;}
164 Int_t GetNodeType(
void )
const {
return fNodeType; }
167 Float_t GetPurity(
void )
const {
return fPurity;}
169 void SetPurity(
void );
172 void SetResponse( Float_t r ) { fResponse = r;}
175 Float_t GetResponse(
void )
const {
return fResponse;}
178 void SetRMS( Float_t r ) { fRMS = r;}
181 Float_t GetRMS(
void )
const {
return fRMS;}
184 void SetNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents = s; }
187 void SetNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents = b; }
190 void SetNEvents( Float_t nev ){ fTrainInfo->fNEvents =nev ; }
193 void SetNSigEvents_unweighted( Float_t s ) { fTrainInfo->fNSigEvents_unweighted = s; }
196 void SetNBkgEvents_unweighted( Float_t b ) { fTrainInfo->fNBkgEvents_unweighted = b; }
199 void SetNEvents_unweighted( Float_t nev ){ fTrainInfo->fNEvents_unweighted =nev ; }
202 void SetNSigEvents_unboosted( Float_t s ) { fTrainInfo->fNSigEvents_unboosted = s; }
205 void SetNBkgEvents_unboosted( Float_t b ) { fTrainInfo->fNBkgEvents_unboosted = b; }
208 void SetNEvents_unboosted( Float_t nev ){ fTrainInfo->fNEvents_unboosted =nev ; }
211 void IncrementNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents += s; }
214 void IncrementNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents += b; }
217 void IncrementNEvents( Float_t nev ){ fTrainInfo->fNEvents +=nev ; }
220 void IncrementNSigEvents_unweighted( ) { fTrainInfo->fNSigEvents_unweighted += 1; }
223 void IncrementNBkgEvents_unweighted( ) { fTrainInfo->fNBkgEvents_unweighted += 1; }
226 void IncrementNEvents_unweighted( ){ fTrainInfo->fNEvents_unweighted +=1 ; }
229 Float_t GetNSigEvents(
void )
const {
return fTrainInfo->fNSigEvents; }
232 Float_t GetNBkgEvents(
void )
const {
return fTrainInfo->fNBkgEvents; }
235 Float_t GetNEvents(
void )
const {
return fTrainInfo->fNEvents; }
238 Float_t GetNSigEvents_unweighted(
void )
const {
return fTrainInfo->fNSigEvents_unweighted; }
241 Float_t GetNBkgEvents_unweighted(
void )
const {
return fTrainInfo->fNBkgEvents_unweighted; }
244 Float_t GetNEvents_unweighted(
void )
const {
return fTrainInfo->fNEvents_unweighted; }
247 Float_t GetNSigEvents_unboosted(
void )
const {
return fTrainInfo->fNSigEvents_unboosted; }
250 Float_t GetNBkgEvents_unboosted(
void )
const {
return fTrainInfo->fNBkgEvents_unboosted; }
253 Float_t GetNEvents_unboosted(
void )
const {
return fTrainInfo->fNEvents_unboosted; }
257 void SetSeparationIndex( Float_t sep ){ fTrainInfo->fSeparationIndex =sep ; }
259 Float_t GetSeparationIndex(
void )
const {
return fTrainInfo->fSeparationIndex; }
262 void SetSeparationGain( Float_t sep ){ fTrainInfo->fSeparationGain =sep ; }
264 Float_t GetSeparationGain(
void )
const {
return fTrainInfo->fSeparationGain; }
267 virtual void Print( std::ostream& os )
const;
270 virtual void PrintRec( std::ostream& os )
const;
272 virtual void AddAttributesToNode(
void* node)
const;
273 virtual void AddContentToNode(std::stringstream& s)
const;
276 void ClearNodeAndAllDaughters();
281 inline virtual DecisionTreeNode* GetLeft( )
const {
return static_cast<DecisionTreeNode*
>(fLeft); }
282 inline virtual DecisionTreeNode* GetRight( )
const {
return static_cast<DecisionTreeNode*
>(fRight); }
283 inline virtual DecisionTreeNode* GetParent( )
const {
return static_cast<DecisionTreeNode*
>(fParent); }
286 inline virtual void SetLeft (Node* l) { fLeft = l;}
287 inline virtual void SetRight (Node* r) { fRight = r;}
288 inline virtual void SetParent(Node* p) { fParent = p;}
294 inline void SetNodeR( Double_t r ) { fTrainInfo->fNodeR = r; }
295 inline Double_t GetNodeR( )
const {
return fTrainInfo->fNodeR; }
298 inline void SetSubTreeR( Double_t r ) { fTrainInfo->fSubTreeR = r; }
299 inline Double_t GetSubTreeR( )
const {
return fTrainInfo->fSubTreeR; }
304 inline void SetAlpha( Double_t alpha ) { fTrainInfo->fAlpha = alpha; }
305 inline Double_t GetAlpha( )
const {
return fTrainInfo->fAlpha; }
308 inline void SetAlphaMinSubtree( Double_t g ) { fTrainInfo->fG = g; }
309 inline Double_t GetAlphaMinSubtree( )
const {
return fTrainInfo->fG; }
312 inline void SetNTerminal( Int_t n ) { fTrainInfo->fNTerminal = n; }
313 inline Int_t GetNTerminal( )
const {
return fTrainInfo->fNTerminal; }
316 inline void SetNBValidation( Double_t b ) { fTrainInfo->fNB = b; }
317 inline void SetNSValidation( Double_t s ) { fTrainInfo->fNS = s; }
318 inline Double_t GetNBValidation( )
const {
return fTrainInfo->fNB; }
319 inline Double_t GetNSValidation( )
const {
return fTrainInfo->fNS; }
322 inline void SetSumTarget(Float_t t) {fTrainInfo->fSumTarget = t; }
323 inline void SetSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 = t2; }
325 inline void AddToSumTarget(Float_t t) {fTrainInfo->fSumTarget += t; }
326 inline void AddToSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 += t2; }
328 inline Float_t GetSumTarget()
const {
return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
329 inline Float_t GetSumTarget2()
const {
return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
333 void ResetValidationData( );
336 inline Bool_t IsTerminal()
const {
return fIsTerminalNode; }
337 inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s; }
338 void PrintPrune( std::ostream& os )
const ;
339 void PrintRecPrune( std::ostream& os )
const;
341 void SetCC(Double_t cc);
342 Double_t GetCC()
const {
return (fTrainInfo? fTrainInfo->fCC : -1.);}
344 Float_t GetSampleMin(UInt_t ivar)
const;
345 Float_t GetSampleMax(UInt_t ivar)
const;
346 void SetSampleMin(UInt_t ivar, Float_t xmin);
347 void SetSampleMax(UInt_t ivar, Float_t xmax);
349 static bool fgIsTraining;
350 static UInt_t fgTmva_Version_Code;
352 virtual Bool_t ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
353 virtual void ReadAttributes(
void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
354 virtual void ReadContent(std::stringstream& s);
358 static MsgLogger& Log();
360 std::vector<Double_t> fFisherCoeff;
371 Bool_t fIsTerminalNode;
373 mutable DTNodeTrainingInfo* fTrainInfo;
377 ClassDef(DecisionTreeNode,0);