Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
LossFunction.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Jan Therhaag
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Event *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * LossFunction and associated classes *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
16  * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * *
20  * Copyright (c) 2005-2011: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * U. of Bonn, Germany *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://mva.sourceforge.net/license.txt) *
29  **********************************************************************************/
30 
31 #ifndef ROOT_TMVA_LossFunction
32 #define ROOT_TMVA_LossFunction
33 
34 //#include <iosfwd>
35 #include <vector>
36 #include <map>
37 #include "TMVA/Event.h"
38 
39 #include "TMVA/Types.h"
40 
41 
42 namespace TMVA {
43 
44  ///////////////////////////////////////////////////////////////////////////////////////////////
45  // Data Structure used by LossFunction and LossFunctionBDT to calculate errors, targets, etc
46  ///////////////////////////////////////////////////////////////////////////////////////////////
47 
48  class LossFunctionEventInfo{
49 
50  public:
51  LossFunctionEventInfo(){
52  trueValue = 0.;
53  predictedValue = 0.;
54  weight = 0.;
55  };
56  LossFunctionEventInfo(Double_t trueValue_, Double_t predictedValue_, Double_t weight_){
57  trueValue = trueValue_;
58  predictedValue = predictedValue_;
59  weight = weight_;
60  }
61  ~LossFunctionEventInfo(){};
62 
63  Double_t trueValue;
64  Double_t predictedValue;
65  Double_t weight;
66  };
67 
68 
69  ///////////////////////////////////////////////////////////////////////////////////////////////
70  // Loss Function interface defining base class for general error calculations in
71  // regression/classification
72  ///////////////////////////////////////////////////////////////////////////////////////////////
73 
74  class LossFunction {
75 
76  public:
77 
78  // constructors
79  LossFunction(){};
80  virtual ~LossFunction(){};
81 
82  // abstract methods that need to be implemented
83  virtual Double_t CalculateLoss(LossFunctionEventInfo& e) = 0;
84  virtual Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs) = 0;
85  virtual Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs) = 0;
86 
87  virtual TString Name() = 0;
88  virtual Int_t Id() = 0;
89  };
90 
91  ///////////////////////////////////////////////////////////////////////////////////////////////
92  // Loss Function interface for boosted decision trees. Inherits from LossFunction
93  ///////////////////////////////////////////////////////////////////////////////////////////////
94 
95  /* Must inherit LossFunction with the virtual keyword so that we only have to implement
96  * the LossFunction interface once.
97  *
98  * LossFunction
99  * / \
100  *SomeLossFunction LossFunctionBDT
101  * \ /
102  * \ /
103  * SomeLossFunctionBDT
104  *
105  * Without the virtual keyword the two would point to their own LossFunction objects
106  * and SomeLossFunctionBDT would have to implement the virtual functions of LossFunction twice, once
107  * for each object. See diagram below.
108  *
109  * LossFunction LossFunction
110  * | |
111  *SomeLossFunction LossFunctionBDT
112  * \ /
113  * \ /
114  * SomeLossFunctionBDT
115  *
116  * Multiple inheritance is often frowned upon. To avoid this, We could make LossFunctionBDT separate
117  * from LossFunction but it really is a type of loss function.
118  * We could also put LossFunction into LossFunctionBDT. In either of these scenarios, if you are doing
119  * different regression methods and want to compare the Loss this makes it more convoluted.
120  * I think that multiple inheritance seems justified in this case, but we could change it if it's a problem.
121  * Usually it isn't a big deal with interfaces and this results in the simplest code in this case.
122  */
123 
124  class LossFunctionBDT : public virtual LossFunction{
125 
126  public:
127 
128  // constructors
129  LossFunctionBDT(){};
130  virtual ~LossFunctionBDT(){};
131 
132  // abstract methods that need to be implemented
133  virtual void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights) = 0;
134  virtual void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap) = 0;
135  virtual Double_t Target(LossFunctionEventInfo& e) = 0;
136  virtual Double_t Fit(std::vector<LossFunctionEventInfo>& evs) = 0;
137 
138  };
139 
140  ///////////////////////////////////////////////////////////////////////////////////////////////
141  // Huber loss function for regression error calculations
142  ///////////////////////////////////////////////////////////////////////////////////////////////
143 
144  class HuberLossFunction : public virtual LossFunction{
145 
146  public:
147  HuberLossFunction();
148  HuberLossFunction(Double_t quantile);
149  ~HuberLossFunction();
150 
151  // The LossFunction methods
152  Double_t CalculateLoss(LossFunctionEventInfo& e);
153  Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs);
154  Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs);
155 
156  // We go ahead and implement the simple ones
157  TString Name(){ return TString("Huber"); };
158  Int_t Id(){ return 0; } ;
159 
160  // Functions needed beyond the interface
161  void Init(std::vector<LossFunctionEventInfo>& evs);
162  Double_t CalculateQuantile(std::vector<LossFunctionEventInfo>& evs, Double_t whichQuantile, Double_t sumOfWeights, bool abs);
163  Double_t CalculateSumOfWeights(const std::vector<LossFunctionEventInfo>& evs);
164  void SetTransitionPoint(std::vector<LossFunctionEventInfo>& evs);
165  void SetSumOfWeights(std::vector<LossFunctionEventInfo>& evs);
166 
167  protected:
168  Double_t fQuantile;
169  Double_t fTransitionPoint;
170  Double_t fSumOfWeights;
171  };
172 
173  ///////////////////////////////////////////////////////////////////////////////////////////////
174  // Huber loss function with boosted decision tree functionality
175  ///////////////////////////////////////////////////////////////////////////////////////////////
176 
177  // The bdt loss function implements the LossFunctionBDT interface and inherits the HuberLossFunction
178  // functionality.
179  class HuberLossFunctionBDT : public LossFunctionBDT, public HuberLossFunction{
180 
181  public:
182  HuberLossFunctionBDT();
183  HuberLossFunctionBDT(Double_t quantile):HuberLossFunction(quantile){};
184  ~HuberLossFunctionBDT(){};
185 
186  // The LossFunctionBDT methods
187  void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights);
188  void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap);
189  Double_t Target(LossFunctionEventInfo& e);
190  Double_t Fit(std::vector<LossFunctionEventInfo>& evs);
191 
192  private:
193  // some data fields
194  };
195 
196  ///////////////////////////////////////////////////////////////////////////////////////////////
197  // LeastSquares loss function for regression error calculations
198  ///////////////////////////////////////////////////////////////////////////////////////////////
199 
200  class LeastSquaresLossFunction : public virtual LossFunction{
201 
202  public:
203  LeastSquaresLossFunction(){};
204  ~LeastSquaresLossFunction(){};
205 
206  // The LossFunction methods
207  Double_t CalculateLoss(LossFunctionEventInfo& e);
208  Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs);
209  Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs);
210 
211  // We go ahead and implement the simple ones
212  TString Name(){ return TString("LeastSquares"); };
213  Int_t Id(){ return 1; } ;
214  };
215 
216  ///////////////////////////////////////////////////////////////////////////////////////////////
217  // Least Squares loss function with boosted decision tree functionality
218  ///////////////////////////////////////////////////////////////////////////////////////////////
219 
220  // The bdt loss function implements the LossFunctionBDT interface and inherits the LeastSquaresLossFunction
221  // functionality.
222  class LeastSquaresLossFunctionBDT : public LossFunctionBDT, public LeastSquaresLossFunction{
223 
224  public:
225  LeastSquaresLossFunctionBDT(){};
226  ~LeastSquaresLossFunctionBDT(){};
227 
228  // The LossFunctionBDT methods
229  void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights);
230  void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap);
231  Double_t Target(LossFunctionEventInfo& e);
232  Double_t Fit(std::vector<LossFunctionEventInfo>& evs);
233  };
234 
235  ///////////////////////////////////////////////////////////////////////////////////////////////
236  // Absolute Deviation loss function for regression error calculations
237  ///////////////////////////////////////////////////////////////////////////////////////////////
238 
239  class AbsoluteDeviationLossFunction : public virtual LossFunction{
240 
241  public:
242  AbsoluteDeviationLossFunction(){};
243  ~AbsoluteDeviationLossFunction(){};
244 
245  // The LossFunction methods
246  Double_t CalculateLoss(LossFunctionEventInfo& e);
247  Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs);
248  Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs);
249 
250  // We go ahead and implement the simple ones
251  TString Name(){ return TString("AbsoluteDeviation"); };
252  Int_t Id(){ return 2; } ;
253  };
254 
255  ///////////////////////////////////////////////////////////////////////////////////////////////
256  // Absolute Deviation loss function with boosted decision tree functionality
257  ///////////////////////////////////////////////////////////////////////////////////////////////
258 
259  // The bdt loss function implements the LossFunctionBDT interface and inherits the AbsoluteDeviationLossFunction
260  // functionality.
261  class AbsoluteDeviationLossFunctionBDT : public LossFunctionBDT, public AbsoluteDeviationLossFunction{
262 
263  public:
264  AbsoluteDeviationLossFunctionBDT(){};
265  ~AbsoluteDeviationLossFunctionBDT(){};
266 
267  // The LossFunctionBDT methods
268  void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights);
269  void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap);
270  Double_t Target(LossFunctionEventInfo& e);
271  Double_t Fit(std::vector<LossFunctionEventInfo>& evs);
272  };
273 }
274 
275 #endif