Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
RuleFit.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Rule *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * A class describing a 'rule' *
12  * Each internal node of a tree defines a rule from all the parental nodes. *
13  * A rule with 0 or 1 nodes in the list is a root rule -> corresponds to a0. *
14  * Input: a decision tree (in the constructor) *
15  * its coefficient *
16  * *
17  * *
18  * Authors (alphabetical): *
19  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
20  * *
21  * Copyright (c) 2005: *
22  * CERN, Switzerland *
23  * Iowa State U. *
24  * MPI-K Heidelberg, Germany *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://tmva.sourceforge.net/LICENSE) *
29  **********************************************************************************/
30 
31 /*! \class TMVA::RuleFit
32 \ingroup TMVA
33 A class implementing various fits of rule ensembles
34 */
35 #include "TMVA/RuleFit.h"
36 
37 #include "TMVA/DataSet.h"
38 #include "TMVA/DecisionTree.h"
39 #include "TMVA/Event.h"
40 #include "TMVA/Factory.h" // for root base dir
41 #include "TMVA/GiniIndex.h"
42 #include "TMVA/MethodBase.h"
43 #include "TMVA/MethodRuleFit.h"
44 #include "TMVA/MsgLogger.h"
45 #include "TMVA/Timer.h"
46 #include "TMVA/Tools.h"
47 #include "TMVA/Types.h"
48 #include "TMVA/SeparationBase.h"
49 
50 #include "TDirectory.h"
51 #include "TH2F.h"
52 #include "TFile.h"
53 #include "TKey.h"
54 #include "TRandom3.h"
55 #include "TROOT.h" // for gROOT
56 
57 #include <algorithm>
58 #include <random>
59 
60 ClassImp(TMVA::RuleFit);
61 
62 ////////////////////////////////////////////////////////////////////////////////
63 /// constructor
64 
65 TMVA::RuleFit::RuleFit(const MethodBase *rfbase)
66  : fVisHistsUseImp( kTRUE )
67  , fLogger(new MsgLogger("RuleFit"))
68 {
69  Initialize(rfbase);
70  fRNGEngine.seed(randSEED);
71 }
72 
73 ////////////////////////////////////////////////////////////////////////////////
74 /// default constructor
75 
76 TMVA::RuleFit::RuleFit()
77  : fNTreeSample(0)
78  , fNEveEffTrain(0)
79  , fMethodRuleFit(0)
80  , fMethodBase(0)
81  , fVisHistsUseImp(kTRUE)
82  , fLogger(new MsgLogger("RuleFit"))
83 {
84  fRNGEngine.seed(randSEED);
85 }
86 
87 ////////////////////////////////////////////////////////////////////////////////
88 /// destructor
89 
90 TMVA::RuleFit::~RuleFit()
91 {
92  delete fLogger;
93 }
94 
95 ////////////////////////////////////////////////////////////////////////////////
96 /// init effective number of events (using event weights)
97 
98 void TMVA::RuleFit::InitNEveEff()
99 {
100  UInt_t neve = fTrainingEvents.size();
101  if (neve==0) return;
102  //
103  fNEveEffTrain = CalcWeightSum( &fTrainingEvents );
104  //
105 }
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 /// initialize pointers
109 
110 void TMVA::RuleFit::InitPtrs( const MethodBase *rfbase )
111 {
112  this->SetMethodBase(rfbase);
113  fRuleEnsemble.Initialize( this );
114  fRuleFitParams.SetRuleFit( this );
115 }
116 
117 ////////////////////////////////////////////////////////////////////////////////
118 /// initialize the parameters of the RuleFit method and make rules
119 
120 void TMVA::RuleFit::Initialize( const MethodBase *rfbase )
121 {
122  InitPtrs(rfbase);
123 
124  if (fMethodRuleFit){
125  fMethodRuleFit->Data()->SetCurrentType(Types::kTraining);
126  UInt_t nevents = fMethodRuleFit->Data()->GetNTrainingEvents();
127  std::vector<const TMVA::Event*> tmp;
128  for (Long64_t ievt=0; ievt<nevents; ievt++) {
129  const Event *event = fMethodRuleFit->GetEvent(ievt);
130  tmp.push_back(event);
131  }
132  SetTrainingEvents( tmp );
133  }
134  // SetTrainingEvents( fMethodRuleFit->GetTrainingEvents() );
135 
136  InitNEveEff();
137 
138  MakeForest();
139 
140  // Make the model - Rule + Linear (if fDoLinear is true)
141  fRuleEnsemble.MakeModel();
142 
143  // init rulefit params
144  fRuleFitParams.Init();
145 
146 }
147 
148 ////////////////////////////////////////////////////////////////////////////////
149 /// set MethodBase
150 
151 void TMVA::RuleFit::SetMethodBase( const MethodBase *rfbase )
152 {
153  fMethodBase = rfbase;
154  fMethodRuleFit = dynamic_cast<const MethodRuleFit *>(rfbase);
155 }
156 
157 ////////////////////////////////////////////////////////////////////////////////
158 /// copy method
159 
160 void TMVA::RuleFit::Copy( const RuleFit& other )
161 {
162  if(this != &other) {
163  fMethodRuleFit = other.GetMethodRuleFit();
164  fMethodBase = other.GetMethodBase();
165  fTrainingEvents = other.GetTrainingEvents();
166  // fSubsampleEvents = other.GetSubsampleEvents();
167 
168  fForest = other.GetForest();
169  fRuleEnsemble = other.GetRuleEnsemble();
170  }
171 }
172 
173 ////////////////////////////////////////////////////////////////////////////////
174 /// calculate the sum of weights
175 
176 Double_t TMVA::RuleFit::CalcWeightSum( const std::vector<const Event *> *events, UInt_t neve )
177 {
178  if (events==0) return 0.0;
179  if (neve==0) neve=events->size();
180  //
181  Double_t sumw=0;
182  for (UInt_t ie=0; ie<neve; ie++) {
183  sumw += ((*events)[ie])->GetWeight();
184  }
185  return sumw;
186 }
187 
188 ////////////////////////////////////////////////////////////////////////////////
189 /// set the current message type to that of mlog for this class and all other subtools
190 
191 void TMVA::RuleFit::SetMsgType( EMsgType t )
192 {
193  fLogger->SetMinType(t);
194  fRuleEnsemble.SetMsgType(t);
195  fRuleFitParams.SetMsgType(t);
196 }
197 
198 ////////////////////////////////////////////////////////////////////////////////
199 /// build the decision tree using fNTreeSample events from fTrainingEventsRndm
200 
201 void TMVA::RuleFit::BuildTree( DecisionTree *dt )
202 {
203  if (dt==0) return;
204  if (fMethodRuleFit==0) {
205  Log() << kFATAL << "RuleFit::BuildTree() - Attempting to build a tree NOT from a MethodRuleFit" << Endl;
206  }
207  std::vector<const Event *> evevec;
208  for (UInt_t ie=0; ie<fNTreeSample; ie++) {
209  evevec.push_back(fTrainingEventsRndm[ie]);
210  }
211  dt->BuildTree(evevec);
212  if (fMethodRuleFit->GetPruneMethod() != DecisionTree::kNoPruning) {
213  dt->SetPruneMethod(fMethodRuleFit->GetPruneMethod());
214  dt->SetPruneStrength(fMethodRuleFit->GetPruneStrength());
215  dt->PruneTree();
216  }
217 }
218 
219 ////////////////////////////////////////////////////////////////////////////////
220 /// make a forest of decisiontrees
221 
222 void TMVA::RuleFit::MakeForest()
223 {
224  if (fMethodRuleFit==0) {
225  Log() << kFATAL << "RuleFit::BuildTree() - Attempting to build a tree NOT from a MethodRuleFit" << Endl;
226  }
227  Log() << kDEBUG << "Creating a forest with " << fMethodRuleFit->GetNTrees() << " decision trees" << Endl;
228  Log() << kDEBUG << "Each tree is built using a random subsample with " << fNTreeSample << " events" << Endl;
229  //
230  Timer timer( fMethodRuleFit->GetNTrees(), "RuleFit" );
231 
232  // Double_t fsig;
233  Int_t nsig,nbkg;
234  //
235  TRandom3 rndGen;
236  //
237  // First save all event weights.
238  // Weights are modified by the boosting.
239  // Those weights we do not want for the later fitting.
240  //
241  Bool_t useBoost = fMethodRuleFit->UseBoost(); // (AdaBoost (True) or RandomForest/Tree (False)
242 
243  if (useBoost) SaveEventWeights();
244 
245  for (Int_t i=0; i<fMethodRuleFit->GetNTrees(); i++) {
246  // timer.DrawProgressBar(i);
247  if (!useBoost) ReshuffleEvents();
248  nsig=0;
249  nbkg=0;
250  for (UInt_t ie = 0; ie<fNTreeSample; ie++) {
251  if (fMethodBase->DataInfo().IsSignal(fTrainingEventsRndm[ie])) nsig++; // ignore weights here
252  else nbkg++;
253  }
254  // fsig = Double_t(nsig)/Double_t(nsig+nbkg);
255  // do not implement the above in this release...just set it to default
256 
257  DecisionTree *dt=nullptr;
258  Bool_t tryAgain=kTRUE;
259  Int_t ntries=0;
260  const Int_t ntriesMax=10;
261  Double_t frnd = 0.;
262  while (tryAgain) {
263  frnd = 100*rndGen.Uniform( fMethodRuleFit->GetMinFracNEve(), 0.5*fMethodRuleFit->GetMaxFracNEve() );
264  Int_t iclass = 0; // event class being treated as signal during training
265  Bool_t useRandomisedTree = !useBoost;
266  dt = new DecisionTree( fMethodRuleFit->GetSeparationBase(), frnd, fMethodRuleFit->GetNCuts(), &(fMethodRuleFit->DataInfo()), iclass, useRandomisedTree);
267  dt->SetNVars(fMethodBase->GetNvar());
268 
269  BuildTree(dt); // reads fNTreeSample events from fTrainingEventsRndm
270  if (dt->GetNNodes()<3) {
271  delete dt;
272  dt=0;
273  }
274  ntries++;
275  tryAgain = ((dt==0) && (ntries<ntriesMax));
276  }
277  if (dt) {
278  fForest.push_back(dt);
279  if (useBoost) Boost(dt);
280 
281  } else {
282 
283  Log() << kWARNING << "------------------------------------------------------------------" << Endl;
284  Log() << kWARNING << " Failed growing a tree even after " << ntriesMax << " trials" << Endl;
285  Log() << kWARNING << " Possible solutions: " << Endl;
286  Log() << kWARNING << " 1. increase the number of training events" << Endl;
287  Log() << kWARNING << " 2. set a lower min fraction cut (fEventsMin)" << Endl;
288  Log() << kWARNING << " 3. maybe also decrease the max fraction cut (fEventsMax)" << Endl;
289  Log() << kWARNING << " If the above warning occurs rarely only, it can be ignored" << Endl;
290  Log() << kWARNING << "------------------------------------------------------------------" << Endl;
291  }
292 
293  Log() << kDEBUG << "Built tree with minimum cut at N = " << frnd <<"% events"
294  << " => N(nodes) = " << fForest.back()->GetNNodes()
295  << " ; n(tries) = " << ntries
296  << Endl;
297  }
298 
299  // Now restore event weights
300  if (useBoost) RestoreEventWeights();
301 
302  // print statistics on the forest created
303  ForestStatistics();
304 }
305 
306 ////////////////////////////////////////////////////////////////////////////////
307 /// save event weights - must be done before making the forest
308 
309 void TMVA::RuleFit::SaveEventWeights()
310 {
311  fEventWeights.clear();
312  for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
313  Double_t w = (*e)->GetBoostWeight();
314  fEventWeights.push_back(w);
315  }
316 }
317 
318 ////////////////////////////////////////////////////////////////////////////////
319 /// save event weights - must be done before making the forest
320 
321 void TMVA::RuleFit::RestoreEventWeights()
322 {
323  UInt_t ie=0;
324  if (fEventWeights.size() != fTrainingEvents.size()) {
325  Log() << kERROR << "RuleFit::RestoreEventWeights() called without having called SaveEventWeights() before!" << Endl;
326  return;
327  }
328  for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
329  (*e)->SetBoostWeight(fEventWeights[ie]);
330  ie++;
331  }
332 }
333 
334 ////////////////////////////////////////////////////////////////////////////////
335 /// Boost the events. The algorithm below is the called AdaBoost.
336 /// See MethodBDT for details.
337 /// Actually, this is a more or less copy of MethodBDT::AdaBoost().
338 
339 void TMVA::RuleFit::Boost( DecisionTree *dt )
340 {
341  Double_t sumw=0; // sum of initial weights - all events
342  Double_t sumwfalse=0; // idem, only misclassified events
343  //
344  std::vector<Char_t> correctSelected; // <--- boolean stored
345  //
346  for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
347  Bool_t isSignalType = (dt->CheckEvent(*e,kTRUE) > 0.5 );
348  Double_t w = (*e)->GetWeight();
349  sumw += w;
350  //
351  if (isSignalType == fMethodBase->DataInfo().IsSignal(*e)) { // correctly classified
352  correctSelected.push_back(kTRUE);
353  }
354  else { // misclassified
355  sumwfalse+= w;
356  correctSelected.push_back(kFALSE);
357  }
358  }
359  // misclassification error
360  Double_t err = sumwfalse/sumw;
361  // calculate boost weight for misclassified events
362  // use for now the exponent = 1.0
363  // one could have w = ((1-err)/err)^beta
364  Double_t boostWeight = (err>0 ? (1.0-err)/err : 1000.0);
365  Double_t newSumw=0.0;
366  UInt_t ie=0;
367  // set new weight to misclassified events
368  for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
369  if (!correctSelected[ie])
370  (*e)->SetBoostWeight( (*e)->GetBoostWeight() * boostWeight);
371  newSumw+=(*e)->GetWeight();
372  ie++;
373  }
374  // reweight all events
375  Double_t scale = sumw/newSumw;
376  for (std::vector<const Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); ++e) {
377  (*e)->SetBoostWeight( (*e)->GetBoostWeight() * scale);
378  }
379  Log() << kDEBUG << "boostWeight = " << boostWeight << " scale = " << scale << Endl;
380 }
381 
382 ////////////////////////////////////////////////////////////////////////////////
383 /// summary of statistics of all trees
384 /// - end-nodes: average and spread
385 
386 void TMVA::RuleFit::ForestStatistics()
387 {
388  UInt_t ntrees = fForest.size();
389  if (ntrees==0) return;
390  const DecisionTree *tree;
391  Double_t sumn2 = 0;
392  Double_t sumn = 0;
393  Double_t nd;
394  for (UInt_t i=0; i<ntrees; i++) {
395  tree = fForest[i];
396  nd = Double_t(tree->GetNNodes());
397  sumn += nd;
398  sumn2 += nd*nd;
399  }
400  Double_t sig = TMath::Sqrt( gTools().ComputeVariance( sumn2, sumn, ntrees ));
401  Log() << kVERBOSE << "Nodes in trees: average & std dev = " << sumn/ntrees << " , " << sig << Endl;
402 }
403 
404 ////////////////////////////////////////////////////////////////////////////////
405 ///
406 /// Fit the coefficients for the rule ensemble
407 ///
408 
409 void TMVA::RuleFit::FitCoefficients()
410 {
411  Log() << kVERBOSE << "Fitting rule/linear terms" << Endl;
412  fRuleFitParams.MakeGDPath();
413 }
414 
415 ////////////////////////////////////////////////////////////////////////////////
416 /// calculates the importance of each rule
417 
418 void TMVA::RuleFit::CalcImportance()
419 {
420  Log() << kVERBOSE << "Calculating importance" << Endl;
421  fRuleEnsemble.CalcImportance();
422  fRuleEnsemble.CleanupRules();
423  fRuleEnsemble.CleanupLinear();
424  fRuleEnsemble.CalcVarImportance();
425  Log() << kVERBOSE << "Filling rule statistics" << Endl;
426  fRuleEnsemble.RuleResponseStats();
427 }
428 
429 ////////////////////////////////////////////////////////////////////////////////
430 /// evaluate single event
431 
432 Double_t TMVA::RuleFit::EvalEvent( const Event& e )
433 {
434  return fRuleEnsemble.EvalEvent( e );
435 }
436 
437 ////////////////////////////////////////////////////////////////////////////////
438 /// set the training events randomly
439 
440 void TMVA::RuleFit::SetTrainingEvents( const std::vector<const Event *>& el )
441 {
442  if (fMethodRuleFit==0) Log() << kFATAL << "RuleFit::SetTrainingEvents - MethodRuleFit not initialized" << Endl;
443  UInt_t neve = el.size();
444  if (neve==0) Log() << kWARNING << "An empty sample of training events was given" << Endl;
445 
446  // copy vector
447  fTrainingEvents.clear();
448  fTrainingEventsRndm.clear();
449  for (UInt_t i=0; i<neve; i++) {
450  fTrainingEvents.push_back(static_cast< const Event *>(el[i]));
451  fTrainingEventsRndm.push_back(static_cast< const Event *>(el[i]));
452  }
453 
454  // Re-shuffle the vector, ie, recreate it in a random order
455  std::shuffle(fTrainingEventsRndm.begin(), fTrainingEventsRndm.end(), fRNGEngine);
456 
457  // fraction events per tree
458  fNTreeSample = static_cast<UInt_t>(neve*fMethodRuleFit->GetTreeEveFrac());
459  Log() << kDEBUG << "Number of events per tree : " << fNTreeSample
460  << " ( N(events) = " << neve << " )"
461  << " randomly drawn without replacement" << Endl;
462 }
463 
464 ////////////////////////////////////////////////////////////////////////////////
465 /// draw a random subsample of the training events without replacement
466 
467 void TMVA::RuleFit::GetRndmSampleEvents(std::vector< const Event * > & evevec, UInt_t nevents)
468 {
469  ReshuffleEvents();
470  if ((nevents<fTrainingEventsRndm.size()) && (nevents>0)) {
471  evevec.resize(nevents);
472  for (UInt_t ie=0; ie<nevents; ie++) {
473  evevec[ie] = fTrainingEventsRndm[ie];
474  }
475  }
476  else {
477  Log() << kWARNING << "GetRndmSampleEvents() : requested sub sample size larger than total size (BUG!).";
478  }
479 }
480 ////////////////////////////////////////////////////////////////////////////////
481 /// normalize rule importance hists
482 ///
483 /// if all weights are positive, the scale will be 1/maxweight
484 /// if minimum weight < 0, then the scale will be 1/max(maxweight,abs(minweight))
485 
486 void TMVA::RuleFit::NormVisHists(std::vector<TH2F *> & hlist)
487 {
488  if (hlist.empty()) return;
489  //
490  Double_t wmin=0;
491  Double_t wmax=0;
492  Double_t w,wm;
493  Double_t awmin;
494  Double_t scale;
495  for (UInt_t i=0; i<hlist.size(); i++) {
496  TH2F *hs = hlist[i];
497  w = hs->GetMaximum();
498  wm = hs->GetMinimum();
499  if (i==0) {
500  wmin=wm;
501  wmax=w;
502  }
503  else {
504  if (w>wmax) wmax=w;
505  if (wm<wmin) wmin=wm;
506  }
507  }
508  awmin = TMath::Abs(wmin);
509  Double_t usemin,usemax;
510  if (awmin>wmax) {
511  scale = 1.0/awmin;
512  usemin = -1.0;
513  usemax = scale*wmax;
514  }
515  else {
516  scale = 1.0/wmax;
517  usemin = scale*wmin;
518  usemax = 1.0;
519  }
520 
521  //
522  for (UInt_t i=0; i<hlist.size(); i++) {
523  TH2F *hs = hlist[i];
524  hs->Scale(scale);
525  hs->SetMinimum(usemin);
526  hs->SetMaximum(usemax);
527  }
528 }
529 
530 ////////////////////////////////////////////////////////////////////////////////
531 /// Fill cut
532 
533 void TMVA::RuleFit::FillCut(TH2F* h2, const Rule *rule, Int_t vind)
534 {
535  if (rule==0) return;
536  if (h2==0) return;
537  //
538  Double_t rmin, rmax;
539  Bool_t dormin,dormax;
540  Bool_t ruleHasVar = rule->GetRuleCut()->GetCutRange(vind,rmin,rmax,dormin,dormax);
541  if (!ruleHasVar) return;
542  //
543  Int_t firstbin = h2->GetBin(1,1,1);
544  if(firstbin<0) firstbin=0;
545  Int_t lastbin = h2->GetBin(h2->GetNbinsX(),1,1);
546  Int_t binmin=(dormin ? h2->FindBin(rmin,0.5):firstbin);
547  Int_t binmax=(dormax ? h2->FindBin(rmax,0.5):lastbin);
548  Int_t fbin;
549  Double_t xbinw = h2->GetXaxis()->GetBinWidth(firstbin);
550  Double_t fbmin = h2->GetXaxis()->GetBinLowEdge(binmin-firstbin+1);
551  Double_t lbmax = h2->GetXaxis()->GetBinLowEdge(binmax-firstbin+1)+xbinw;
552  Double_t fbfrac = (dormin ? ((fbmin+xbinw-rmin)/xbinw):1.0);
553  Double_t lbfrac = (dormax ? ((rmax-lbmax+xbinw)/xbinw):1.0);
554  Double_t f;
555  Double_t xc;
556  Double_t val;
557 
558  for (Int_t bin = binmin; bin<binmax+1; bin++) {
559  fbin = bin-firstbin+1;
560  if (bin==binmin) {
561  f = fbfrac;
562  }
563  else if (bin==binmax) {
564  f = lbfrac;
565  }
566  else {
567  f = 1.0;
568  }
569  xc = h2->GetXaxis()->GetBinCenter(fbin);
570  //
571  if (fVisHistsUseImp) {
572  val = rule->GetImportance();
573  }
574  else {
575  val = rule->GetCoefficient()*rule->GetSupport();
576  }
577  h2->Fill(xc,0.5,val*f);
578  }
579 }
580 
581 ////////////////////////////////////////////////////////////////////////////////
582 /// fill lin
583 
584 void TMVA::RuleFit::FillLin(TH2F* h2,Int_t vind)
585 {
586  if (h2==0) return;
587  if (!fRuleEnsemble.DoLinear()) return;
588  //
589  Int_t firstbin = 1;
590  Int_t lastbin = h2->GetNbinsX();
591  Double_t xc;
592  Double_t val;
593  if (fVisHistsUseImp) {
594  val = fRuleEnsemble.GetLinImportance(vind);
595  }
596  else {
597  val = fRuleEnsemble.GetLinCoefficients(vind);
598  }
599  for (Int_t bin = firstbin; bin<lastbin+1; bin++) {
600  xc = h2->GetXaxis()->GetBinCenter(bin);
601  h2->Fill(xc,0.5,val);
602  }
603 }
604 
605 ////////////////////////////////////////////////////////////////////////////////
606 /// fill rule correlation between vx and vy, weighted with either the importance or the coefficient
607 
608 void TMVA::RuleFit::FillCorr(TH2F* h2,const Rule *rule,Int_t vx, Int_t vy)
609 {
610  if (rule==0) return;
611  if (h2==0) return;
612  Double_t val;
613  if (fVisHistsUseImp) {
614  val = rule->GetImportance();
615  }
616  else {
617  val = rule->GetCoefficient()*rule->GetSupport();
618  }
619  //
620  Double_t rxmin, rxmax, rymin, rymax;
621  Bool_t dorxmin, dorxmax, dorymin, dorymax;
622  //
623  // Get range in rule for X and Y
624  //
625  Bool_t ruleHasVarX = rule->GetRuleCut()->GetCutRange(vx,rxmin,rxmax,dorxmin,dorxmax);
626  Bool_t ruleHasVarY = rule->GetRuleCut()->GetCutRange(vy,rymin,rymax,dorymin,dorymax);
627  if (!(ruleHasVarX || ruleHasVarY)) return;
628  // min max of varX and varY in hist
629  Double_t vxmin = (dorxmin ? rxmin:h2->GetXaxis()->GetXmin());
630  Double_t vxmax = (dorxmax ? rxmax:h2->GetXaxis()->GetXmax());
631  Double_t vymin = (dorymin ? rymin:h2->GetYaxis()->GetXmin());
632  Double_t vymax = (dorymax ? rymax:h2->GetYaxis()->GetXmax());
633  // min max bin in X and Y
634  Int_t binxmin = h2->GetXaxis()->FindBin(vxmin);
635  Int_t binxmax = h2->GetXaxis()->FindBin(vxmax);
636  Int_t binymin = h2->GetYaxis()->FindBin(vymin);
637  Int_t binymax = h2->GetYaxis()->FindBin(vymax);
638  // bin widths
639  Double_t xbinw = h2->GetXaxis()->GetBinWidth(binxmin);
640  Double_t ybinw = h2->GetYaxis()->GetBinWidth(binxmin);
641  Double_t xbinmin = h2->GetXaxis()->GetBinLowEdge(binxmin);
642  Double_t xbinmax = h2->GetXaxis()->GetBinLowEdge(binxmax)+xbinw;
643  Double_t ybinmin = h2->GetYaxis()->GetBinLowEdge(binymin);
644  Double_t ybinmax = h2->GetYaxis()->GetBinLowEdge(binymax)+ybinw;
645  // fraction of edges
646  Double_t fxbinmin = (dorxmin ? ((xbinmin+xbinw-vxmin)/xbinw):1.0);
647  Double_t fxbinmax = (dorxmax ? ((vxmax-xbinmax+xbinw)/xbinw):1.0);
648  Double_t fybinmin = (dorymin ? ((ybinmin+ybinw-vymin)/ybinw):1.0);
649  Double_t fybinmax = (dorymax ? ((vymax-ybinmax+ybinw)/ybinw):1.0);
650  //
651  Double_t fx,fy;
652  Double_t xc,yc;
653  // fill histo
654  for (Int_t binx = binxmin; binx<binxmax+1; binx++) {
655  if (binx==binxmin) {
656  fx = fxbinmin;
657  }
658  else if (binx==binxmax) {
659  fx = fxbinmax;
660  }
661  else {
662  fx = 1.0;
663  }
664  xc = h2->GetXaxis()->GetBinCenter(binx);
665  for (Int_t biny = binymin; biny<binymax+1; biny++) {
666  if (biny==binymin) {
667  fy = fybinmin;
668  }
669  else if (biny==binymax) {
670  fy = fybinmax;
671  }
672  else {
673  fy = 1.0;
674  }
675  yc = h2->GetYaxis()->GetBinCenter(biny);
676  h2->Fill(xc,yc,val*fx*fy);
677  }
678  }
679 }
680 
681 ////////////////////////////////////////////////////////////////////////////////
682 /// help routine to MakeVisHists() - fills for all variables
683 
684 void TMVA::RuleFit::FillVisHistCut(const Rule* rule, std::vector<TH2F *> & hlist)
685 {
686  Int_t nhists = hlist.size();
687  Int_t nvar = fMethodBase->GetNvar();
688  if (nhists!=nvar) Log() << kFATAL << "BUG TRAP: number of hists is not equal the number of variables!" << Endl;
689  //
690  std::vector<Int_t> vindex;
691  TString hstr;
692  // not a nice way to do a check...
693  for (Int_t ih=0; ih<nhists; ih++) {
694  hstr = hlist[ih]->GetTitle();
695  for (Int_t iv=0; iv<nvar; iv++) {
696  if (fMethodBase->GetInputTitle(iv) == hstr)
697  vindex.push_back(iv);
698  }
699  }
700  //
701  for (Int_t iv=0; iv<nvar; iv++) {
702  if (rule) {
703  if (rule->ContainsVariable(vindex[iv])) {
704  FillCut(hlist[iv],rule,vindex[iv]);
705  }
706  }
707  else {
708  FillLin(hlist[iv],vindex[iv]);
709  }
710  }
711 }
712 ////////////////////////////////////////////////////////////////////////////////
713 /// help routine to MakeVisHists() - fills for all correlation plots
714 
715 void TMVA::RuleFit::FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist)
716 {
717  if (rule==0) return;
718  Double_t ruleimp = rule->GetImportance();
719  if (!(ruleimp>0)) return;
720  if (ruleimp<fRuleEnsemble.GetImportanceCut()) return;
721  //
722  Int_t nhists = hlist.size();
723  Int_t nvar = fMethodBase->GetNvar();
724  Int_t ncorr = (nvar*(nvar+1)/2)-nvar;
725  if (nhists!=ncorr) Log() << kERROR << "BUG TRAP: number of corr hists is not correct! ncorr = "
726  << ncorr << " nvar = " << nvar << " nhists = " << nhists << Endl;
727  //
728  std::vector< std::pair<Int_t,Int_t> > vindex;
729  TString hstr, var1, var2;
730  Int_t iv1=0,iv2=0;
731  // not a nice way to do a check...
732  for (Int_t ih=0; ih<nhists; ih++) {
733  hstr = hlist[ih]->GetName();
734  if (GetCorrVars( hstr, var1, var2 )) {
735  iv1 = fMethodBase->DataInfo().FindVarIndex( var1 );
736  iv2 = fMethodBase->DataInfo().FindVarIndex( var2 );
737  vindex.push_back( std::pair<Int_t,Int_t>(iv2,iv1) ); // pair X, Y
738  }
739  else {
740  Log() << kERROR << "BUG TRAP: should not be here - failed getting var1 and var2" << Endl;
741  }
742  }
743  //
744  for (Int_t ih=0; ih<nhists; ih++) {
745  if ( (rule->ContainsVariable(vindex[ih].first)) ||
746  (rule->ContainsVariable(vindex[ih].second)) ) {
747  FillCorr(hlist[ih],rule,vindex[ih].first,vindex[ih].second);
748  }
749  }
750 }
751 ////////////////////////////////////////////////////////////////////////////////
752 /// get first and second variables from title
753 
754 Bool_t TMVA::RuleFit::GetCorrVars(TString & title, TString & var1, TString & var2)
755 {
756  var1="";
757  var2="";
758  if(!title.BeginsWith("scat_")) return kFALSE;
759 
760  TString titleCopy = title(5,title.Length());
761  if(titleCopy.Index("_RF2D")>=0) titleCopy.Remove(titleCopy.Index("_RF2D"));
762 
763  Int_t splitPos = titleCopy.Index("_vs_");
764  if(splitPos>=0) { // there is a _vs_ in the string
765  var1 = titleCopy(0,splitPos);
766  var2 = titleCopy(splitPos+4, titleCopy.Length());
767  return kTRUE;
768  }
769  else {
770  var1 = titleCopy;
771  return kFALSE;
772  }
773 }
774 ////////////////////////////////////////////////////////////////////////////////
775 /// this will create histograms visualizing the rule ensemble
776 
777 void TMVA::RuleFit::MakeVisHists()
778 {
779  const TString directories[5] = { "InputVariables_Id",
780  "InputVariables_Deco",
781  "InputVariables_PCA",
782  "InputVariables_Gauss",
783  "InputVariables_Gauss_Deco" };
784 
785  const TString corrDirName = "CorrelationPlots";
786 
787  TDirectory* rootDir = fMethodBase->GetFile();
788  TDirectory* varDir = 0;
789  TDirectory* corrDir = 0;
790 
791  TDirectory* methodDir = fMethodBase->BaseDir();
792  TString varDirName;
793  //
794  Bool_t done=(rootDir==0);
795  Int_t type=0;
796  if (done) {
797  Log() << kWARNING << "No basedir - BUG??" << Endl;
798  return;
799  }
800  while (!done) {
801  varDir = (TDirectory*)rootDir->Get( directories[type] );
802  type++;
803  done = ((varDir!=0) || (type>4));
804  }
805  if (varDir==0) {
806  Log() << kWARNING << "No input variable directory found - BUG?" << Endl;
807  return;
808  }
809  corrDir = (TDirectory*)varDir->Get( corrDirName );
810  if (corrDir==0) {
811  Log() << kWARNING << "No correlation directory found" << Endl;
812  Log() << kWARNING << "Check for other warnings related to correlation histograms" << Endl;
813  return;
814  }
815  if (methodDir==0) {
816  Log() << kWARNING << "No rulefit method directory found - BUG?" << Endl;
817  return;
818  }
819 
820  varDirName = varDir->GetName();
821  varDir->cd();
822  //
823  // get correlation plot directory
824  corrDir = (TDirectory *)varDir->Get(corrDirName);
825  if (corrDir==0) {
826  Log() << kWARNING << "No correlation directory found : " << corrDirName << Endl;
827  return;
828  }
829 
830  // how many plots are in the var directory?
831  Int_t noPlots = ((varDir->GetListOfKeys())->GetEntries()) / 2;
832  Log() << kDEBUG << "Got number of plots = " << noPlots << Endl;
833 
834  // loop over all objects in directory
835  std::vector<TH2F *> h1Vector;
836  std::vector<TH2F *> h2CorrVector;
837  TIter next(varDir->GetListOfKeys());
838  TKey *key;
839  while ((key = (TKey*)next())) {
840  // make sure, that we only look at histograms
841  TClass *cl = gROOT->GetClass(key->GetClassName());
842  if (!cl->InheritsFrom(TH1F::Class())) continue;
843  TH1F *sig = (TH1F*)key->ReadObj();
844  TString hname= sig->GetName();
845  Log() << kDEBUG << "Got histogram : " << hname << Endl;
846 
847  // check for all signal histograms
848  if (hname.Contains("__S")){ // found a new signal plot
849  TString htitle = sig->GetTitle();
850  htitle.ReplaceAll("signal","");
851  TString newname = hname;
852  newname.ReplaceAll("__Signal","__RF");
853  newname.ReplaceAll("__S","__RF");
854 
855  methodDir->cd();
856  TH2F *newhist = new TH2F(newname,htitle,sig->GetNbinsX(),sig->GetXaxis()->GetXmin(),sig->GetXaxis()->GetXmax(),
857  1,sig->GetYaxis()->GetXmin(),sig->GetYaxis()->GetXmax());
858  varDir->cd();
859  h1Vector.push_back( newhist );
860  }
861  }
862  //
863  corrDir->cd();
864  TString var1,var2;
865  TIter nextCorr(corrDir->GetListOfKeys());
866  while ((key = (TKey*)nextCorr())) {
867  // make sure, that we only look at histograms
868  TClass *cl = gROOT->GetClass(key->GetClassName());
869  if (!cl->InheritsFrom(TH2F::Class())) continue;
870  TH2F *sig = (TH2F*)key->ReadObj();
871  TString hname= sig->GetName();
872 
873  // check for all signal histograms
874  if ((hname.Contains("scat_")) && (hname.Contains("_Signal"))) {
875  Log() << kDEBUG << "Got histogram (2D) : " << hname << Endl;
876  TString htitle = sig->GetTitle();
877  htitle.ReplaceAll("(Signal)","");
878  TString newname = hname;
879  newname.ReplaceAll("_Signal","_RF2D");
880 
881  methodDir->cd();
882  const Int_t rebin=2;
883  TH2F *newhist = new TH2F(newname,htitle,
884  sig->GetNbinsX()/rebin,sig->GetXaxis()->GetXmin(),sig->GetXaxis()->GetXmax(),
885  sig->GetNbinsY()/rebin,sig->GetYaxis()->GetXmin(),sig->GetYaxis()->GetXmax());
886  if (GetCorrVars( newname, var1, var2 )) {
887  Int_t iv1 = fMethodBase->DataInfo().FindVarIndex(var1);
888  Int_t iv2 = fMethodBase->DataInfo().FindVarIndex(var2);
889  if (iv1<0) {
890  sig->GetYaxis()->SetTitle(var1);
891  }
892  else {
893  sig->GetYaxis()->SetTitle(fMethodBase->GetInputTitle(iv1));
894  }
895  if (iv2<0) {
896  sig->GetXaxis()->SetTitle(var2);
897  }
898  else {
899  sig->GetXaxis()->SetTitle(fMethodBase->GetInputTitle(iv2));
900  }
901  }
902  corrDir->cd();
903  h2CorrVector.push_back( newhist );
904  }
905  }
906 
907  varDir->cd();
908  // fill rules
909  UInt_t nrules = fRuleEnsemble.GetNRules();
910  const Rule *rule;
911  for (UInt_t i=0; i<nrules; i++) {
912  rule = fRuleEnsemble.GetRulesConst(i);
913  FillVisHistCut(rule, h1Vector);
914  }
915  // fill linear terms and normalise hists
916  FillVisHistCut(0, h1Vector);
917  NormVisHists(h1Vector);
918 
919  //
920  corrDir->cd();
921  // fill rules
922  for (UInt_t i=0; i<nrules; i++) {
923  rule = fRuleEnsemble.GetRulesConst(i);
924  FillVisHistCorr(rule, h2CorrVector);
925  }
926  NormVisHists(h2CorrVector);
927 
928  // write histograms to file
929  methodDir->cd();
930  for (UInt_t i=0; i<h1Vector.size(); i++) h1Vector[i]->Write();
931  for (UInt_t i=0; i<h2CorrVector.size(); i++) h2CorrVector[i]->Write();
932 }
933 
934 ////////////////////////////////////////////////////////////////////////////////
935 /// this will create a histograms intended rather for debugging or for the curious user
936 
937 void TMVA::RuleFit::MakeDebugHists()
938 {
939  TDirectory* methodDir = fMethodBase->BaseDir();
940  if (methodDir==0) {
941  Log() << kWARNING << "<MakeDebugHists> No rulefit method directory found - bug?" << Endl;
942  return;
943  }
944  //
945  methodDir->cd();
946  std::vector<Double_t> distances;
947  std::vector<Double_t> fncuts;
948  std::vector<Double_t> fnvars;
949  const Rule *ruleA;
950  const Rule *ruleB;
951  Double_t dABmin=1000000.0;
952  Double_t dABmax=-1.0;
953  UInt_t nrules = fRuleEnsemble.GetNRules();
954  for (UInt_t i=0; i<nrules; i++) {
955  ruleA = fRuleEnsemble.GetRulesConst(i);
956  for (UInt_t j=i+1; j<nrules; j++) {
957  ruleB = fRuleEnsemble.GetRulesConst(j);
958  Double_t dAB = ruleA->RuleDist( *ruleB, kTRUE );
959  if (dAB>-0.5) {
960  UInt_t nc = ruleA->GetNcuts();
961  UInt_t nv = ruleA->GetNumVarsUsed();
962  distances.push_back(dAB);
963  fncuts.push_back(static_cast<Double_t>(nc));
964  fnvars.push_back(static_cast<Double_t>(nv));
965  if (dAB<dABmin) dABmin=dAB;
966  if (dAB>dABmax) dABmax=dAB;
967  }
968  }
969  }
970  //
971  TH1F *histDist = new TH1F("RuleDist","Rule distances",100,dABmin,dABmax);
972  TTree *distNtuple = new TTree("RuleDistNtuple","RuleDist ntuple");
973  Double_t ntDist;
974  Double_t ntNcuts;
975  Double_t ntNvars;
976  distNtuple->Branch("dist", &ntDist, "dist/D");
977  distNtuple->Branch("ncuts",&ntNcuts, "ncuts/D");
978  distNtuple->Branch("nvars",&ntNvars, "nvars/D");
979  //
980  for (UInt_t i=0; i<distances.size(); i++) {
981  histDist->Fill(distances[i]);
982  ntDist = distances[i];
983  ntNcuts = fncuts[i];
984  ntNvars = fnvars[i];
985  distNtuple->Fill();
986  }
987  distNtuple->Write();
988 }