Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
RuleFit.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : RuleFit *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * A class implementing various fits of rule ensembles *
12  * *
13  * Authors (alphabetical): *
14  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. *
16  * *
17  * Copyright (c) 2005: *
18  * CERN, Switzerland *
19  * Iowa State U. *
20  * MPI-K Heidelberg, Germany *
21  * *
22  * Redistribution and use in source and binary forms, with or without *
23  * modification, are permitted according to the terms listed in LICENSE *
24  * (http://tmva.sourceforge.net/LICENSE) *
25  **********************************************************************************/
26 
27 #ifndef ROOT_TMVA_RuleFit
28 #define ROOT_TMVA_RuleFit
29 
30 #include "TMVA/DecisionTree.h"
31 #include "TMVA/RuleEnsemble.h"
32 #include "TMVA/RuleFitParams.h"
33 #include "TMVA/Event.h"
34 
35 #include <algorithm>
36 #include <random>
37 
38 namespace TMVA {
39 
40 
41  class MethodBase;
42  class MethodRuleFit;
43  class MsgLogger;
44 
45  class RuleFit {
46 
47  public:
48 
49  // main constructor
50  RuleFit( const TMVA::MethodBase *rfbase );
51 
52  // empty constructor
53  RuleFit( void );
54 
55  virtual ~RuleFit( void );
56 
57  void InitNEveEff();
58  void InitPtrs( const TMVA::MethodBase *rfbase );
59  void Initialize( const TMVA::MethodBase *rfbase );
60 
61  void SetMsgType( EMsgType t );
62 
63  void SetTrainingEvents( const std::vector<const TMVA::Event *> & el );
64 
65  void ReshuffleEvents()
66  {
67  std::shuffle(fTrainingEventsRndm.begin(), fTrainingEventsRndm.end(), fRNGEngine);
68  }
69 
70  void SetMethodBase( const MethodBase *rfbase );
71 
72  // make the forest of trees for rule generation
73  void MakeForest();
74 
75  // build a tree
76  void BuildTree( TMVA::DecisionTree *dt );
77 
78  // save event weights
79  void SaveEventWeights();
80 
81  // restore saved event weights
82  void RestoreEventWeights();
83 
84  // boost events based on the given tree
85  void Boost( TMVA::DecisionTree *dt );
86 
87  // calculate and print some statistics on the given forest
88  void ForestStatistics();
89 
90  // calculate the discriminating variable for the given event
91  Double_t EvalEvent( const Event& e );
92 
93  // calculate sum of
94  Double_t CalcWeightSum( const std::vector<const TMVA::Event *> *events, UInt_t neve=0 );
95 
96  // do the fitting of the coefficients
97  void FitCoefficients();
98 
99  // calculate variable and rule importance from a set of events
100  void CalcImportance();
101 
102  // set usage of linear term
103  void SetModelLinear() { fRuleEnsemble.SetModelLinear(); }
104  // set usage of rules
105  void SetModelRules() { fRuleEnsemble.SetModelRules(); }
106  // set usage of linear term
107  void SetModelFull() { fRuleEnsemble.SetModelFull(); }
108  // set minimum importance allowed
109  void SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
110  // set minimum rule distance - see RuleEnsemble
111  void SetRuleMinDist( Double_t d ) { fRuleEnsemble.SetRuleMinDist(d); }
112  // set path related parameters
113  void SetGDTau( Double_t t=0.0 ) { fRuleFitParams.SetGDTau(t); }
114  void SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
115  void SetGDNPathSteps( Int_t n=100 ) { fRuleFitParams.SetGDNPathSteps(n); }
116  // make visualization histograms
117  void SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
118  void UseImportanceVisHists() { fVisHistsUseImp = kTRUE; }
119  void UseCoefficientsVisHists() { fVisHistsUseImp = kFALSE; }
120  void MakeVisHists();
121  void FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
122  void FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
123  void FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
124  void FillLin(TH2F* h2,Int_t vind);
125  void FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
126  void NormVisHists(std::vector<TH2F *> & hlist);
127  void MakeDebugHists();
128  Bool_t GetCorrVars(TString & title, TString & var1, TString & var2);
129  // accessors
130  UInt_t GetNTreeSample() const { return fNTreeSample; }
131  Double_t GetNEveEff() const { return fNEveEffTrain; } // reweighted number of events = sum(wi)
132  const Event* GetTrainingEvent(UInt_t i) const { return static_cast< const Event *>(fTrainingEvents[i]); }
133  Double_t GetTrainingEventWeight(UInt_t i) const { return fTrainingEvents[i]->GetWeight(); }
134 
135  // const Event* GetTrainingEvent(UInt_t i, UInt_t isub) const { return &(fTrainingEvents[fSubsampleEvents[isub]])[i]; }
136 
137  const std::vector< const TMVA::Event * > & GetTrainingEvents() const { return fTrainingEvents; }
138  // const std::vector< Int_t > & GetSubsampleEvents() const { return fSubsampleEvents; }
139 
140  // void GetSubsampleEvents(Int_t sub, UInt_t & ibeg, UInt_t & iend) const;
141  void GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
142  //
143  const std::vector< const TMVA::DecisionTree *> & GetForest() const { return fForest; }
144  const RuleEnsemble & GetRuleEnsemble() const { return fRuleEnsemble; }
145  RuleEnsemble * GetRuleEnsemblePtr() { return &fRuleEnsemble; }
146  const RuleFitParams & GetRuleFitParams() const { return fRuleFitParams; }
147  RuleFitParams * GetRuleFitParamsPtr() { return &fRuleFitParams; }
148  const MethodRuleFit * GetMethodRuleFit() const { return fMethodRuleFit; }
149  const MethodBase * GetMethodBase() const { return fMethodBase; }
150 
151  private:
152 
153  // copy constructor
154  RuleFit( const RuleFit & other );
155 
156  // copy method
157  void Copy( const RuleFit & other );
158 
159  std::vector<const TMVA::Event *> fTrainingEvents; // all training events
160  std::vector<const TMVA::Event *> fTrainingEventsRndm; // idem, but randomly shuffled
161  std::vector<Double_t> fEventWeights; // original weights of the events - follows fTrainingEvents
162  UInt_t fNTreeSample; // number of events in sub sample = frac*neve
163 
164  Double_t fNEveEffTrain; // reweighted number of events = sum(wi)
165  std::vector< const TMVA::DecisionTree *> fForest; // the input forest of decision trees
166  RuleEnsemble fRuleEnsemble; // the ensemble of rules
167  RuleFitParams fRuleFitParams; // fit rule parameters
168  const MethodRuleFit *fMethodRuleFit; // pointer the method which initialized this RuleFit instance
169  const MethodBase *fMethodBase; // pointer the method base which initialized this RuleFit instance
170  Bool_t fVisHistsUseImp; // if true, use importance as weight; else coef in vis hists
171 
172  mutable MsgLogger* fLogger; // message logger
173  MsgLogger& Log() const { return *fLogger; }
174 
175  static const Int_t randSEED = 0; // set to 1 for debugging purposes or to zero for random seeds
176  std::default_random_engine fRNGEngine;
177 
178  ClassDef(RuleFit,0); // Calculations for Friedman's RuleFit method
179  };
180 }
181 
182 #endif