Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
DecisionTreeNode.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : TMVA::DecisionTreeNode *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation of a Decision Tree Node *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17  * Eckhard von Toerne <evt@physik.uni-bonn.de> - U. of Bonn, Germany *
18  * *
19  * CopyRight (c) 2009: *
20  * CERN, Switzerland *
21  * U. of Victoria, Canada *
22  * MPI-K Heidelberg, Germany *
23  * U. of Bonn, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 /*! \class TMVA::
31 \ingroup TMVA
32 
33 Node for the Decision Tree.
34 
35 The node specifies ONE variable out of the given set of selection variable
36 that is used to split the sample which "arrives" at the node, into a left
37 (background-enhanced) and a right (signal-enhanced) sample.
38 
39 */
40 
41 #include "TMVA/DecisionTreeNode.h"
42 
43 #include "TMVA/Types.h"
44 #include "TMVA/MsgLogger.h"
45 #include "TMVA/Tools.h"
46 #include "TMVA/Event.h"
47 
48 #include "ThreadLocalStorage.h"
49 #include "TString.h"
50 
51 #include <algorithm>
52 #include <exception>
53 #include <iomanip>
54 #include <limits>
55 #include <sstream>
56 
57 using std::string;
58 
59 ClassImp(TMVA::DecisionTreeNode);
60 
61 bool TMVA::DecisionTreeNode::fgIsTraining = false;
62 UInt_t TMVA::DecisionTreeNode::fgTmva_Version_Code = 0;
63 
64 ////////////////////////////////////////////////////////////////////////////////
65 /// constructor of an essentially "empty" node floating in space
66 
67 TMVA::DecisionTreeNode::DecisionTreeNode()
68  : TMVA::Node(),
69  fCutValue(0),
70  fCutType ( kTRUE ),
71  fSelector ( -1 ),
72  fResponse(-99 ),
73  fRMS(0),
74  fNodeType (-99 ),
75  fPurity (-99),
76  fIsTerminalNode( kFALSE )
77 {
78  if (DecisionTreeNode::fgIsTraining){
79  fTrainInfo = new DTNodeTrainingInfo();
80  //std::cout << "Node constructor with TrainingINFO"<<std::endl;
81  }
82  else {
83  //std::cout << "**Node constructor WITHOUT TrainingINFO"<<std::endl;
84  fTrainInfo = 0;
85  }
86 }
87 
88 ////////////////////////////////////////////////////////////////////////////////
89 /// constructor of a daughter node as a daughter of 'p'
90 
91 TMVA::DecisionTreeNode::DecisionTreeNode(TMVA::Node* p, char pos)
92  : TMVA::Node(p, pos),
93  fCutValue( 0 ),
94  fCutType ( kTRUE ),
95  fSelector( -1 ),
96  fResponse(-99 ),
97  fRMS(0),
98  fNodeType( -99 ),
99  fPurity (-99),
100  fIsTerminalNode( kFALSE )
101 {
102  if (DecisionTreeNode::fgIsTraining){
103  fTrainInfo = new DTNodeTrainingInfo();
104  //std::cout << "Node constructor with TrainingINFO"<<std::endl;
105  }
106  else {
107  //std::cout << "**Node constructor WITHOUT TrainingINFO"<<std::endl;
108  fTrainInfo = 0;
109  }
110 }
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 /// copy constructor of a node. It will result in an explicit copy of
114 /// the node and recursively all it's daughters
115 
116 TMVA::DecisionTreeNode::DecisionTreeNode(const TMVA::DecisionTreeNode &n,
117  DecisionTreeNode* parent)
118  : TMVA::Node(n),
119  fCutValue( n.fCutValue ),
120  fCutType ( n.fCutType ),
121  fSelector( n.fSelector ),
122  fResponse( n.fResponse ),
123  fRMS ( n.fRMS),
124  fNodeType( n.fNodeType ),
125  fPurity ( n.fPurity),
126  fIsTerminalNode( n.fIsTerminalNode )
127 {
128  this->SetParent( parent );
129  if (n.GetLeft() == 0 ) this->SetLeft(NULL);
130  else this->SetLeft( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetLeft())),this));
131 
132  if (n.GetRight() == 0 ) this->SetRight(NULL);
133  else this->SetRight( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetRight())),this));
134 
135  if (DecisionTreeNode::fgIsTraining){
136  fTrainInfo = new DTNodeTrainingInfo(*(n.fTrainInfo));
137  //std::cout << "Node constructor with TrainingINFO"<<std::endl;
138  }
139  else {
140  //std::cout << "**Node constructor WITHOUT TrainingINFO"<<std::endl;
141  fTrainInfo = 0;
142  }
143 }
144 
145 ////////////////////////////////////////////////////////////////////////////////
146 /// destructor
147 
148 TMVA::DecisionTreeNode::~DecisionTreeNode(){
149  delete fTrainInfo;
150 }
151 
152 ////////////////////////////////////////////////////////////////////////////////
153 /// test event if it descends the tree at this node to the right
154 
155 Bool_t TMVA::DecisionTreeNode::GoesRight(const TMVA::Event & e) const
156 {
157  Bool_t result;
158  // first check if the fisher criterium is used or ordinary cuts:
159  if (GetNFisherCoeff() == 0){
160 
161  result = (e.GetValueFast(this->GetSelector()) >= this->GetCutValue() );
162 
163  }else{
164 
165  Double_t fisher = this->GetFisherCoeff(fFisherCoeff.size()-1); // the offset
166  for (UInt_t ivar=0; ivar<fFisherCoeff.size()-1; ivar++)
167  fisher += this->GetFisherCoeff(ivar)*(e.GetValueFast(ivar));
168 
169  result = fisher > this->GetCutValue();
170  }
171 
172  if (fCutType == kTRUE) return result; //the cuts are selecting Signal ;
173  else return !result;
174 }
175 
176 ////////////////////////////////////////////////////////////////////////////////
177 /// test event if it descends the tree at this node to the left
178 
179 Bool_t TMVA::DecisionTreeNode::GoesLeft(const TMVA::Event & e) const
180 {
181  if (!this->GoesRight(e)) return kTRUE;
182  else return kFALSE;
183 }
184 
185 
186 ////////////////////////////////////////////////////////////////////////////////
187 /// return the S/(S+B) (purity) for the node
188 /// REM: even if nodes with purity 0.01 are very PURE background nodes, they still
189 /// get a small value of the purity.
190 
191 void TMVA::DecisionTreeNode::SetPurity( void )
192 {
193  if ( ( this->GetNSigEvents() + this->GetNBkgEvents() ) > 0 ) {
194  fPurity = this->GetNSigEvents() / ( this->GetNSigEvents() + this->GetNBkgEvents());
195  }
196  else {
197  Log() << kINFO << "Zero events in purity calculation , return purity=0.5" << Endl;
198  std::ostringstream oss;
199  this->Print(oss);
200  Log() <<oss.str();
201  fPurity = 0.5;
202  }
203  return;
204 }
205 
206 ////////////////////////////////////////////////////////////////////////////////
207 ///print the node
208 
209 void TMVA::DecisionTreeNode::Print(std::ostream& os) const
210 {
211  os << "< *** " << std::endl;
212  os << " d: " << this->GetDepth()
213  << std::setprecision(6)
214  << "NCoef: " << this->GetNFisherCoeff();
215  for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) { os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
216  os << " ivar: " << this->GetSelector()
217  << " cut: " << this->GetCutValue()
218  << " cType: " << this->GetCutType()
219  << " s: " << this->GetNSigEvents()
220  << " b: " << this->GetNBkgEvents()
221  << " nEv: " << this->GetNEvents()
222  << " suw: " << this->GetNSigEvents_unweighted()
223  << " buw: " << this->GetNBkgEvents_unweighted()
224  << " nEvuw: " << this->GetNEvents_unweighted()
225  << " sepI: " << this->GetSeparationIndex()
226  << " sepG: " << this->GetSeparationGain()
227  << " nType: " << this->GetNodeType()
228  << std::endl;
229 
230  os << "My address is " << long(this) << ", ";
231  if (this->GetParent() != NULL) os << " parent at addr: " << long(this->GetParent()) ;
232  if (this->GetLeft() != NULL) os << " left daughter at addr: " << long(this->GetLeft());
233  if (this->GetRight() != NULL) os << " right daughter at addr: " << long(this->GetRight()) ;
234 
235  os << " **** > " << std::endl;
236 }
237 
238 ////////////////////////////////////////////////////////////////////////////////
239 /// recursively print the node and its daughters (--> print the 'tree')
240 
241 void TMVA::DecisionTreeNode::PrintRec(std::ostream& os) const
242 {
243  os << this->GetDepth()
244  << std::setprecision(6)
245  << " " << this->GetPos()
246  << "NCoef: " << this->GetNFisherCoeff();
247  for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
248  os << " ivar: " << this->GetSelector()
249  << " cut: " << this->GetCutValue()
250  << " cType: " << this->GetCutType()
251  << " s: " << this->GetNSigEvents()
252  << " b: " << this->GetNBkgEvents()
253  << " nEv: " << this->GetNEvents()
254  << " suw: " << this->GetNSigEvents_unweighted()
255  << " buw: " << this->GetNBkgEvents_unweighted()
256  << " nEvuw: " << this->GetNEvents_unweighted()
257  << " sepI: " << this->GetSeparationIndex()
258  << " sepG: " << this->GetSeparationGain()
259  << " res: " << this->GetResponse()
260  << " rms: " << this->GetRMS()
261  << " nType: " << this->GetNodeType();
262  if (this->GetCC() > 10000000000000.) os << " CC: " << 100000. << std::endl;
263  else os << " CC: " << this->GetCC() << std::endl;
264 
265  if (this->GetLeft() != NULL) this->GetLeft() ->PrintRec(os);
266  if (this->GetRight() != NULL) this->GetRight()->PrintRec(os);
267 }
268 
269 ////////////////////////////////////////////////////////////////////////////////
270 /// Read the data block
271 
272 Bool_t TMVA::DecisionTreeNode::ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code )
273 {
274  fgTmva_Version_Code=tmva_Version_Code;
275  string tmp;
276 
277  Float_t cutVal, cutType, nsig, nbkg, nEv, nsig_unweighted, nbkg_unweighted, nEv_unweighted;
278  Float_t separationIndex, separationGain, response(-99), cc(0);
279  Int_t depth, ivar, nodeType;
280  ULong_t lseq;
281  char pos;
282 
283  is >> depth; // 2
284  if ( depth==-1 ) { return kFALSE; }
285  // if ( depth==-1 ) { delete this; return kFALSE; }
286  is >> pos ; // r
287  this->SetDepth(depth);
288  this->SetPos(pos);
289 
290  if (tmva_Version_Code < TMVA_VERSION(4,0,0)) {
291  is >> tmp >> lseq
292  >> tmp >> ivar
293  >> tmp >> cutVal
294  >> tmp >> cutType
295  >> tmp >> nsig
296  >> tmp >> nbkg
297  >> tmp >> nEv
298  >> tmp >> nsig_unweighted
299  >> tmp >> nbkg_unweighted
300  >> tmp >> nEv_unweighted
301  >> tmp >> separationIndex
302  >> tmp >> separationGain
303  >> tmp >> nodeType;
304  } else {
305  is >> tmp >> lseq
306  >> tmp >> ivar
307  >> tmp >> cutVal
308  >> tmp >> cutType
309  >> tmp >> nsig
310  >> tmp >> nbkg
311  >> tmp >> nEv
312  >> tmp >> nsig_unweighted
313  >> tmp >> nbkg_unweighted
314  >> tmp >> nEv_unweighted
315  >> tmp >> separationIndex
316  >> tmp >> separationGain
317  >> tmp >> response
318  >> tmp >> nodeType
319  >> tmp >> cc;
320  }
321 
322  this->SetSelector((UInt_t)ivar);
323  this->SetCutValue(cutVal);
324  this->SetCutType(cutType);
325  this->SetNodeType(nodeType);
326  if (fTrainInfo){
327  this->SetNSigEvents(nsig);
328  this->SetNBkgEvents(nbkg);
329  this->SetNEvents(nEv);
330  this->SetNSigEvents_unweighted(nsig_unweighted);
331  this->SetNBkgEvents_unweighted(nbkg_unweighted);
332  this->SetNEvents_unweighted(nEv_unweighted);
333  this->SetSeparationIndex(separationIndex);
334  this->SetSeparationGain(separationGain);
335  this->SetPurity();
336  // this->SetResponse(response); old .txt weightfiles don't know regression yet
337  this->SetCC(cc);
338  }
339 
340  return kTRUE;
341 }
342 
343 ////////////////////////////////////////////////////////////////////////////////
344 /// clear the nodes (their S/N, Nevents etc), just keep the structure of the tree
345 
346 void TMVA::DecisionTreeNode::ClearNodeAndAllDaughters()
347 {
348  SetNSigEvents(0);
349  SetNBkgEvents(0);
350  SetNEvents(0);
351  SetNSigEvents_unweighted(0);
352  SetNBkgEvents_unweighted(0);
353  SetNEvents_unweighted(0);
354  SetSeparationIndex(-1);
355  SetSeparationGain(-1);
356  SetPurity();
357 
358  if (this->GetLeft() != NULL) ((DecisionTreeNode*)(this->GetLeft()))->ClearNodeAndAllDaughters();
359  if (this->GetRight() != NULL) ((DecisionTreeNode*)(this->GetRight()))->ClearNodeAndAllDaughters();
360 }
361 
362 ////////////////////////////////////////////////////////////////////////////////
363 /// temporary stored node values (number of events, etc.) that originate
364 /// not from the training but from the validation data (used in pruning)
365 
366 void TMVA::DecisionTreeNode::ResetValidationData( ) {
367  SetNBValidation( 0.0 );
368  SetNSValidation( 0.0 );
369  SetSumTarget( 0 );
370  SetSumTarget2( 0 );
371 
372  if(GetLeft() != NULL && GetRight() != NULL) {
373  GetLeft()->ResetValidationData();
374  GetRight()->ResetValidationData();
375  }
376 }
377 
378 ////////////////////////////////////////////////////////////////////////////////
379 /// printout of the node (can be read in with ReadDataRecord)
380 
381 void TMVA::DecisionTreeNode::PrintPrune( std::ostream& os ) const {
382  os << "----------------------" << std::endl
383  << "|~T_t| " << GetNTerminal() << std::endl
384  << "R(t): " << GetNodeR() << std::endl
385  << "R(T_t): " << GetSubTreeR() << std::endl
386  << "g(t): " << GetAlpha() << std::endl
387  << "G(t): " << GetAlphaMinSubtree() << std::endl;
388 }
389 
390 ////////////////////////////////////////////////////////////////////////////////
391 /// recursive printout of the node and its daughters
392 
393 void TMVA::DecisionTreeNode::PrintRecPrune( std::ostream& os ) const {
394  this->PrintPrune(os);
395  if(this->GetLeft() != NULL && this->GetRight() != NULL) {
396  ((DecisionTreeNode*)this->GetLeft())->PrintRecPrune(os);
397  ((DecisionTreeNode*)this->GetRight())->PrintRecPrune(os);
398  }
399 }
400 
401 ////////////////////////////////////////////////////////////////////////////////
402 
403 void TMVA::DecisionTreeNode::SetCC(Double_t cc)
404 {
405  if (fTrainInfo) fTrainInfo->fCC = cc;
406  else Log() << kFATAL << "call to SetCC without trainingInfo" << Endl;
407 }
408 
409 ////////////////////////////////////////////////////////////////////////////////
410 /// return the minimum of variable ivar from the training sample
411 /// that pass/end up in this node
412 
413 Float_t TMVA::DecisionTreeNode::GetSampleMin(UInt_t ivar) const {
414  if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMin[ivar];
415  else Log() << kFATAL << "You asked for Min of the event sample in node for variable "
416  << ivar << " that is out of range" << Endl;
417  return -9999;
418 }
419 
420 ////////////////////////////////////////////////////////////////////////////////
421 /// return the maximum of variable ivar from the training sample
422 /// that pass/end up in this node
423 
424 Float_t TMVA::DecisionTreeNode::GetSampleMax(UInt_t ivar) const {
425  if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMax[ivar];
426  else Log() << kFATAL << "You asked for Max of the event sample in node for variable "
427  << ivar << " that is out of range" << Endl;
428  return 9999;
429 }
430 
431 ////////////////////////////////////////////////////////////////////////////////
432 /// set the minimum of variable ivar from the training sample
433 /// that pass/end up in this node
434 
435 void TMVA::DecisionTreeNode::SetSampleMin(UInt_t ivar, Float_t xmin){
436  if ( fTrainInfo) {
437  if ( ivar >= fTrainInfo->fSampleMin.size()) fTrainInfo->fSampleMin.resize(ivar+1);
438  fTrainInfo->fSampleMin[ivar]=xmin;
439  }
440 }
441 
442 ////////////////////////////////////////////////////////////////////////////////
443 /// set the maximum of variable ivar from the training sample
444 /// that pass/end up in this node
445 
446 void TMVA::DecisionTreeNode::SetSampleMax(UInt_t ivar, Float_t xmax){
447  if( ! fTrainInfo ) return;
448  if ( ivar >= fTrainInfo->fSampleMax.size() )
449  fTrainInfo->fSampleMax.resize(ivar+1);
450  fTrainInfo->fSampleMax[ivar]=xmax;
451 }
452 
453 ////////////////////////////////////////////////////////////////////////////////
454 
455 void TMVA::DecisionTreeNode::ReadAttributes(void* node, UInt_t /* tmva_Version_Code */ )
456 {
457  Float_t tempNSigEvents,tempNBkgEvents;
458 
459  Int_t nCoef;
460  if (gTools().HasAttr(node, "NCoef")){
461  gTools().ReadAttr(node, "NCoef", nCoef );
462  this->SetNFisherCoeff(nCoef);
463  Double_t tmp;
464  for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {
465  gTools().ReadAttr(node, Form("fC%d",i), tmp );
466  this->SetFisherCoeff(i,tmp);
467  }
468  }else{
469  this->SetNFisherCoeff(0);
470  }
471  gTools().ReadAttr(node, "IVar", fSelector );
472  gTools().ReadAttr(node, "Cut", fCutValue );
473  gTools().ReadAttr(node, "cType", fCutType );
474  if (gTools().HasAttr(node,"res")) gTools().ReadAttr(node, "res", fResponse);
475  if (gTools().HasAttr(node,"rms")) gTools().ReadAttr(node, "rms", fRMS);
476  // else {
477  if( gTools().HasAttr(node, "purity") ) {
478  gTools().ReadAttr(node, "purity",fPurity );
479  } else {
480  gTools().ReadAttr(node, "nS", tempNSigEvents );
481  gTools().ReadAttr(node, "nB", tempNBkgEvents );
482  fPurity = tempNSigEvents / (tempNSigEvents + tempNBkgEvents);
483  }
484  // }
485  gTools().ReadAttr(node, "nType", fNodeType );
486 }
487 
488 
489 ////////////////////////////////////////////////////////////////////////////////
490 /// add attribute to xml
491 
492 void TMVA::DecisionTreeNode::AddAttributesToNode(void* node) const
493 {
494  gTools().AddAttr(node, "NCoef", GetNFisherCoeff());
495  for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++)
496  gTools().AddAttr(node, Form("fC%d",i), this->GetFisherCoeff(i));
497 
498  gTools().AddAttr(node, "IVar", GetSelector());
499  gTools().AddAttr(node, "Cut", GetCutValue());
500  gTools().AddAttr(node, "cType", GetCutType());
501 
502  //UInt_t analysisType = (dynamic_cast<const TMVA::DecisionTree*>(GetParentTree()) )->GetAnalysisType();
503  // if ( analysisType == TMVA::Types:: kRegression) {
504  gTools().AddAttr(node, "res", GetResponse());
505  gTools().AddAttr(node, "rms", GetRMS());
506  //} else if ( analysisType == TMVA::Types::kClassification) {
507  gTools().AddAttr(node, "purity",GetPurity());
508  //}
509  gTools().AddAttr(node, "nType", GetNodeType());
510 }
511 
512 ////////////////////////////////////////////////////////////////////////////////
513 /// set fisher coefficients
514 
515 void TMVA::DecisionTreeNode::SetFisherCoeff(Int_t ivar, Double_t coeff)
516 {
517  if ((Int_t) fFisherCoeff.size()<ivar+1) fFisherCoeff.resize(ivar+1) ;
518  fFisherCoeff[ivar]=coeff;
519 }
520 
521 ////////////////////////////////////////////////////////////////////////////////
522 /// adding attributes to tree node (well, was used in BinarySearchTree,
523 /// and somehow I guess someone programmed it such that we need this in
524 /// this tree too, although we don't..)
525 
526 void TMVA::DecisionTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
527 {
528 }
529 
530 ////////////////////////////////////////////////////////////////////////////////
531 /// reading attributes from tree node (well, was used in BinarySearchTree,
532 /// and somehow I guess someone programmed it such that we need this in
533 /// this tree too, although we don't..)
534 
535 void TMVA::DecisionTreeNode::ReadContent( std::stringstream& /*s*/ )
536 {
537 }
538 ////////////////////////////////////////////////////////////////////////////////
539 
540 TMVA::MsgLogger& TMVA::DecisionTreeNode::Log() {
541  TTHREAD_TLS_DECL_ARG(MsgLogger,logger,"DecisionTreeNode"); // static because there is a huge number of nodes...
542  return logger;
543 }