59 ClassImp(TMVA::DecisionTreeNode);
61 bool TMVA::DecisionTreeNode::fgIsTraining =
false;
62 UInt_t TMVA::DecisionTreeNode::fgTmva_Version_Code = 0;
67 TMVA::DecisionTreeNode::DecisionTreeNode()
76 fIsTerminalNode( kFALSE )
78 if (DecisionTreeNode::fgIsTraining){
79 fTrainInfo =
new DTNodeTrainingInfo();
91 TMVA::DecisionTreeNode::DecisionTreeNode(TMVA::Node* p,
char pos)
100 fIsTerminalNode( kFALSE )
102 if (DecisionTreeNode::fgIsTraining){
103 fTrainInfo =
new DTNodeTrainingInfo();
116 TMVA::DecisionTreeNode::DecisionTreeNode(
const TMVA::DecisionTreeNode &n,
117 DecisionTreeNode* parent)
119 fCutValue( n.fCutValue ),
120 fCutType ( n.fCutType ),
121 fSelector( n.fSelector ),
122 fResponse( n.fResponse ),
124 fNodeType( n.fNodeType ),
125 fPurity ( n.fPurity),
126 fIsTerminalNode( n.fIsTerminalNode )
128 this->SetParent( parent );
129 if (n.GetLeft() == 0 ) this->SetLeft(NULL);
130 else this->SetLeft(
new DecisionTreeNode( *((DecisionTreeNode*)(n.GetLeft())),
this));
132 if (n.GetRight() == 0 ) this->SetRight(NULL);
133 else this->SetRight(
new DecisionTreeNode( *((DecisionTreeNode*)(n.GetRight())),
this));
135 if (DecisionTreeNode::fgIsTraining){
136 fTrainInfo =
new DTNodeTrainingInfo(*(n.fTrainInfo));
148 TMVA::DecisionTreeNode::~DecisionTreeNode(){
155 Bool_t TMVA::DecisionTreeNode::GoesRight(
const TMVA::Event & e)
const
159 if (GetNFisherCoeff() == 0){
161 result = (e.GetValueFast(this->GetSelector()) >= this->GetCutValue() );
165 Double_t fisher = this->GetFisherCoeff(fFisherCoeff.size()-1);
166 for (UInt_t ivar=0; ivar<fFisherCoeff.size()-1; ivar++)
167 fisher += this->GetFisherCoeff(ivar)*(e.GetValueFast(ivar));
169 result = fisher > this->GetCutValue();
172 if (fCutType == kTRUE)
return result;
179 Bool_t TMVA::DecisionTreeNode::GoesLeft(
const TMVA::Event & e)
const
181 if (!this->GoesRight(e))
return kTRUE;
191 void TMVA::DecisionTreeNode::SetPurity(
void )
193 if ( ( this->GetNSigEvents() + this->GetNBkgEvents() ) > 0 ) {
194 fPurity = this->GetNSigEvents() / ( this->GetNSigEvents() + this->GetNBkgEvents());
197 Log() << kINFO <<
"Zero events in purity calculation , return purity=0.5" << Endl;
198 std::ostringstream oss;
209 void TMVA::DecisionTreeNode::Print(std::ostream& os)
const
211 os <<
"< *** " << std::endl;
212 os <<
" d: " << this->GetDepth()
213 << std::setprecision(6)
214 <<
"NCoef: " << this->GetNFisherCoeff();
215 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) { os <<
"fC"<<i<<
": " << this->GetFisherCoeff(i);}
216 os <<
" ivar: " << this->GetSelector()
217 <<
" cut: " << this->GetCutValue()
218 <<
" cType: " << this->GetCutType()
219 <<
" s: " << this->GetNSigEvents()
220 <<
" b: " << this->GetNBkgEvents()
221 <<
" nEv: " << this->GetNEvents()
222 <<
" suw: " << this->GetNSigEvents_unweighted()
223 <<
" buw: " << this->GetNBkgEvents_unweighted()
224 <<
" nEvuw: " << this->GetNEvents_unweighted()
225 <<
" sepI: " << this->GetSeparationIndex()
226 <<
" sepG: " << this->GetSeparationGain()
227 <<
" nType: " << this->GetNodeType()
230 os <<
"My address is " << long(
this) <<
", ";
231 if (this->GetParent() != NULL) os <<
" parent at addr: " << long(this->GetParent()) ;
232 if (this->GetLeft() != NULL) os <<
" left daughter at addr: " << long(this->GetLeft());
233 if (this->GetRight() != NULL) os <<
" right daughter at addr: " << long(this->GetRight()) ;
235 os <<
" **** > " << std::endl;
241 void TMVA::DecisionTreeNode::PrintRec(std::ostream& os)
const
243 os << this->GetDepth()
244 << std::setprecision(6)
245 <<
" " << this->GetPos()
246 <<
"NCoef: " << this->GetNFisherCoeff();
247 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {os <<
"fC"<<i<<
": " << this->GetFisherCoeff(i);}
248 os <<
" ivar: " << this->GetSelector()
249 <<
" cut: " << this->GetCutValue()
250 <<
" cType: " << this->GetCutType()
251 <<
" s: " << this->GetNSigEvents()
252 <<
" b: " << this->GetNBkgEvents()
253 <<
" nEv: " << this->GetNEvents()
254 <<
" suw: " << this->GetNSigEvents_unweighted()
255 <<
" buw: " << this->GetNBkgEvents_unweighted()
256 <<
" nEvuw: " << this->GetNEvents_unweighted()
257 <<
" sepI: " << this->GetSeparationIndex()
258 <<
" sepG: " << this->GetSeparationGain()
259 <<
" res: " << this->GetResponse()
260 <<
" rms: " << this->GetRMS()
261 <<
" nType: " << this->GetNodeType();
262 if (this->GetCC() > 10000000000000.) os <<
" CC: " << 100000. << std::endl;
263 else os <<
" CC: " << this->GetCC() << std::endl;
265 if (this->GetLeft() != NULL) this->GetLeft() ->PrintRec(os);
266 if (this->GetRight() != NULL) this->GetRight()->PrintRec(os);
272 Bool_t TMVA::DecisionTreeNode::ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code )
274 fgTmva_Version_Code=tmva_Version_Code;
277 Float_t cutVal, cutType, nsig, nbkg, nEv, nsig_unweighted, nbkg_unweighted, nEv_unweighted;
278 Float_t separationIndex, separationGain, response(-99), cc(0);
279 Int_t depth, ivar, nodeType;
284 if ( depth==-1 ) {
return kFALSE; }
287 this->SetDepth(depth);
290 if (tmva_Version_Code < TMVA_VERSION(4,0,0)) {
298 >> tmp >> nsig_unweighted
299 >> tmp >> nbkg_unweighted
300 >> tmp >> nEv_unweighted
301 >> tmp >> separationIndex
302 >> tmp >> separationGain
312 >> tmp >> nsig_unweighted
313 >> tmp >> nbkg_unweighted
314 >> tmp >> nEv_unweighted
315 >> tmp >> separationIndex
316 >> tmp >> separationGain
322 this->SetSelector((UInt_t)ivar);
323 this->SetCutValue(cutVal);
324 this->SetCutType(cutType);
325 this->SetNodeType(nodeType);
327 this->SetNSigEvents(nsig);
328 this->SetNBkgEvents(nbkg);
329 this->SetNEvents(nEv);
330 this->SetNSigEvents_unweighted(nsig_unweighted);
331 this->SetNBkgEvents_unweighted(nbkg_unweighted);
332 this->SetNEvents_unweighted(nEv_unweighted);
333 this->SetSeparationIndex(separationIndex);
334 this->SetSeparationGain(separationGain);
346 void TMVA::DecisionTreeNode::ClearNodeAndAllDaughters()
351 SetNSigEvents_unweighted(0);
352 SetNBkgEvents_unweighted(0);
353 SetNEvents_unweighted(0);
354 SetSeparationIndex(-1);
355 SetSeparationGain(-1);
358 if (this->GetLeft() != NULL) ((DecisionTreeNode*)(this->GetLeft()))->ClearNodeAndAllDaughters();
359 if (this->GetRight() != NULL) ((DecisionTreeNode*)(this->GetRight()))->ClearNodeAndAllDaughters();
366 void TMVA::DecisionTreeNode::ResetValidationData( ) {
367 SetNBValidation( 0.0 );
368 SetNSValidation( 0.0 );
372 if(GetLeft() != NULL && GetRight() != NULL) {
373 GetLeft()->ResetValidationData();
374 GetRight()->ResetValidationData();
381 void TMVA::DecisionTreeNode::PrintPrune( std::ostream& os )
const {
382 os <<
"----------------------" << std::endl
383 <<
"|~T_t| " << GetNTerminal() << std::endl
384 <<
"R(t): " << GetNodeR() << std::endl
385 <<
"R(T_t): " << GetSubTreeR() << std::endl
386 <<
"g(t): " << GetAlpha() << std::endl
387 <<
"G(t): " << GetAlphaMinSubtree() << std::endl;
393 void TMVA::DecisionTreeNode::PrintRecPrune( std::ostream& os )
const {
394 this->PrintPrune(os);
395 if(this->GetLeft() != NULL && this->GetRight() != NULL) {
396 ((DecisionTreeNode*)this->GetLeft())->PrintRecPrune(os);
397 ((DecisionTreeNode*)this->GetRight())->PrintRecPrune(os);
403 void TMVA::DecisionTreeNode::SetCC(Double_t cc)
405 if (fTrainInfo) fTrainInfo->fCC = cc;
406 else Log() << kFATAL <<
"call to SetCC without trainingInfo" << Endl;
413 Float_t TMVA::DecisionTreeNode::GetSampleMin(UInt_t ivar)
const {
414 if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size())
return fTrainInfo->fSampleMin[ivar];
415 else Log() << kFATAL <<
"You asked for Min of the event sample in node for variable "
416 << ivar <<
" that is out of range" << Endl;
424 Float_t TMVA::DecisionTreeNode::GetSampleMax(UInt_t ivar)
const {
425 if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size())
return fTrainInfo->fSampleMax[ivar];
426 else Log() << kFATAL <<
"You asked for Max of the event sample in node for variable "
427 << ivar <<
" that is out of range" << Endl;
435 void TMVA::DecisionTreeNode::SetSampleMin(UInt_t ivar, Float_t xmin){
437 if ( ivar >= fTrainInfo->fSampleMin.size()) fTrainInfo->fSampleMin.resize(ivar+1);
438 fTrainInfo->fSampleMin[ivar]=xmin;
446 void TMVA::DecisionTreeNode::SetSampleMax(UInt_t ivar, Float_t xmax){
447 if( ! fTrainInfo )
return;
448 if ( ivar >= fTrainInfo->fSampleMax.size() )
449 fTrainInfo->fSampleMax.resize(ivar+1);
450 fTrainInfo->fSampleMax[ivar]=xmax;
455 void TMVA::DecisionTreeNode::ReadAttributes(
void* node, UInt_t )
457 Float_t tempNSigEvents,tempNBkgEvents;
460 if (gTools().HasAttr(node,
"NCoef")){
461 gTools().ReadAttr(node,
"NCoef", nCoef );
462 this->SetNFisherCoeff(nCoef);
464 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {
465 gTools().ReadAttr(node, Form(
"fC%d",i), tmp );
466 this->SetFisherCoeff(i,tmp);
469 this->SetNFisherCoeff(0);
471 gTools().ReadAttr(node,
"IVar", fSelector );
472 gTools().ReadAttr(node,
"Cut", fCutValue );
473 gTools().ReadAttr(node,
"cType", fCutType );
474 if (gTools().HasAttr(node,
"res")) gTools().ReadAttr(node,
"res", fResponse);
475 if (gTools().HasAttr(node,
"rms")) gTools().ReadAttr(node,
"rms", fRMS);
477 if( gTools().HasAttr(node,
"purity") ) {
478 gTools().ReadAttr(node,
"purity",fPurity );
480 gTools().ReadAttr(node,
"nS", tempNSigEvents );
481 gTools().ReadAttr(node,
"nB", tempNBkgEvents );
482 fPurity = tempNSigEvents / (tempNSigEvents + tempNBkgEvents);
485 gTools().ReadAttr(node,
"nType", fNodeType );
492 void TMVA::DecisionTreeNode::AddAttributesToNode(
void* node)
const
494 gTools().AddAttr(node,
"NCoef", GetNFisherCoeff());
495 for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++)
496 gTools().AddAttr(node, Form(
"fC%d",i), this->GetFisherCoeff(i));
498 gTools().AddAttr(node,
"IVar", GetSelector());
499 gTools().AddAttr(node,
"Cut", GetCutValue());
500 gTools().AddAttr(node,
"cType", GetCutType());
504 gTools().AddAttr(node,
"res", GetResponse());
505 gTools().AddAttr(node,
"rms", GetRMS());
507 gTools().AddAttr(node,
"purity",GetPurity());
509 gTools().AddAttr(node,
"nType", GetNodeType());
515 void TMVA::DecisionTreeNode::SetFisherCoeff(Int_t ivar, Double_t coeff)
517 if ((Int_t) fFisherCoeff.size()<ivar+1) fFisherCoeff.resize(ivar+1) ;
518 fFisherCoeff[ivar]=coeff;
526 void TMVA::DecisionTreeNode::AddContentToNode( std::stringstream& )
const
535 void TMVA::DecisionTreeNode::ReadContent( std::stringstream& )
540 TMVA::MsgLogger& TMVA::DecisionTreeNode::Log() {
541 TTHREAD_TLS_DECL_ARG(MsgLogger,logger,
"DecisionTreeNode");