91 const Int_t TMVA::DecisionTree::fgRandomSeed = 0;
95 ClassImp(TMVA::DecisionTree);
97 bool almost_equal_float(
float x,
float y,
int ulp=4){
100 return std::abs(x-y) < std::numeric_limits<float>::epsilon() * std::abs(x+y) * ulp
102 || std::abs(x-y) < std::numeric_limits<float>::min();
105 bool almost_equal_double(
double x,
double y,
int ulp=4){
108 return std::abs(x-y) < std::numeric_limits<double>::epsilon() * std::abs(x+y) * ulp
110 || std::abs(x-y) < std::numeric_limits<double>::min();
118 TMVA::DecisionTree::DecisionTree():
122 fUseFisherCuts (kFALSE),
123 fMinLinCorrForFisher (1),
124 fUseExclusiveVars (kTRUE),
130 fUseSearchTree(kFALSE),
132 fPruneMethod (kNoPruning),
133 fNNodesBeforePruning(0),
134 fNodePurityLimit(0.5),
135 fRandomisedTree (kFALSE),
137 fUsePoissonNvars(kFALSE),
142 fAnalysisType (Types::kClassification),
153 TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType, Float_t minSize, Int_t nCuts, DataSetInfo* dataInfo, UInt_t cls,
154 Bool_t randomisedTree, Int_t useNvars, Bool_t usePoissonNvars,
155 UInt_t nMaxDepth, Int_t iSeed, Float_t purityLimit, Int_t treeID):
159 fUseFisherCuts (kFALSE),
160 fMinLinCorrForFisher (1),
161 fUseExclusiveVars (kTRUE),
165 fMinNodeSize (minSize),
167 fUseSearchTree (kFALSE),
169 fPruneMethod (kNoPruning),
170 fNNodesBeforePruning(0),
171 fNodePurityLimit(purityLimit),
172 fRandomisedTree (randomisedTree),
173 fUseNvars (useNvars),
174 fUsePoissonNvars(usePoissonNvars),
175 fMyTrandom (new TRandom3(iSeed)),
176 fMaxDepth (nMaxDepth),
179 fAnalysisType (Types::kClassification),
180 fDataSetInfo (dataInfo)
182 if (sepType == NULL) {
185 fAnalysisType = Types::kRegression;
186 fRegType =
new RegressionVariance();
189 Log() << kWARNING <<
" You had chosen the training mode using optimal cuts, not\n"
190 <<
" based on a grid of " << fNCuts <<
" by setting the option NCuts < 0\n"
191 <<
" as this doesn't exist yet, I set it to " << fNCuts <<
" and use the grid"
195 fAnalysisType = Types::kClassification;
203 TMVA::DecisionTree::DecisionTree(
const DecisionTree &d ):
207 fUseFisherCuts (d.fUseFisherCuts),
208 fMinLinCorrForFisher (d.fMinLinCorrForFisher),
209 fUseExclusiveVars (d.fUseExclusiveVars),
210 fSepType (d.fSepType),
211 fRegType (d.fRegType),
212 fMinSize (d.fMinSize),
213 fMinNodeSize(d.fMinNodeSize),
214 fMinSepGain (d.fMinSepGain),
215 fUseSearchTree (d.fUseSearchTree),
216 fPruneStrength (d.fPruneStrength),
217 fPruneMethod (d.fPruneMethod),
218 fNodePurityLimit(d.fNodePurityLimit),
219 fRandomisedTree (d.fRandomisedTree),
220 fUseNvars (d.fUseNvars),
221 fUsePoissonNvars(d.fUsePoissonNvars),
222 fMyTrandom (new TRandom3(fgRandomSeed)),
223 fMaxDepth (d.fMaxDepth),
224 fSigClass (d.fSigClass),
226 fAnalysisType(d.fAnalysisType),
227 fDataSetInfo (d.fDataSetInfo)
229 this->SetRoot(
new TMVA::DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
230 this->SetParentTreeInNodes();
239 TMVA::DecisionTree::~DecisionTree()
243 if (fMyTrandom)
delete fMyTrandom;
244 if (fRegType)
delete fRegType;
251 void TMVA::DecisionTree::SetParentTreeInNodes( Node *n )
256 Log() << kFATAL <<
"SetParentTreeNodes: started with undefined ROOT node" <<Endl;
261 if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
262 Log() << kFATAL <<
" Node with only one daughter?? Something went wrong" << Endl;
264 }
else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
265 Log() << kFATAL <<
" Node with only one daughter?? Something went wrong" << Endl;
269 if (this->GetLeftDaughter(n) != NULL) {
270 this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
272 if (this->GetRightDaughter(n) != NULL) {
273 this->SetParentTreeInNodes( this->GetRightDaughter(n) );
276 n->SetParentTree(
this);
277 if (n->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(n->GetDepth());
284 TMVA::DecisionTree* TMVA::DecisionTree::CreateFromXML(
void* node, UInt_t tmva_Version_Code ) {
285 std::string type(
"");
286 gTools().ReadAttr(node,
"type", type);
287 DecisionTree* dt =
new DecisionTree();
289 dt->ReadXML( node, tmva_Version_Code );
300 struct BuildNodeInfo{
302 BuildNodeInfo(Int_t fNvars,
const TMVA::Event* evt){
305 xmin = std::vector<Float_t>(nvars);
306 xmax = std::vector<Float_t>(nvars);
309 for (Int_t ivar=0; ivar<fNvars; ivar++) {
310 const Double_t val = evt->GetValueFast(ivar);
316 BuildNodeInfo(Int_t fNvars, std::vector<Float_t>& inxmin, std::vector<Float_t>& inxmax){
319 xmin = std::vector<Float_t>(nvars);
320 xmax = std::vector<Float_t>(nvars);
323 for (Int_t ivar=0; ivar<fNvars; ivar++) {
324 xmin[ivar]=inxmin[ivar];
325 xmax[ivar]=inxmax[ivar];
338 Double_t target2 = 0;
339 std::vector<Float_t> xmin;
340 std::vector<Float_t> xmax;
344 BuildNodeInfo operator+(
const BuildNodeInfo& other)
346 BuildNodeInfo ret(nvars, xmin, xmax);
347 if(nvars != other.nvars)
349 std::cout <<
"!!! ERROR BuildNodeInfo1+BuildNodeInfo2 failure. Nvars1 != Nvars2." << std::endl;
353 ret.suw = suw + other.suw;
354 ret.sub = sub + other.sub;
356 ret.buw = buw + other.buw;
357 ret.bub = bub + other.bub;
358 ret.target = target + other.target;
359 ret.target2 = target2 + other.target2;
362 for(Int_t i=0; i<nvars; i++)
364 ret.xmin[i]=xmin[i]<other.xmin[i]?xmin[i]:other.xmin[i];
365 ret.xmax[i]=xmax[i]>other.xmax[i]?xmax[i]:other.xmax[i];
380 UInt_t TMVA::DecisionTree::BuildTree(
const std::vector<const TMVA::Event*> & eventSample,
381 TMVA::DecisionTreeNode *node)
385 node =
new TMVA::DecisionTreeNode();
389 this->GetRoot()->SetPos(
's');
390 this->GetRoot()->SetDepth(0);
391 this->GetRoot()->SetParentTree(
this);
392 fMinSize = fMinNodeSize/100. * eventSample.size();
394 Log() << kDEBUG <<
"\tThe minimal node size MinNodeSize=" << fMinNodeSize <<
" fMinNodeSize="<<fMinNodeSize<<
"% is translated to an actual number of events = "<< fMinSize<<
" for the training sample size of " << eventSample.size() << Endl;
395 Log() << kDEBUG <<
"\tNote: This number will be taken as absolute minimum in the node, " << Endl;
396 Log() << kDEBUG <<
" \tin terms of 'weighted events' and unweighted ones !! " << Endl;
400 UInt_t nevents = eventSample.size();
403 if (fNvars==0) fNvars = eventSample[0]->GetNVariables();
404 fVariableImportance.resize(fNvars);
406 else Log() << kFATAL <<
":<BuildTree> eventsample Size == 0 " << Endl;
413 UInt_t nPartitions = TMVA::Config::Instance().GetThreadExecutor().GetPoolSize();
414 auto seeds = ROOT::TSeqU(nPartitions);
417 auto f = [
this, &eventSample, &nPartitions](UInt_t partition = 0){
419 Int_t start = 1.0*partition/nPartitions*eventSample.size();
420 Int_t end = (partition+1.0)/nPartitions*eventSample.size();
422 BuildNodeInfo nodeInfof(fNvars, eventSample[0]);
424 for(Int_t iev=start; iev<end; iev++){
425 const TMVA::Event* evt = eventSample[iev];
426 const Double_t weight = evt->GetWeight();
427 const Double_t orgWeight = evt->GetOriginalWeight();
428 if (evt->GetClass() == fSigClass) {
429 nodeInfof.s += weight;
431 nodeInfof.sub += orgWeight;
434 nodeInfof.b += weight;
436 nodeInfof.bub += orgWeight;
438 if ( DoRegression() ) {
439 const Double_t tgt = evt->GetTarget(0);
440 nodeInfof.target +=weight*tgt;
441 nodeInfof.target2+=weight*tgt*tgt;
445 for (UInt_t ivar=0; ivar<fNvars; ivar++) {
446 const Double_t val = evt->GetValueFast(ivar);
448 nodeInfof.xmin[ivar]=val;
449 nodeInfof.xmax[ivar]=val;
451 if (val < nodeInfof.xmin[ivar]) nodeInfof.xmin[ivar]=val;
452 if (val > nodeInfof.xmax[ivar]) nodeInfof.xmax[ivar]=val;
459 BuildNodeInfo nodeInfoInit(fNvars, eventSample[0]);
462 auto redfunc = [nodeInfoInit](std::vector<BuildNodeInfo> v) -> BuildNodeInfo {
return std::accumulate(v.begin(), v.end(), nodeInfoInit); };
463 BuildNodeInfo nodeInfo = TMVA::Config::Instance().GetThreadExecutor().MapReduce(f, seeds, redfunc);
466 if (nodeInfo.s+nodeInfo.b < 0) {
467 Log() << kWARNING <<
" One of the Decision Tree nodes has negative total number of signal or background events. "
468 <<
"(Nsig="<<nodeInfo.s<<
" Nbkg="<<nodeInfo.b<<
" Probaby you use a Monte Carlo with negative weights. That should in principle "
469 <<
"be fine as long as on average you end up with something positive. For this you have to make sure that the "
470 <<
"minimal number of (unweighted) events demanded for a tree node (currently you use: MinNodeSize="<<fMinNodeSize
471 <<
"% of training events, you can set this via the BDT option string when booking the classifier) is large enough "
472 <<
"to allow for reasonable averaging!!!" << Endl
473 <<
" If this does not help.. maybe you want to try the option: NoNegWeightsInTraining which ignores events "
474 <<
"with negative weight in the training." << Endl;
476 for (UInt_t i=0; i<eventSample.size(); i++) {
477 if (eventSample[i]->GetClass() != fSigClass) {
478 nBkg += eventSample[i]->GetWeight();
479 Log() << kDEBUG <<
"Event "<< i<<
" has (original) weight: " << eventSample[i]->GetWeight()/eventSample[i]->GetBoostWeight()
480 <<
" boostWeight: " << eventSample[i]->GetBoostWeight() << Endl;
483 Log() << kDEBUG <<
" that gives in total: " << nBkg<<Endl;
486 node->SetNSigEvents(nodeInfo.s);
487 node->SetNBkgEvents(nodeInfo.b);
488 node->SetNSigEvents_unweighted(nodeInfo.suw);
489 node->SetNBkgEvents_unweighted(nodeInfo.buw);
490 node->SetNSigEvents_unboosted(nodeInfo.sub);
491 node->SetNBkgEvents_unboosted(nodeInfo.bub);
493 if (node == this->GetRoot()) {
494 node->SetNEvents(nodeInfo.s+nodeInfo.b);
495 node->SetNEvents_unweighted(nodeInfo.suw+nodeInfo.buw);
496 node->SetNEvents_unboosted(nodeInfo.sub+nodeInfo.bub);
500 for (UInt_t ivar=0; ivar<fNvars; ivar++) {
501 node->SetSampleMin(ivar,nodeInfo.xmin[ivar]);
502 node->SetSampleMax(ivar,nodeInfo.xmax[ivar]);
516 if ((eventSample.size() >= 2*fMinSize && nodeInfo.s+nodeInfo.b >= 2*fMinSize) && node->GetDepth() < fMaxDepth
517 && ( ( nodeInfo.s!=0 && nodeInfo.b !=0 && !DoRegression()) || ( (nodeInfo.s+nodeInfo.b)!=0 && DoRegression()) ) ) {
520 Double_t separationGain;
522 separationGain = this->TrainNodeFast(eventSample, node);
525 separationGain = this->TrainNodeFull(eventSample, node);
529 if (separationGain < std::numeric_limits<double>::epsilon()) {
532 if (DoRegression()) {
533 node->SetSeparationIndex(fRegType->GetSeparationIndex(nodeInfo.s+nodeInfo.b,nodeInfo.target,nodeInfo.target2));
534 node->SetResponse(nodeInfo.target/(nodeInfo.s+nodeInfo.b));
535 if( almost_equal_double(nodeInfo.target2/(nodeInfo.s+nodeInfo.b),nodeInfo.target/(nodeInfo.s+nodeInfo.b)*nodeInfo.target/(nodeInfo.s+nodeInfo.b)) ){
538 node->SetRMS(TMath::Sqrt(nodeInfo.target2/(nodeInfo.s+nodeInfo.b) - nodeInfo.target/(nodeInfo.s+nodeInfo.b)*nodeInfo.target/(nodeInfo.s+nodeInfo.b)));
542 node->SetSeparationIndex(fSepType->GetSeparationIndex(nodeInfo.s,nodeInfo.b));
543 if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
544 else node->SetNodeType(-1);
546 if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
551 std::vector<const TMVA::Event*> leftSample; leftSample.reserve(nevents);
552 std::vector<const TMVA::Event*> rightSample; rightSample.reserve(nevents);
554 Double_t nRight=0, nLeft=0;
555 Double_t nRightUnBoosted=0, nLeftUnBoosted=0;
557 for (UInt_t ie=0; ie< nevents ; ie++) {
558 if (node->GoesRight(*eventSample[ie])) {
559 rightSample.push_back(eventSample[ie]);
560 nRight += eventSample[ie]->GetWeight();
561 nRightUnBoosted += eventSample[ie]->GetOriginalWeight();
564 leftSample.push_back(eventSample[ie]);
565 nLeft += eventSample[ie]->GetWeight();
566 nLeftUnBoosted += eventSample[ie]->GetOriginalWeight();
570 if (leftSample.empty() || rightSample.empty()) {
572 Log() << kERROR <<
"<TrainNode> all events went to the same branch" << Endl
573 <<
"--- Hence new node == old node ... check" << Endl
574 <<
"--- left:" << leftSample.size()
575 <<
" right:" << rightSample.size() << Endl
576 <<
" while the separation is thought to be " << separationGain
577 <<
"\n when cutting on variable " << node->GetSelector()
578 <<
" at value " << node->GetCutValue()
579 << kFATAL <<
"--- this should never happen, please write a bug report to Helge.Voss@cern.ch" << Endl;
583 TMVA::DecisionTreeNode *rightNode =
new TMVA::DecisionTreeNode(node,
'r');
585 rightNode->SetNEvents(nRight);
586 rightNode->SetNEvents_unboosted(nRightUnBoosted);
587 rightNode->SetNEvents_unweighted(rightSample.size());
589 TMVA::DecisionTreeNode *leftNode =
new TMVA::DecisionTreeNode(node,
'l');
592 leftNode->SetNEvents(nLeft);
593 leftNode->SetNEvents_unboosted(nLeftUnBoosted);
594 leftNode->SetNEvents_unweighted(leftSample.size());
596 node->SetNodeType(0);
597 node->SetLeft(leftNode);
598 node->SetRight(rightNode);
600 this->BuildTree(rightSample, rightNode);
601 this->BuildTree(leftSample, leftNode );
606 if (DoRegression()) {
607 node->SetSeparationIndex(fRegType->GetSeparationIndex(nodeInfo.s+nodeInfo.b,nodeInfo.target,nodeInfo.target2));
608 node->SetResponse(nodeInfo.target/(nodeInfo.s+nodeInfo.b));
609 if( almost_equal_double(nodeInfo.target2/(nodeInfo.s+nodeInfo.b), nodeInfo.target/(nodeInfo.s+nodeInfo.b)*nodeInfo.target/(nodeInfo.s+nodeInfo.b)) ) {
612 node->SetRMS(TMath::Sqrt(nodeInfo.target2/(nodeInfo.s+nodeInfo.b) - nodeInfo.target/(nodeInfo.s+nodeInfo.b)*nodeInfo.target/(nodeInfo.s+nodeInfo.b)));
616 node->SetSeparationIndex(fSepType->GetSeparationIndex(nodeInfo.s,nodeInfo.b));
617 if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
618 else node->SetNodeType(-1);
627 if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
637 UInt_t TMVA::DecisionTree::BuildTree(
const std::vector<const TMVA::Event*> & eventSample,
638 TMVA::DecisionTreeNode *node)
642 node =
new TMVA::DecisionTreeNode();
646 this->GetRoot()->SetPos(
's');
647 this->GetRoot()->SetDepth(0);
648 this->GetRoot()->SetParentTree(
this);
649 fMinSize = fMinNodeSize/100. * eventSample.size();
651 Log() << kDEBUG <<
"\tThe minimal node size MinNodeSize=" << fMinNodeSize <<
" fMinNodeSize="<<fMinNodeSize<<
"% is translated to an actual number of events = "<< fMinSize<<
" for the training sample size of " << eventSample.size() << Endl;
652 Log() << kDEBUG <<
"\tNote: This number will be taken as absolute minimum in the node, " << Endl;
653 Log() << kDEBUG <<
" \tin terms of 'weighted events' and unweighted ones !! " << Endl;
657 UInt_t nevents = eventSample.size();
660 if (fNvars==0) fNvars = eventSample[0]->GetNVariables();
661 fVariableImportance.resize(fNvars);
663 else Log() << kFATAL <<
":<BuildTree> eventsample Size == 0 " << Endl;
666 Double_t suw=0, buw=0;
667 Double_t sub=0, bub=0;
668 Double_t target=0, target2=0;
669 Float_t *xmin =
new Float_t[fNvars];
670 Float_t *xmax =
new Float_t[fNvars];
673 for (UInt_t ivar=0; ivar<fNvars; ivar++) {
674 xmin[ivar]=xmax[ivar]=0;
679 for (UInt_t iev=0; iev<eventSample.size(); iev++) {
680 const TMVA::Event* evt = eventSample[iev];
681 const Double_t weight = evt->GetWeight();
682 const Double_t orgWeight = evt->GetOriginalWeight();
683 if (evt->GetClass() == fSigClass) {
693 if ( DoRegression() ) {
694 const Double_t tgt = evt->GetTarget(0);
696 target2+=weight*tgt*tgt;
700 for (UInt_t ivar=0; ivar<fNvars; ivar++) {
701 const Double_t val = evt->GetValueFast(ivar);
702 if (iev==0) xmin[ivar]=xmax[ivar]=val;
703 if (val < xmin[ivar]) xmin[ivar]=val;
704 if (val > xmax[ivar]) xmax[ivar]=val;
710 Log() << kWARNING <<
" One of the Decision Tree nodes has negative total number of signal or background events. "
711 <<
"(Nsig="<<s<<
" Nbkg="<<b<<
" Probaby you use a Monte Carlo with negative weights. That should in principle "
712 <<
"be fine as long as on average you end up with something positive. For this you have to make sure that the "
713 <<
"minimul number of (unweighted) events demanded for a tree node (currently you use: MinNodeSize="<<fMinNodeSize
714 <<
"% of training events, you can set this via the BDT option string when booking the classifier) is large enough "
715 <<
"to allow for reasonable averaging!!!" << Endl
716 <<
" If this does not help.. maybe you want to try the option: NoNegWeightsInTraining which ignores events "
717 <<
"with negative weight in the training." << Endl;
719 for (UInt_t i=0; i<eventSample.size(); i++) {
720 if (eventSample[i]->GetClass() != fSigClass) {
721 nBkg += eventSample[i]->GetWeight();
722 Log() << kDEBUG <<
"Event "<< i<<
" has (original) weight: " << eventSample[i]->GetWeight()/eventSample[i]->GetBoostWeight()
723 <<
" boostWeight: " << eventSample[i]->GetBoostWeight() << Endl;
726 Log() << kDEBUG <<
" that gives in total: " << nBkg<<Endl;
729 node->SetNSigEvents(s);
730 node->SetNBkgEvents(b);
731 node->SetNSigEvents_unweighted(suw);
732 node->SetNBkgEvents_unweighted(buw);
733 node->SetNSigEvents_unboosted(sub);
734 node->SetNBkgEvents_unboosted(bub);
736 if (node == this->GetRoot()) {
737 node->SetNEvents(s+b);
738 node->SetNEvents_unweighted(suw+buw);
739 node->SetNEvents_unboosted(sub+bub);
743 for (UInt_t ivar=0; ivar<fNvars; ivar++) {
744 node->SetSampleMin(ivar,xmin[ivar]);
745 node->SetSampleMax(ivar,xmax[ivar]);
762 if ((eventSample.size() >= 2*fMinSize && s+b >= 2*fMinSize) && node->GetDepth() < fMaxDepth
763 && ( ( s!=0 && b !=0 && !DoRegression()) || ( (s+b)!=0 && DoRegression()) ) ) {
764 Double_t separationGain;
766 separationGain = this->TrainNodeFast(eventSample, node);
768 separationGain = this->TrainNodeFull(eventSample, node);
770 if (separationGain < std::numeric_limits<double>::epsilon()) {
774 if (DoRegression()) {
775 node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
776 node->SetResponse(target/(s+b));
777 if( almost_equal_double(target2/(s+b),target/(s+b)*target/(s+b)) ){
780 node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
784 node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
786 if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
787 else node->SetNodeType(-1);
789 if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
793 std::vector<const TMVA::Event*> leftSample; leftSample.reserve(nevents);
794 std::vector<const TMVA::Event*> rightSample; rightSample.reserve(nevents);
796 Double_t nRight=0, nLeft=0;
797 Double_t nRightUnBoosted=0, nLeftUnBoosted=0;
799 for (UInt_t ie=0; ie< nevents ; ie++) {
800 if (node->GoesRight(*eventSample[ie])) {
801 rightSample.push_back(eventSample[ie]);
802 nRight += eventSample[ie]->GetWeight();
803 nRightUnBoosted += eventSample[ie]->GetOriginalWeight();
806 leftSample.push_back(eventSample[ie]);
807 nLeft += eventSample[ie]->GetWeight();
808 nLeftUnBoosted += eventSample[ie]->GetOriginalWeight();
813 if (leftSample.empty() || rightSample.empty()) {
815 Log() << kERROR <<
"<TrainNode> all events went to the same branch" << Endl
816 <<
"--- Hence new node == old node ... check" << Endl
817 <<
"--- left:" << leftSample.size()
818 <<
" right:" << rightSample.size() << Endl
819 <<
" while the separation is thought to be " << separationGain
820 <<
"\n when cutting on variable " << node->GetSelector()
821 <<
" at value " << node->GetCutValue()
822 << kFATAL <<
"--- this should never happen, please write a bug report to Helge.Voss@cern.ch" << Endl;
826 TMVA::DecisionTreeNode *rightNode =
new TMVA::DecisionTreeNode(node,
'r');
828 rightNode->SetNEvents(nRight);
829 rightNode->SetNEvents_unboosted(nRightUnBoosted);
830 rightNode->SetNEvents_unweighted(rightSample.size());
832 TMVA::DecisionTreeNode *leftNode =
new TMVA::DecisionTreeNode(node,
'l');
835 leftNode->SetNEvents(nLeft);
836 leftNode->SetNEvents_unboosted(nLeftUnBoosted);
837 leftNode->SetNEvents_unweighted(leftSample.size());
839 node->SetNodeType(0);
840 node->SetLeft(leftNode);
841 node->SetRight(rightNode);
843 this->BuildTree(rightSample, rightNode);
844 this->BuildTree(leftSample, leftNode );
849 if (DoRegression()) {
850 node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
851 node->SetResponse(target/(s+b));
852 if( almost_equal_double(target2/(s+b), target/(s+b)*target/(s+b)) ) {
855 node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
859 node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
860 if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
861 else node->SetNodeType(-1);
870 if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
883 void TMVA::DecisionTree::FillTree(
const std::vector<TMVA::Event*> & eventSample )
885 for (UInt_t i=0; i<eventSample.size(); i++) {
886 this->FillEvent(*(eventSample[i]),NULL);
894 void TMVA::DecisionTree::FillEvent(
const TMVA::Event & event,
895 TMVA::DecisionTreeNode *node )
898 node = this->GetRoot();
901 node->IncrementNEvents( event.GetWeight() );
902 node->IncrementNEvents_unweighted( );
904 if (event.GetClass() == fSigClass) {
905 node->IncrementNSigEvents( event.GetWeight() );
906 node->IncrementNSigEvents_unweighted( );
909 node->IncrementNBkgEvents( event.GetWeight() );
910 node->IncrementNBkgEvents_unweighted( );
912 node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
913 node->GetNBkgEvents()));
915 if (node->GetNodeType() == 0) {
916 if (node->GoesRight(event))
917 this->FillEvent(event, node->GetRight());
919 this->FillEvent(event, node->GetLeft());
926 void TMVA::DecisionTree::ClearTree()
928 if (this->GetRoot()!=NULL) this->GetRoot()->ClearNodeAndAllDaughters();
940 UInt_t TMVA::DecisionTree::CleanTree( DecisionTreeNode *node )
943 node = this->GetRoot();
946 DecisionTreeNode *l = node->GetLeft();
947 DecisionTreeNode *r = node->GetRight();
949 if (node->GetNodeType() == 0) {
952 if (l->GetNodeType() * r->GetNodeType() > 0) {
954 this->PruneNode(node);
958 return this->CountNodes();
967 Double_t TMVA::DecisionTree::PruneTree(
const EventConstList* validationSample )
969 IPruneTool* tool(NULL);
970 PruningInfo* info(NULL);
972 if( fPruneMethod == kNoPruning )
return 0.0;
974 if (fPruneMethod == kExpectedErrorPruning)
976 tool =
new ExpectedErrorPruneTool();
977 else if (fPruneMethod == kCostComplexityPruning)
979 tool =
new CostComplexityPruneTool();
982 Log() << kFATAL <<
"Selected pruning method not yet implemented "
986 if(!tool)
return 0.0;
988 tool->SetPruneStrength(GetPruneStrength());
989 if(tool->IsAutomatic()) {
990 if(validationSample == NULL){
991 Log() << kFATAL <<
"Cannot automate the pruning algorithm without an "
992 <<
"independent validation sample!" << Endl;
993 }
else if(validationSample->size() == 0) {
994 Log() << kFATAL <<
"Cannot automate the pruning algorithm with "
995 <<
"independent validation sample of ZERO events!" << Endl;
999 info = tool->CalculatePruningInfo(
this,validationSample);
1000 Double_t pruneStrength=0;
1002 Log() << kFATAL <<
"Error pruning tree! Check prune.log for more information."
1005 pruneStrength = info->PruneStrength;
1011 for (UInt_t i = 0; i < info->PruneSequence.size(); ++i) {
1013 PruneNode(info->PruneSequence[i]);
1022 return pruneStrength;
1032 void TMVA::DecisionTree::ApplyValidationSample(
const EventConstList* validationSample )
const
1034 GetRoot()->ResetValidationData();
1035 for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
1036 CheckEventWithPrunedTree((*validationSample)[ievt]);
1046 Double_t TMVA::DecisionTree::TestPrunedTreeQuality(
const DecisionTreeNode* n, Int_t mode )
const
1049 n = this->GetRoot();
1051 Log() << kFATAL <<
"TestPrunedTreeQuality: started with undefined ROOT node" <<Endl;
1056 if( n->GetLeft() != NULL && n->GetRight() != NULL && !n->IsTerminal() ) {
1057 return (TestPrunedTreeQuality( n->GetLeft(), mode ) +
1058 TestPrunedTreeQuality( n->GetRight(), mode ));
1061 if (DoRegression()) {
1062 Double_t sumw = n->GetNSValidation() + n->GetNBValidation();
1063 return n->GetSumTarget2() - 2*n->GetSumTarget()*n->GetResponse() + sumw*n->GetResponse()*n->GetResponse();
1067 if (n->GetPurity() > this->GetNodePurityLimit())
1068 return n->GetNBValidation();
1070 return n->GetNSValidation();
1072 else if ( mode == 1 ) {
1074 return (n->GetPurity() * n->GetNBValidation() + (1.0 - n->GetPurity()) * n->GetNSValidation());
1077 throw std::string(
"Unknown ValidationQualityMode");
1088 void TMVA::DecisionTree::CheckEventWithPrunedTree(
const Event* e )
const
1090 DecisionTreeNode* current = this->GetRoot();
1091 if (current == NULL) {
1092 Log() << kFATAL <<
"CheckEventWithPrunedTree: started with undefined ROOT node" <<Endl;
1095 while(current != NULL) {
1096 if(e->GetClass() == fSigClass)
1097 current->SetNSValidation(current->GetNSValidation() + e->GetWeight());
1099 current->SetNBValidation(current->GetNBValidation() + e->GetWeight());
1101 if (e->GetNTargets() > 0) {
1102 current->AddToSumTarget(e->GetWeight()*e->GetTarget(0));
1103 current->AddToSumTarget2(e->GetWeight()*e->GetTarget(0)*e->GetTarget(0));
1106 if (current->GetRight() == NULL || current->GetLeft() == NULL) {
1110 if (current->GoesRight(*e))
1111 current = (TMVA::DecisionTreeNode*)current->GetRight();
1113 current = (TMVA::DecisionTreeNode*)current->GetLeft();
1121 Double_t TMVA::DecisionTree::GetSumWeights(
const EventConstList* validationSample )
const
1123 Double_t sumWeights = 0.0;
1124 for( EventConstList::const_iterator it = validationSample->begin();
1125 it != validationSample->end(); ++it ) {
1126 sumWeights += (*it)->GetWeight();
1134 UInt_t TMVA::DecisionTree::CountLeafNodes( TMVA::Node *n )
1137 n = this->GetRoot();
1139 Log() << kFATAL <<
"CountLeafNodes: started with undefined ROOT node" <<Endl;
1144 UInt_t countLeafs=0;
1146 if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
1150 if (this->GetLeftDaughter(n) != NULL) {
1151 countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
1153 if (this->GetRightDaughter(n) != NULL) {
1154 countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
1163 void TMVA::DecisionTree::DescendTree( Node* n )
1166 n = this->GetRoot();
1168 Log() << kFATAL <<
"DescendTree: started with undefined ROOT node" <<Endl;
1173 if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
1176 else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
1177 Log() << kFATAL <<
" Node with only one daughter?? Something went wrong" << Endl;
1180 else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
1181 Log() << kFATAL <<
" Node with only one daughter?? Something went wrong" << Endl;
1185 if (this->GetLeftDaughter(n) != NULL) {
1186 this->DescendTree( this->GetLeftDaughter(n) );
1188 if (this->GetRightDaughter(n) != NULL) {
1189 this->DescendTree( this->GetRightDaughter(n) );
1197 void TMVA::DecisionTree::PruneNode( DecisionTreeNode* node )
1199 DecisionTreeNode *l = node->GetLeft();
1200 DecisionTreeNode *r = node->GetRight();
1202 node->SetRight(NULL);
1203 node->SetLeft(NULL);
1204 node->SetSelector(-1);
1205 node->SetSeparationGain(-1);
1206 if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
1207 else node->SetNodeType(-1);
1208 this->DeleteNode(l);
1209 this->DeleteNode(r);
1220 void TMVA::DecisionTree::PruneNodeInPlace( DecisionTreeNode* node ) {
1221 if(node == NULL)
return;
1222 node->SetNTerminal(1);
1223 node->SetSubTreeR( node->GetNodeR() );
1224 node->SetAlpha( std::numeric_limits<double>::infinity( ) );
1225 node->SetAlphaMinSubtree( std::numeric_limits<double>::infinity( ) );
1226 node->SetTerminal(kTRUE);
1234 TMVA::Node* TMVA::DecisionTree::GetNode( ULong_t sequence, UInt_t depth )
1236 Node* current = this->GetRoot();
1238 for (UInt_t i =0; i < depth; i++) {
1239 ULong_t tmp = 1 << i;
1240 if ( tmp & sequence) current = this->GetRightDaughter(current);
1241 else current = this->GetLeftDaughter(current);
1250 void TMVA::DecisionTree::GetRandomisedVariables(Bool_t *useVariable, UInt_t *mapVariable, UInt_t &useNvars){
1251 for (UInt_t ivar=0; ivar<fNvars; ivar++) useVariable[ivar]=kFALSE;
1254 fUseNvars = UInt_t(TMath::Sqrt(fNvars)+0.6);
1256 if (fUsePoissonNvars) useNvars=TMath::Min(fNvars,TMath::Max(UInt_t(1),(UInt_t) fMyTrandom->Poisson(fUseNvars)));
1257 else useNvars = fUseNvars;
1259 UInt_t nSelectedVars = 0;
1260 while (nSelectedVars < useNvars) {
1261 Double_t bla = fMyTrandom->Rndm()*fNvars;
1262 useVariable[Int_t (bla)] = kTRUE;
1264 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1265 if (useVariable[ivar] == kTRUE) {
1266 mapVariable[nSelectedVars] = ivar;
1271 if (nSelectedVars != useNvars) { std::cout <<
"Bug in TrainNode - GetRandisedVariables()... sorry" << std::endl; std::exit(1);}
1281 struct TrainNodeInfo{
1283 TrainNodeInfo(Int_t cNvars_, UInt_t* nBins_){
1288 nSelS = std::vector< std::vector<Double_t> >(cNvars);
1289 nSelB = std::vector< std::vector<Double_t> >(cNvars);
1290 nSelS_unWeighted = std::vector< std::vector<Double_t> >(cNvars);
1291 nSelB_unWeighted = std::vector< std::vector<Double_t> >(cNvars);
1292 target = std::vector< std::vector<Double_t> >(cNvars);
1293 target2 = std::vector< std::vector<Double_t> >(cNvars);
1295 for(Int_t ivar=0; ivar<cNvars; ivar++){
1296 nSelS[ivar] = std::vector<Double_t>(nBins[ivar], 0);
1297 nSelB[ivar] = std::vector<Double_t>(nBins[ivar], 0);
1298 nSelS_unWeighted[ivar] = std::vector<Double_t>(nBins[ivar], 0);
1299 nSelB_unWeighted[ivar] = std::vector<Double_t>(nBins[ivar], 0);
1300 target[ivar] = std::vector<Double_t>(nBins[ivar], 0);
1301 target2[ivar] = std::vector<Double_t>(nBins[ivar], 0);
1319 Double_t nTotS_unWeighted = 0;
1321 Double_t nTotB_unWeighted = 0;
1323 std::vector< std::vector<Double_t> > nSelS;
1324 std::vector< std::vector<Double_t> > nSelB;
1325 std::vector< std::vector<Double_t> > nSelS_unWeighted;
1326 std::vector< std::vector<Double_t> > nSelB_unWeighted;
1327 std::vector< std::vector<Double_t> > target;
1328 std::vector< std::vector<Double_t> > target2;
1332 TrainNodeInfo operator+(
const TrainNodeInfo& other)
1334 TrainNodeInfo ret(cNvars, nBins);
1337 if(cNvars != other.cNvars)
1339 std::cout <<
"!!! ERROR TrainNodeInfo1+TrainNodeInfo2 failure. cNvars1 != cNvars2." << std::endl;
1344 for (Int_t ivar=0; ivar<cNvars; ivar++) {
1345 for (UInt_t ibin=0; ibin<nBins[ivar]; ibin++) {
1346 ret.nSelS[ivar][ibin] = nSelS[ivar][ibin] + other.nSelS[ivar][ibin];
1347 ret.nSelB[ivar][ibin] = nSelB[ivar][ibin] + other.nSelB[ivar][ibin];
1348 ret.nSelS_unWeighted[ivar][ibin] = nSelS_unWeighted[ivar][ibin] + other.nSelS_unWeighted[ivar][ibin];
1349 ret.nSelB_unWeighted[ivar][ibin] = nSelB_unWeighted[ivar][ibin] + other.nSelB_unWeighted[ivar][ibin];
1350 ret.target[ivar][ibin] = target[ivar][ibin] + other.target[ivar][ibin];
1351 ret.target2[ivar][ibin] = target2[ivar][ibin] + other.target2[ivar][ibin];
1355 ret.nTotS = nTotS + other.nTotS;
1356 ret.nTotS_unWeighted = nTotS_unWeighted + other.nTotS_unWeighted;
1357 ret.nTotB = nTotB + other.nTotB;
1358 ret.nTotB_unWeighted = nTotB_unWeighted + other.nTotB_unWeighted;
1377 Double_t TMVA::DecisionTree::TrainNodeFast(
const EventConstList & eventSample,
1378 TMVA::DecisionTreeNode *node )
1381 Double_t separationGainTotal = -1;
1382 Double_t *separationGain =
new Double_t[fNvars+1];
1383 Int_t *cutIndex =
new Int_t[fNvars+1];
1386 for (UInt_t ivar=0; ivar <= fNvars; ivar++) {
1387 separationGain[ivar]=-1;
1392 Bool_t cutType = kTRUE;
1393 UInt_t nevents = eventSample.size();
1398 Bool_t *useVariable =
new Bool_t[fNvars+1];
1399 UInt_t *mapVariable =
new UInt_t[fNvars+1];
1401 std::vector<Double_t> fisherCoeff;
1404 if (fRandomisedTree) {
1405 UInt_t tmp=fUseNvars;
1406 GetRandomisedVariables(useVariable,mapVariable,tmp);
1409 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1410 useVariable[ivar] = kTRUE;
1411 mapVariable[ivar] = ivar;
1415 useVariable[fNvars] = kFALSE;
1418 Bool_t fisherOK = kFALSE;
1419 if (fUseFisherCuts) {
1420 useVariable[fNvars] = kTRUE;
1424 Bool_t *useVarInFisher =
new Bool_t[fNvars];
1425 UInt_t *mapVarInFisher =
new UInt_t[fNvars];
1426 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1427 useVarInFisher[ivar] = kFALSE;
1428 mapVarInFisher[ivar] = ivar;
1431 std::vector<TMatrixDSym*>* covMatrices;
1432 covMatrices = gTools().CalcCovarianceMatrices( eventSample, 2 );
1434 Log() << kWARNING <<
" in TrainNodeFast, the covariance Matrices needed for the Fisher-Cuts returned error --> revert to just normal cuts for this node" << Endl;
1437 TMatrixD *ss =
new TMatrixD(*(covMatrices->at(0)));
1438 TMatrixD *bb =
new TMatrixD(*(covMatrices->at(1)));
1439 const TMatrixD *s = gTools().GetCorrelationMatrix(ss);
1440 const TMatrixD *b = gTools().GetCorrelationMatrix(bb);
1442 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1443 for (UInt_t jvar=ivar+1; jvar < fNvars; jvar++) {
1444 if ( ( TMath::Abs( (*s)(ivar, jvar)) > fMinLinCorrForFisher) ||
1445 ( TMath::Abs( (*b)(ivar, jvar)) > fMinLinCorrForFisher) ){
1446 useVarInFisher[ivar] = kTRUE;
1447 useVarInFisher[jvar] = kTRUE;
1454 UInt_t nFisherVars = 0;
1455 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1458 if (useVarInFisher[ivar] && useVariable[ivar]) {
1459 mapVarInFisher[nFisherVars++]=ivar;
1462 if (fUseExclusiveVars) useVariable[ivar] = kFALSE;
1467 fisherCoeff = this->GetFisherCoefficients(eventSample, nFisherVars, mapVarInFisher);
1470 delete [] useVarInFisher;
1471 delete [] mapVarInFisher;
1477 UInt_t cNvars = fNvars;
1478 if (fUseFisherCuts && fisherOK) cNvars++;
1482 UInt_t* nBins =
new UInt_t [cNvars];
1483 Double_t* binWidth =
new Double_t [cNvars];
1484 Double_t* invBinWidth =
new Double_t [cNvars];
1485 Double_t** cutValues =
new Double_t* [cNvars];
1489 Double_t *xmin =
new Double_t[cNvars];
1490 Double_t *xmax =
new Double_t[cNvars];
1493 for (UInt_t ivar=0; ivar<cNvars; ivar++) {
1495 nBins[ivar] = fNCuts+1;
1496 if (ivar < fNvars) {
1497 if (fDataSetInfo->GetVariableInfo(ivar).GetVarType() ==
'I') {
1498 nBins[ivar] = node->GetSampleMax(ivar) - node->GetSampleMin(ivar) + 1;
1502 cutValues[ivar] =
new Double_t [nBins[ivar]];
1506 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
1508 xmin[ivar]=node->GetSampleMin(ivar);
1509 xmax[ivar]=node->GetSampleMax(ivar);
1510 if (almost_equal_float(xmax[ivar], xmin[ivar])) {
1513 useVariable[ivar]=kFALSE;
1521 for (UInt_t iev=0; iev<nevents; iev++) {
1523 Double_t result = fisherCoeff[fNvars];
1524 for (UInt_t jvar=0; jvar<fNvars; jvar++)
1525 result += fisherCoeff[jvar]*(eventSample[iev])->GetValueFast(jvar);
1526 if (result > xmax[ivar]) xmax[ivar]=result;
1527 if (result < xmin[ivar]) xmin[ivar]=result;
1531 for (UInt_t ibin=0; ibin<nBins[ivar]; ibin++) {
1532 cutValues[ivar][ibin]=0;
1545 auto varSeeds = ROOT::TSeqU(cNvars);
1546 auto fvarInitCuts = [
this, &useVariable, &cutValues, &invBinWidth, &binWidth, &nBins, &xmin, &xmax](UInt_t ivar = 0){
1548 if ( useVariable[ivar] ) {
1561 binWidth[ivar] = ( xmax[ivar] - xmin[ivar] ) / Double_t(nBins[ivar]);
1562 invBinWidth[ivar] = 1./binWidth[ivar];
1563 if (ivar < fNvars) {
1564 if (fDataSetInfo->GetVariableInfo(ivar).GetVarType() ==
'I') { invBinWidth[ivar] = 1; binWidth[ivar] = 1; }
1572 for (UInt_t icut=0; icut<nBins[ivar]-1; icut++) {
1573 cutValues[ivar][icut]=xmin[ivar]+(Double_t(icut+1))*binWidth[ivar];
1579 TMVA::Config::Instance().GetThreadExecutor().Map(fvarInitCuts, varSeeds);
1585 TrainNodeInfo nodeInfo(cNvars, nBins);
1586 UInt_t nPartitions = TMVA::Config::Instance().GetThreadExecutor().GetPoolSize();
1590 if(eventSample.size() >= cNvars*fNCuts*nPartitions*2)
1592 auto seeds = ROOT::TSeqU(nPartitions);
1595 auto f = [
this, &eventSample, &fisherCoeff, &useVariable, &invBinWidth,
1596 &nBins, &xmin, &cNvars, &nPartitions](UInt_t partition = 0){
1598 UInt_t start = 1.0*partition/nPartitions*eventSample.size();
1599 UInt_t end = (partition+1.0)/nPartitions*eventSample.size();
1601 TrainNodeInfo nodeInfof(cNvars, nBins);
1603 for(UInt_t iev=start; iev<end; iev++) {
1605 Double_t eventWeight = eventSample[iev]->GetWeight();
1606 if (eventSample[iev]->GetClass() == fSigClass) {
1607 nodeInfof.nTotS+=eventWeight;
1608 nodeInfof.nTotS_unWeighted++; }
1610 nodeInfof.nTotB+=eventWeight;
1611 nodeInfof.nTotB_unWeighted++;
1616 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
1619 if ( useVariable[ivar] ) {
1621 if (ivar < fNvars) eventData = eventSample[iev]->GetValueFast(ivar);
1623 eventData = fisherCoeff[fNvars];
1624 for (UInt_t jvar=0; jvar<fNvars; jvar++)
1625 eventData += fisherCoeff[jvar]*(eventSample[iev])->GetValueFast(jvar);
1630 iBin = TMath::Min(Int_t(nBins[ivar]-1),TMath::Max(0,
int (invBinWidth[ivar]*(eventData-xmin[ivar]) ) ));
1631 if (eventSample[iev]->GetClass() == fSigClass) {
1632 nodeInfof.nSelS[ivar][iBin]+=eventWeight;
1633 nodeInfof.nSelS_unWeighted[ivar][iBin]++;
1636 nodeInfof.nSelB[ivar][iBin]+=eventWeight;
1637 nodeInfof.nSelB_unWeighted[ivar][iBin]++;
1639 if (DoRegression()) {
1640 nodeInfof.target[ivar][iBin] +=eventWeight*eventSample[iev]->GetTarget(0);
1641 nodeInfof.target2[ivar][iBin]+=eventWeight*eventSample[iev]->GetTarget(0)*eventSample[iev]->GetTarget(0);
1650 TrainNodeInfo nodeInfoInit(cNvars, nBins);
1653 auto redfunc = [nodeInfoInit](std::vector<TrainNodeInfo> v) -> TrainNodeInfo {
return std::accumulate(v.begin(), v.end(), nodeInfoInit); };
1654 nodeInfo = TMVA::Config::Instance().GetThreadExecutor().MapReduce(f, seeds, redfunc);
1662 auto fvarFillNodeInfo = [
this, &nodeInfo, &eventSample, &fisherCoeff, &useVariable, &invBinWidth, &nBins, &xmin](UInt_t ivar = 0){
1664 for(UInt_t iev=0; iev<eventSample.size(); iev++) {
1667 Double_t eventWeight = eventSample[iev]->GetWeight();
1671 if (eventSample[iev]->GetClass() == fSigClass) {
1672 nodeInfo.nTotS+=eventWeight;
1673 nodeInfo.nTotS_unWeighted++; }
1675 nodeInfo.nTotB+=eventWeight;
1676 nodeInfo.nTotB_unWeighted++;
1681 if ( useVariable[ivar] ) {
1683 if (ivar < fNvars) eventData = eventSample[iev]->GetValueFast(ivar);
1685 eventData = fisherCoeff[fNvars];
1686 for (UInt_t jvar=0; jvar<fNvars; jvar++)
1687 eventData += fisherCoeff[jvar]*(eventSample[iev])->GetValueFast(jvar);
1692 iBin = TMath::Min(Int_t(nBins[ivar]-1),TMath::Max(0,
int (invBinWidth[ivar]*(eventData-xmin[ivar]) ) ));
1693 if (eventSample[iev]->GetClass() == fSigClass) {
1694 nodeInfo.nSelS[ivar][iBin]+=eventWeight;
1695 nodeInfo.nSelS_unWeighted[ivar][iBin]++;
1698 nodeInfo.nSelB[ivar][iBin]+=eventWeight;
1699 nodeInfo.nSelB_unWeighted[ivar][iBin]++;
1701 if (DoRegression()) {
1702 nodeInfo.target[ivar][iBin] +=eventWeight*eventSample[iev]->GetTarget(0);
1703 nodeInfo.target2[ivar][iBin]+=eventWeight*eventSample[iev]->GetTarget(0)*eventSample[iev]->GetTarget(0);
1710 TMVA::Config::Instance().GetThreadExecutor().Map(fvarFillNodeInfo, varSeeds);
1716 auto fvarCumulative = [&nodeInfo, &useVariable, &nBins,
this, &eventSample](UInt_t ivar = 0){
1717 if (useVariable[ivar]) {
1718 for (UInt_t ibin=1; ibin < nBins[ivar]; ibin++) {
1719 nodeInfo.nSelS[ivar][ibin]+=nodeInfo.nSelS[ivar][ibin-1];
1720 nodeInfo.nSelS_unWeighted[ivar][ibin]+=nodeInfo.nSelS_unWeighted[ivar][ibin-1];
1721 nodeInfo.nSelB[ivar][ibin]+=nodeInfo.nSelB[ivar][ibin-1];
1722 nodeInfo.nSelB_unWeighted[ivar][ibin]+=nodeInfo.nSelB_unWeighted[ivar][ibin-1];
1723 if (DoRegression()) {
1724 nodeInfo.target[ivar][ibin] +=nodeInfo.target[ivar][ibin-1] ;
1725 nodeInfo.target2[ivar][ibin]+=nodeInfo.target2[ivar][ibin-1];
1728 if (nodeInfo.nSelS_unWeighted[ivar][nBins[ivar]-1] +nodeInfo.nSelB_unWeighted[ivar][nBins[ivar]-1] != eventSample.size()) {
1729 Log() << kFATAL <<
"Helge, you have a bug ....nodeInfo.nSelS_unw..+nodeInfo.nSelB_unw..= "
1730 << nodeInfo.nSelS_unWeighted[ivar][nBins[ivar]-1] +nodeInfo.nSelB_unWeighted[ivar][nBins[ivar]-1]
1731 <<
" while eventsample size = " << eventSample.size()
1734 double lastBins=nodeInfo.nSelS[ivar][nBins[ivar]-1] +nodeInfo.nSelB[ivar][nBins[ivar]-1];
1735 double totalSum=nodeInfo.nTotS+nodeInfo.nTotB;
1736 if (TMath::Abs(lastBins-totalSum)/totalSum>0.01) {
1737 Log() << kFATAL <<
"Helge, you have another bug ....nodeInfo.nSelS+nodeInfo.nSelB= "
1739 <<
" while total number of events = " << totalSum
1745 TMVA::Config::Instance().GetThreadExecutor().Map(fvarCumulative, varSeeds);
1750 auto fvarMaxSep = [&nodeInfo, &useVariable,
this, &separationGain, &cutIndex, &nBins] (UInt_t ivar = 0){
1751 if (useVariable[ivar]) {
1753 for (UInt_t iBin=0; iBin<nBins[ivar]-1; iBin++) {
1765 Double_t sl = nodeInfo.nSelS_unWeighted[ivar][iBin];
1766 Double_t bl = nodeInfo.nSelB_unWeighted[ivar][iBin];
1767 Double_t s = nodeInfo.nTotS_unWeighted;
1768 Double_t b = nodeInfo.nTotB_unWeighted;
1769 Double_t slW = nodeInfo.nSelS[ivar][iBin];
1770 Double_t blW = nodeInfo.nSelB[ivar][iBin];
1771 Double_t sW = nodeInfo.nTotS;
1772 Double_t bW = nodeInfo.nTotB;
1775 Double_t srW = sW-slW;
1776 Double_t brW = bW-blW;
1778 if ( ((sl+bl)>=fMinSize && (sr+br)>=fMinSize)
1779 && ((slW+blW)>=fMinSize && (srW+brW)>=fMinSize)
1782 if (DoRegression()) {
1783 sepTmp = fRegType->GetSeparationGain(nodeInfo.nSelS[ivar][iBin]+nodeInfo.nSelB[ivar][iBin],
1784 nodeInfo.target[ivar][iBin],nodeInfo.target2[ivar][iBin],
1785 nodeInfo.nTotS+nodeInfo.nTotB,
1786 nodeInfo.target[ivar][nBins[ivar]-1],nodeInfo.target2[ivar][nBins[ivar]-1]);
1788 sepTmp = fSepType->GetSeparationGain(nodeInfo.nSelS[ivar][iBin], nodeInfo.nSelB[ivar][iBin], nodeInfo.nTotS, nodeInfo.nTotB);
1790 if (separationGain[ivar] < sepTmp) {
1791 separationGain[ivar] = sepTmp;
1792 cutIndex[ivar] = iBin;
1799 TMVA::Config::Instance().GetThreadExecutor().Map(fvarMaxSep, varSeeds);
1802 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
1803 if (useVariable[ivar] ) {
1804 if (separationGainTotal < separationGain[ivar]) {
1805 separationGainTotal = separationGain[ivar];
1812 if (DoRegression()) {
1813 node->SetSeparationIndex(fRegType->GetSeparationIndex(nodeInfo.nTotS+nodeInfo.nTotB,nodeInfo.target[0][nBins[mxVar]-1],nodeInfo.target2[0][nBins[mxVar]-1]));
1814 node->SetResponse(nodeInfo.target[0][nBins[mxVar]-1]/(nodeInfo.nTotS+nodeInfo.nTotB));
1815 if ( almost_equal_double(nodeInfo.target2[0][nBins[mxVar]-1]/(nodeInfo.nTotS+nodeInfo.nTotB), nodeInfo.target[0][nBins[mxVar]-1]/(nodeInfo.nTotS+nodeInfo.nTotB)*nodeInfo.target[0][nBins[mxVar]-1]/(nodeInfo.nTotS+nodeInfo.nTotB))) {
1819 node->SetRMS(TMath::Sqrt(nodeInfo.target2[0][nBins[mxVar]-1]/(nodeInfo.nTotS+nodeInfo.nTotB) - nodeInfo.target[0][nBins[mxVar]-1]/(nodeInfo.nTotS+nodeInfo.nTotB)*nodeInfo.target[0][nBins[mxVar]-1]/(nodeInfo.nTotS+nodeInfo.nTotB)));
1823 node->SetSeparationIndex(fSepType->GetSeparationIndex(nodeInfo.nTotS,nodeInfo.nTotB));
1825 if (nodeInfo.nSelS[mxVar][cutIndex[mxVar]]/nodeInfo.nTotS > nodeInfo.nSelB[mxVar][cutIndex[mxVar]]/nodeInfo.nTotB) cutType=kTRUE;
1826 else cutType=kFALSE;
1829 node->SetSelector((UInt_t)mxVar);
1830 node->SetCutValue(cutValues[mxVar][cutIndex[mxVar]]);
1831 node->SetCutType(cutType);
1832 node->SetSeparationGain(separationGainTotal);
1833 if (mxVar < (Int_t) fNvars){
1834 node->SetNFisherCoeff(0);
1835 fVariableImportance[mxVar] += separationGainTotal*separationGainTotal * (nodeInfo.nTotS+nodeInfo.nTotB) * (nodeInfo.nTotS+nodeInfo.nTotB) ;
1841 node->SetNFisherCoeff(fNvars+1);
1842 for (UInt_t ivar=0; ivar<=fNvars; ivar++) {
1843 node->SetFisherCoeff(ivar,fisherCoeff[ivar]);
1846 fVariableImportance[ivar] += fisherCoeff[ivar]*fisherCoeff[ivar]*separationGainTotal*separationGainTotal * (nodeInfo.nTotS+nodeInfo.nTotB) * (nodeInfo.nTotS+nodeInfo.nTotB) ;
1852 separationGainTotal = 0;
1858 for (UInt_t i=0; i<cNvars; i++) {
1865 delete [] cutValues[i];
1876 delete [] cutValues;
1881 delete [] useVariable;
1882 delete [] mapVariable;
1884 delete [] separationGain;
1889 delete [] invBinWidth;
1891 return separationGainTotal;
1896 Double_t TMVA::DecisionTree::TrainNodeFast(
const EventConstList & eventSample,
1897 TMVA::DecisionTreeNode *node )
1900 Double_t separationGainTotal = -1, sepTmp;
1901 Double_t *separationGain =
new Double_t[fNvars+1];
1902 Int_t *cutIndex =
new Int_t[fNvars+1];
1905 for (UInt_t ivar=0; ivar <= fNvars; ivar++) {
1906 separationGain[ivar]=-1;
1911 Bool_t cutType = kTRUE;
1912 Double_t nTotS, nTotB;
1913 Int_t nTotS_unWeighted, nTotB_unWeighted;
1914 UInt_t nevents = eventSample.size();
1919 Bool_t *useVariable =
new Bool_t[fNvars+1];
1920 UInt_t *mapVariable =
new UInt_t[fNvars+1];
1922 std::vector<Double_t> fisherCoeff;
1925 if (fRandomisedTree) {
1926 UInt_t tmp=fUseNvars;
1927 GetRandomisedVariables(useVariable,mapVariable,tmp);
1930 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1931 useVariable[ivar] = kTRUE;
1932 mapVariable[ivar] = ivar;
1936 useVariable[fNvars] = kFALSE;
1939 Bool_t fisherOK = kFALSE;
1940 if (fUseFisherCuts) {
1941 useVariable[fNvars] = kTRUE;
1945 Bool_t *useVarInFisher =
new Bool_t[fNvars];
1946 UInt_t *mapVarInFisher =
new UInt_t[fNvars];
1947 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1948 useVarInFisher[ivar] = kFALSE;
1949 mapVarInFisher[ivar] = ivar;
1952 std::vector<TMatrixDSym*>* covMatrices;
1953 covMatrices = gTools().CalcCovarianceMatrices( eventSample, 2 );
1955 Log() << kWARNING <<
" in TrainNodeFast, the covariance Matrices needed for the Fisher-Cuts returned error --> revert to just normal cuts for this node" << Endl;
1958 TMatrixD *ss =
new TMatrixD(*(covMatrices->at(0)));
1959 TMatrixD *bb =
new TMatrixD(*(covMatrices->at(1)));
1960 const TMatrixD *s = gTools().GetCorrelationMatrix(ss);
1961 const TMatrixD *b = gTools().GetCorrelationMatrix(bb);
1963 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1964 for (UInt_t jvar=ivar+1; jvar < fNvars; jvar++) {
1965 if ( ( TMath::Abs( (*s)(ivar, jvar)) > fMinLinCorrForFisher) ||
1966 ( TMath::Abs( (*b)(ivar, jvar)) > fMinLinCorrForFisher) ){
1967 useVarInFisher[ivar] = kTRUE;
1968 useVarInFisher[jvar] = kTRUE;
1975 UInt_t nFisherVars = 0;
1976 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1979 if (useVarInFisher[ivar] && useVariable[ivar]) {
1980 mapVarInFisher[nFisherVars++]=ivar;
1983 if (fUseExclusiveVars) useVariable[ivar] = kFALSE;
1988 fisherCoeff = this->GetFisherCoefficients(eventSample, nFisherVars, mapVarInFisher);
1991 delete [] useVarInFisher;
1992 delete [] mapVarInFisher;
1998 UInt_t cNvars = fNvars;
1999 if (fUseFisherCuts && fisherOK) cNvars++;
2003 UInt_t* nBins =
new UInt_t [cNvars];
2004 Double_t* binWidth =
new Double_t [cNvars];
2005 Double_t* invBinWidth =
new Double_t [cNvars];
2007 Double_t** nSelS =
new Double_t* [cNvars];
2008 Double_t** nSelB =
new Double_t* [cNvars];
2009 Double_t** nSelS_unWeighted =
new Double_t* [cNvars];
2010 Double_t** nSelB_unWeighted =
new Double_t* [cNvars];
2011 Double_t** target =
new Double_t* [cNvars];
2012 Double_t** target2 =
new Double_t* [cNvars];
2013 Double_t** cutValues =
new Double_t* [cNvars];
2016 for (UInt_t ivar=0; ivar<cNvars; ivar++) {
2018 nBins[ivar] = fNCuts+1;
2019 if (ivar < fNvars) {
2020 if (fDataSetInfo->GetVariableInfo(ivar).GetVarType() ==
'I') {
2021 nBins[ivar] = node->GetSampleMax(ivar) - node->GetSampleMin(ivar) + 1;
2027 nSelS[ivar] =
new Double_t [nBins[ivar]];
2028 nSelB[ivar] =
new Double_t [nBins[ivar]];
2029 nSelS_unWeighted[ivar] =
new Double_t [nBins[ivar]];
2030 nSelB_unWeighted[ivar] =
new Double_t [nBins[ivar]];
2031 target[ivar] =
new Double_t [nBins[ivar]];
2032 target2[ivar] =
new Double_t [nBins[ivar]];
2033 cutValues[ivar] =
new Double_t [nBins[ivar]];
2038 Double_t *xmin =
new Double_t[cNvars];
2039 Double_t *xmax =
new Double_t[cNvars];
2042 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
2044 xmin[ivar]=node->GetSampleMin(ivar);
2045 xmax[ivar]=node->GetSampleMax(ivar);
2046 if (almost_equal_float(xmax[ivar], xmin[ivar])) {
2049 useVariable[ivar]=kFALSE;
2057 for (UInt_t iev=0; iev<nevents; iev++) {
2059 Double_t result = fisherCoeff[fNvars];
2060 for (UInt_t jvar=0; jvar<fNvars; jvar++)
2061 result += fisherCoeff[jvar]*(eventSample[iev])->GetValueFast(jvar);
2062 if (result > xmax[ivar]) xmax[ivar]=result;
2063 if (result < xmin[ivar]) xmin[ivar]=result;
2066 for (UInt_t ibin=0; ibin<nBins[ivar]; ibin++) {
2067 nSelS[ivar][ibin]=0;
2068 nSelB[ivar][ibin]=0;
2069 nSelS_unWeighted[ivar][ibin]=0;
2070 nSelB_unWeighted[ivar][ibin]=0;
2071 target[ivar][ibin]=0;
2072 target2[ivar][ibin]=0;
2073 cutValues[ivar][ibin]=0;
2080 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
2082 if ( useVariable[ivar] ) {
2095 binWidth[ivar] = ( xmax[ivar] - xmin[ivar] ) / Double_t(nBins[ivar]);
2096 invBinWidth[ivar] = 1./binWidth[ivar];
2097 if (ivar < fNvars) {
2098 if (fDataSetInfo->GetVariableInfo(ivar).GetVarType() ==
'I') { invBinWidth[ivar] = 1; binWidth[ivar] = 1; }
2106 for (UInt_t icut=0; icut<nBins[ivar]-1; icut++) {
2107 cutValues[ivar][icut]=xmin[ivar]+(Double_t(icut+1))*binWidth[ivar];
2115 nTotS_unWeighted=0; nTotB_unWeighted=0;
2116 for (UInt_t iev=0; iev<nevents; iev++) {
2118 Double_t eventWeight = eventSample[iev]->GetWeight();
2119 if (eventSample[iev]->GetClass() == fSigClass) {
2121 nTotS_unWeighted++; }
2129 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
2132 if ( useVariable[ivar] ) {
2134 if (ivar < fNvars) eventData = eventSample[iev]->GetValueFast(ivar);
2136 eventData = fisherCoeff[fNvars];
2137 for (UInt_t jvar=0; jvar<fNvars; jvar++)
2138 eventData += fisherCoeff[jvar]*(eventSample[iev])->GetValueFast(jvar);
2143 iBin = TMath::Min(Int_t(nBins[ivar]-1),TMath::Max(0,
int (invBinWidth[ivar]*(eventData-xmin[ivar]) ) ));
2144 if (eventSample[iev]->GetClass() == fSigClass) {
2145 nSelS[ivar][iBin]+=eventWeight;
2146 nSelS_unWeighted[ivar][iBin]++;
2149 nSelB[ivar][iBin]+=eventWeight;
2150 nSelB_unWeighted[ivar][iBin]++;
2152 if (DoRegression()) {
2153 target[ivar][iBin] +=eventWeight*eventSample[iev]->GetTarget(0);
2154 target2[ivar][iBin]+=eventWeight*eventSample[iev]->GetTarget(0)*eventSample[iev]->GetTarget(0);
2161 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
2162 if (useVariable[ivar]) {
2163 for (UInt_t ibin=1; ibin < nBins[ivar]; ibin++) {
2164 nSelS[ivar][ibin]+=nSelS[ivar][ibin-1];
2165 nSelS_unWeighted[ivar][ibin]+=nSelS_unWeighted[ivar][ibin-1];
2166 nSelB[ivar][ibin]+=nSelB[ivar][ibin-1];
2167 nSelB_unWeighted[ivar][ibin]+=nSelB_unWeighted[ivar][ibin-1];
2168 if (DoRegression()) {
2169 target[ivar][ibin] +=target[ivar][ibin-1] ;
2170 target2[ivar][ibin]+=target2[ivar][ibin-1];
2173 if (nSelS_unWeighted[ivar][nBins[ivar]-1] +nSelB_unWeighted[ivar][nBins[ivar]-1] != eventSample.size()) {
2174 Log() << kFATAL <<
"Helge, you have a bug ....nSelS_unw..+nSelB_unw..= "
2175 << nSelS_unWeighted[ivar][nBins[ivar]-1] +nSelB_unWeighted[ivar][nBins[ivar]-1]
2176 <<
" while eventsample size = " << eventSample.size()
2179 double lastBins=nSelS[ivar][nBins[ivar]-1] +nSelB[ivar][nBins[ivar]-1];
2180 double totalSum=nTotS+nTotB;
2181 if (TMath::Abs(lastBins-totalSum)/totalSum>0.01) {
2182 Log() << kFATAL <<
"Helge, you have another bug ....nSelS+nSelB= "
2184 <<
" while total number of events = " << totalSum
2192 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
2193 if (useVariable[ivar]) {
2194 for (UInt_t iBin=0; iBin<nBins[ivar]-1; iBin++) {
2206 Double_t sl = nSelS_unWeighted[ivar][iBin];
2207 Double_t bl = nSelB_unWeighted[ivar][iBin];
2208 Double_t s = nTotS_unWeighted;
2209 Double_t b = nTotB_unWeighted;
2210 Double_t slW = nSelS[ivar][iBin];
2211 Double_t blW = nSelB[ivar][iBin];
2212 Double_t sW = nTotS;
2213 Double_t bW = nTotB;
2216 Double_t srW = sW-slW;
2217 Double_t brW = bW-blW;
2219 if ( ((sl+bl)>=fMinSize && (sr+br)>=fMinSize)
2220 && ((slW+blW)>=fMinSize && (srW+brW)>=fMinSize)
2223 if (DoRegression()) {
2224 sepTmp = fRegType->GetSeparationGain(nSelS[ivar][iBin]+nSelB[ivar][iBin],
2225 target[ivar][iBin],target2[ivar][iBin],
2227 target[ivar][nBins[ivar]-1],target2[ivar][nBins[ivar]-1]);
2229 sepTmp = fSepType->GetSeparationGain(nSelS[ivar][iBin], nSelB[ivar][iBin], nTotS, nTotB);
2231 if (separationGain[ivar] < sepTmp) {
2232 separationGain[ivar] = sepTmp;
2233 cutIndex[ivar] = iBin;
2241 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
2242 if (useVariable[ivar] ) {
2243 if (separationGainTotal < separationGain[ivar]) {
2244 separationGainTotal = separationGain[ivar];
2251 if (DoRegression()) {
2252 node->SetSeparationIndex(fRegType->GetSeparationIndex(nTotS+nTotB,target[0][nBins[mxVar]-1],target2[0][nBins[mxVar]-1]));
2253 node->SetResponse(target[0][nBins[mxVar]-1]/(nTotS+nTotB));
2254 if ( almost_equal_double(target2[0][nBins[mxVar]-1]/(nTotS+nTotB), target[0][nBins[mxVar]-1]/(nTotS+nTotB)*target[0][nBins[mxVar]-1]/(nTotS+nTotB))) {
2257 node->SetRMS(TMath::Sqrt(target2[0][nBins[mxVar]-1]/(nTotS+nTotB) - target[0][nBins[mxVar]-1]/(nTotS+nTotB)*target[0][nBins[mxVar]-1]/(nTotS+nTotB)));
2261 node->SetSeparationIndex(fSepType->GetSeparationIndex(nTotS,nTotB));
2263 if (nSelS[mxVar][cutIndex[mxVar]]/nTotS > nSelB[mxVar][cutIndex[mxVar]]/nTotB) cutType=kTRUE;
2264 else cutType=kFALSE;
2267 node->SetSelector((UInt_t)mxVar);
2268 node->SetCutValue(cutValues[mxVar][cutIndex[mxVar]]);
2269 node->SetCutType(cutType);
2270 node->SetSeparationGain(separationGainTotal);
2271 if (mxVar < (Int_t) fNvars){
2272 node->SetNFisherCoeff(0);
2273 fVariableImportance[mxVar] += separationGainTotal*separationGainTotal * (nTotS+nTotB) * (nTotS+nTotB) ;
2279 node->SetNFisherCoeff(fNvars+1);
2280 for (UInt_t ivar=0; ivar<=fNvars; ivar++) {
2281 node->SetFisherCoeff(ivar,fisherCoeff[ivar]);
2284 fVariableImportance[ivar] += fisherCoeff[ivar]*fisherCoeff[ivar]*separationGainTotal*separationGainTotal * (nTotS+nTotB) * (nTotS+nTotB) ;
2290 separationGainTotal = 0;
2306 for (UInt_t i=0; i<cNvars; i++) {
2309 delete [] nSelS_unWeighted[i];
2310 delete [] nSelB_unWeighted[i];
2311 delete [] target[i];
2312 delete [] target2[i];
2313 delete [] cutValues[i];
2317 delete [] nSelS_unWeighted;
2318 delete [] nSelB_unWeighted;
2321 delete [] cutValues;
2326 delete [] useVariable;
2327 delete [] mapVariable;
2329 delete [] separationGain;
2334 delete [] invBinWidth;
2336 return separationGainTotal;
2345 std::vector<Double_t> TMVA::DecisionTree::GetFisherCoefficients(
const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher){
2346 std::vector<Double_t> fisherCoeff(fNvars+1);
2350 TMatrixD* meanMatx =
new TMatrixD( nFisherVars, 3 );
2353 TMatrixD* betw =
new TMatrixD( nFisherVars, nFisherVars );
2354 TMatrixD* with =
new TMatrixD( nFisherVars, nFisherVars );
2355 TMatrixD* cov =
new TMatrixD( nFisherVars, nFisherVars );
2362 Double_t sumOfWeightsS = 0;
2363 Double_t sumOfWeightsB = 0;
2367 Double_t* sumS =
new Double_t[nFisherVars];
2368 Double_t* sumB =
new Double_t[nFisherVars];
2369 for (UInt_t ivar=0; ivar<nFisherVars; ivar++) { sumS[ivar] = sumB[ivar] = 0; }
2371 UInt_t nevents = eventSample.size();
2373 for (UInt_t ievt=0; ievt<nevents; ievt++) {
2376 const Event * ev = eventSample[ievt];
2379 Double_t weight = ev->GetWeight();
2380 if (ev->GetClass() == fSigClass) sumOfWeightsS += weight;
2381 else sumOfWeightsB += weight;
2383 Double_t* sum = ev->GetClass() == fSigClass ? sumS : sumB;
2384 for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
2385 sum[ivar] += ev->GetValueFast( mapVarInFisher[ivar] )*weight;
2388 for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
2389 (*meanMatx)( ivar, 2 ) = sumS[ivar];
2390 (*meanMatx)( ivar, 0 ) = sumS[ivar]/sumOfWeightsS;
2392 (*meanMatx)( ivar, 2 ) += sumB[ivar];
2393 (*meanMatx)( ivar, 1 ) = sumB[ivar]/sumOfWeightsB;
2396 (*meanMatx)( ivar, 2 ) /= (sumOfWeightsS + sumOfWeightsB);
2408 assert( sumOfWeightsS > 0 && sumOfWeightsB > 0 );
2412 const Int_t nFisherVars2 = nFisherVars*nFisherVars;
2413 Double_t *sum2Sig =
new Double_t[nFisherVars2];
2414 Double_t *sum2Bgd =
new Double_t[nFisherVars2];
2415 Double_t *xval =
new Double_t[nFisherVars2];
2416 memset(sum2Sig,0,nFisherVars2*
sizeof(Double_t));
2417 memset(sum2Bgd,0,nFisherVars2*
sizeof(Double_t));
2420 for (UInt_t ievt=0; ievt<nevents; ievt++) {
2424 const Event* ev = eventSample.at(ievt);
2426 Double_t weight = ev->GetWeight();
2428 for (UInt_t x=0; x<nFisherVars; x++) {
2429 xval[x] = ev->GetValueFast( mapVarInFisher[x] );
2432 for (UInt_t x=0; x<nFisherVars; x++) {
2433 for (UInt_t y=0; y<nFisherVars; y++) {
2434 if ( ev->GetClass() == fSigClass ) sum2Sig[k] += ( (xval[x] - (*meanMatx)(x, 0))*(xval[y] - (*meanMatx)(y, 0)) )*weight;
2435 else sum2Bgd[k] += ( (xval[x] - (*meanMatx)(x, 1))*(xval[y] - (*meanMatx)(y, 1)) )*weight;
2441 for (UInt_t x=0; x<nFisherVars; x++) {
2442 for (UInt_t y=0; y<nFisherVars; y++) {
2443 (*with)(x, y) = sum2Sig[k]/sumOfWeightsS + sum2Bgd[k]/sumOfWeightsB;
2458 Double_t prodSig, prodBgd;
2460 for (UInt_t x=0; x<nFisherVars; x++) {
2461 for (UInt_t y=0; y<nFisherVars; y++) {
2463 prodSig = ( ((*meanMatx)(x, 0) - (*meanMatx)(x, 2))*
2464 ((*meanMatx)(y, 0) - (*meanMatx)(y, 2)) );
2465 prodBgd = ( ((*meanMatx)(x, 1) - (*meanMatx)(x, 2))*
2466 ((*meanMatx)(y, 1) - (*meanMatx)(y, 2)) );
2468 (*betw)(x, y) = (sumOfWeightsS*prodSig + sumOfWeightsB*prodBgd) / (sumOfWeightsS + sumOfWeightsB);
2475 for (UInt_t x=0; x<nFisherVars; x++)
2476 for (UInt_t y=0; y<nFisherVars; y++)
2477 (*cov)(x, y) = (*with)(x, y) + (*betw)(x, y);
2487 TMatrixD* theMat = with;
2490 TMatrixD invCov( *theMat );
2491 if ( TMath::Abs(invCov.Determinant()) < 10E-24 ) {
2492 Log() << kWARNING <<
"FisherCoeff matrix is almost singular with determinant="
2493 << TMath::Abs(invCov.Determinant())
2494 <<
" did you use the variables that are linear combinations or highly correlated?"
2497 if ( TMath::Abs(invCov.Determinant()) < 10E-120 ) {
2498 Log() << kFATAL <<
"FisherCoeff matrix is singular with determinant="
2499 << TMath::Abs(invCov.Determinant())
2500 <<
" did you use the variables that are linear combinations?"
2507 Double_t xfact = TMath::Sqrt( sumOfWeightsS*sumOfWeightsB ) / (sumOfWeightsS + sumOfWeightsB);
2510 std::vector<Double_t> diffMeans( nFisherVars );
2512 for (UInt_t ivar=0; ivar<=fNvars; ivar++) fisherCoeff[ivar] = 0;
2513 for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
2514 for (UInt_t jvar=0; jvar<nFisherVars; jvar++) {
2515 Double_t d = (*meanMatx)(jvar, 0) - (*meanMatx)(jvar, 1);
2516 fisherCoeff[mapVarInFisher[ivar]] += invCov(ivar, jvar)*d;
2520 fisherCoeff[mapVarInFisher[ivar]] *= xfact;
2525 for (UInt_t ivar=0; ivar<nFisherVars; ivar++){
2526 f0 += fisherCoeff[mapVarInFisher[ivar]]*((*meanMatx)(ivar, 0) + (*meanMatx)(ivar, 1));
2530 fisherCoeff[fNvars] = f0;
2539 Double_t TMVA::DecisionTree::TrainNodeFull(
const EventConstList & eventSample,
2540 TMVA::DecisionTreeNode *node )
2542 Double_t nTotS = 0.0, nTotB = 0.0;
2543 Int_t nTotS_unWeighted = 0, nTotB_unWeighted = 0;
2545 std::vector<TMVA::BDTEventWrapper> bdtEventSample;
2549 std::vector<Double_t> lCutValue( fNvars, 0.0 );
2550 std::vector<Double_t> lSepGain( fNvars, -1.0e6 );
2551 std::vector<Char_t> lCutType( fNvars );
2552 lCutType.assign( fNvars, Char_t(kFALSE) );
2556 for( std::vector<const TMVA::Event*>::const_iterator it = eventSample.begin(); it != eventSample.end(); ++it ) {
2557 if((*it)->GetClass() == fSigClass) {
2558 nTotS += (*it)->GetWeight();
2562 nTotB += (*it)->GetWeight();
2565 bdtEventSample.push_back(TMVA::BDTEventWrapper(*it));
2568 std::vector<Char_t> useVariable(fNvars);
2569 useVariable.assign( fNvars, Char_t(kTRUE) );
2571 for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar]=Char_t(kFALSE);
2572 if (fRandomisedTree) {
2573 if (fUseNvars ==0 ) {
2575 fUseNvars = UInt_t(TMath::Sqrt(fNvars)+0.6);
2577 Int_t nSelectedVars = 0;
2578 while (nSelectedVars < fUseNvars) {
2579 Double_t bla = fMyTrandom->Rndm()*fNvars;
2580 useVariable[Int_t (bla)] = Char_t(kTRUE);
2582 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
2583 if(useVariable[ivar] == Char_t(kTRUE)) nSelectedVars++;
2588 for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar] = Char_t(kTRUE);
2590 for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) {
2591 if(!useVariable[ivar])
continue;
2593 TMVA::BDTEventWrapper::SetVarIndex(ivar);
2595 std::sort( bdtEventSample.begin(),bdtEventSample.end() );
2598 Double_t bkgWeightCtr = 0.0, sigWeightCtr = 0.0;
2600 std::vector<TMVA::BDTEventWrapper>::iterator it = bdtEventSample.begin(), it_end = bdtEventSample.end();
2601 for( ; it != it_end; ++it ) {
2602 if((**it)->GetClass() == fSigClass )
2603 sigWeightCtr += (**it)->GetWeight();
2605 bkgWeightCtr += (**it)->GetWeight();
2607 it->SetCumulativeWeight(
false,bkgWeightCtr);
2608 it->SetCumulativeWeight(
true,sigWeightCtr);
2611 const Double_t fPMin = 1.0e-6;
2612 Bool_t cutType = kFALSE;
2614 Double_t separationGain = -1.0, sepTmp = 0.0, cutValue = 0.0, dVal = 0.0, norm = 0.0;
2617 for( it = bdtEventSample.begin(); it != it_end; ++it ) {
2618 if( index == 0 ) { ++index;
continue; }
2619 if( *(*it) == NULL ) {
2620 Log() << kFATAL <<
"In TrainNodeFull(): have a null event! Where index="
2621 << index <<
", and parent node=" << node->GetParent() << Endl;
2624 dVal = bdtEventSample[index].GetVal() - bdtEventSample[index-1].GetVal();
2625 norm = TMath::Abs(bdtEventSample[index].GetVal() + bdtEventSample[index-1].GetVal());
2628 if( index >= fMinSize && (nTotS_unWeighted + nTotB_unWeighted) - index >= fMinSize && TMath::Abs(dVal/(0.5*norm + 1)) > fPMin ) {
2630 sepTmp = fSepType->GetSeparationGain( it->GetCumulativeWeight(
true), it->GetCumulativeWeight(
false), sigWeightCtr, bkgWeightCtr );
2631 if( sepTmp > separationGain ) {
2632 separationGain = sepTmp;
2633 cutValue = it->GetVal() - 0.5*dVal;
2634 Double_t nSelS = it->GetCumulativeWeight(
true);
2635 Double_t nSelB = it->GetCumulativeWeight(
false);
2638 if( nSelS/sigWeightCtr > nSelB/bkgWeightCtr ) cutType = kTRUE;
2639 else cutType = kFALSE;
2644 lCutType[ivar] = Char_t(cutType);
2645 lCutValue[ivar] = cutValue;
2646 lSepGain[ivar] = separationGain;
2648 Double_t separationGain = -1.0;
2649 Int_t iVarIndex = -1;
2650 for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) {
2651 if( lSepGain[ivar] > separationGain ) {
2653 separationGain = lSepGain[ivar];
2658 if(iVarIndex >= 0) {
2659 node->SetSelector(iVarIndex);
2660 node->SetCutValue(lCutValue[iVarIndex]);
2661 node->SetSeparationGain(lSepGain[iVarIndex]);
2662 node->SetCutType(lCutType[iVarIndex]);
2663 fVariableImportance[iVarIndex] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB);
2666 separationGain = 0.0;
2669 return separationGain;
2676 TMVA::DecisionTreeNode* TMVA::DecisionTree::GetEventNode(
const TMVA::Event & e)
const
2678 TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
2679 while(current->GetNodeType() == 0) {
2680 current = (current->GoesRight(e)) ?
2681 (TMVA::DecisionTreeNode*)current->GetRight() :
2682 (TMVA::DecisionTreeNode*)current->GetLeft();
2693 Double_t TMVA::DecisionTree::CheckEvent(
const TMVA::Event * e, Bool_t UseYesNoLeaf )
const
2695 TMVA::DecisionTreeNode *current = this->GetRoot();
2697 Log() << kFATAL <<
"CheckEvent: started with undefined ROOT node" <<Endl;
2701 while (current->GetNodeType() == 0) {
2702 current = (current->GoesRight(*e)) ?
2703 current->GetRight() :
2706 Log() << kFATAL <<
"DT::CheckEvent: inconsistent tree structure" <<Endl;
2711 if (DoRegression()) {
2715 return current->GetResponse();
2717 if (UseYesNoLeaf)
return Double_t ( current->GetNodeType() );
2718 else return current->GetPurity();
2725 Double_t TMVA::DecisionTree::SamplePurity( std::vector<TMVA::Event*> eventSample )
2727 Double_t sumsig=0, sumbkg=0, sumtot=0;
2728 for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
2729 if (eventSample[ievt]->GetClass() != fSigClass) sumbkg+=eventSample[ievt]->GetWeight();
2730 else sumsig+=eventSample[ievt]->GetWeight();
2731 sumtot+=eventSample[ievt]->GetWeight();
2734 if (sumtot!= (sumsig+sumbkg)){
2735 Log() << kFATAL <<
"<SamplePurity> sumtot != sumsig+sumbkg"
2736 << sumtot <<
" " << sumsig <<
" " << sumbkg << Endl;
2738 if (sumtot>0)
return sumsig/(sumsig + sumbkg);
2748 vector< Double_t > TMVA::DecisionTree::GetVariableImportance()
2750 std::vector<Double_t> relativeImportance(fNvars);
2752 for (UInt_t i=0; i< fNvars; i++) {
2753 sum += fVariableImportance[i];
2754 relativeImportance[i] = fVariableImportance[i];
2757 for (UInt_t i=0; i< fNvars; i++) {
2758 if (sum > std::numeric_limits<double>::epsilon())
2759 relativeImportance[i] /= sum;
2761 relativeImportance[i] = 0;
2763 return relativeImportance;
2769 Double_t TMVA::DecisionTree::GetVariableImportance( UInt_t ivar )
2771 std::vector<Double_t> relativeImportance = this->GetVariableImportance();
2772 if (ivar < fNvars)
return relativeImportance[ivar];
2774 Log() << kFATAL <<
"<GetVariableImportance>" << Endl
2775 <<
"--- ivar = " << ivar <<
" is out of range " << Endl;