Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
MethodRXGB.cxx
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodRXGB *
8  * Web : http://oproject.org *
9  * *
10  * Description: *
11  * R eXtreme Gradient Boosting *
12  * *
13  * *
14  * Redistribution and use in source and binary forms, with or without *
15  * modification, are permitted according to the terms listed in LICENSE *
16  * (http://tmva.sourceforge.net/LICENSE) *
17  * *
18  **********************************************************************************/
19 
20 #include <iomanip>
21 
22 #include "TMath.h"
23 #include "Riostream.h"
24 #include "TMatrix.h"
25 #include "TMatrixD.h"
26 #include "TVectorD.h"
27 
29 #include "TMVA/MethodRXGB.h"
30 #include "TMVA/Tools.h"
31 #include "TMVA/Config.h"
32 #include "TMVA/Ranking.h"
33 #include "TMVA/Types.h"
34 #include "TMVA/PDF.h"
35 #include "TMVA/ClassifierFactory.h"
36 
37 #include "TMVA/Results.h"
38 #include "TMVA/Timer.h"
39 
40 using namespace TMVA;
41 
42 REGISTER_METHOD(RXGB)
43 
44 ClassImp(MethodRXGB);
45 
46 //creating an Instance
47 Bool_t MethodRXGB::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("xgboost");
48 
49 //_______________________________________________________________________
50 MethodRXGB::MethodRXGB(const TString &jobName,
51  const TString &methodTitle,
52  DataSetInfo &dsi,
53  const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
54  fNRounds(10),
55  fEta(0.3),
56  fMaxDepth(6),
57  predict("predict", "xgboost"),
58  xgbtrain("xgboost"),
59  xgbdmatrix("xgb.DMatrix"),
60  xgbsave("xgb.save"),
61  xgbload("xgb.load"),
62  asfactor("as.factor"),
63  asmatrix("as.matrix"),
64  fModel(NULL)
65 {
66  // standard constructor for the RXGB
67 
68 }
69 
70 //_______________________________________________________________________
71 MethodRXGB::MethodRXGB(DataSetInfo &theData, const TString &theWeightFile)
72  : RMethodBase(Types::kRXGB, theData, theWeightFile),
73  fNRounds(10),
74  fEta(0.3),
75  fMaxDepth(6),
76  predict("predict", "xgboost"),
77  xgbtrain("xgboost"),
78  xgbdmatrix("xgb.DMatrix"),
79  xgbsave("xgb.save"),
80  xgbload("xgb.load"),
81  asfactor("as.factor"),
82  asmatrix("as.matrix"),
83  fModel(NULL)
84 {
85 
86 }
87 
88 
89 //_______________________________________________________________________
90 MethodRXGB::~MethodRXGB(void)
91 {
92  if (fModel) delete fModel;
93 }
94 
95 //_______________________________________________________________________
96 Bool_t MethodRXGB::HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/)
97 {
98  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
99  return kFALSE;
100 }
101 
102 
103 //_______________________________________________________________________
104 void MethodRXGB::Init()
105 {
106 
107  if (!IsModuleLoaded) {
108  Error("Init", "R's package xgboost can not be loaded.");
109  Log() << kFATAL << " R's package xgboost can not be loaded."
110  << Endl;
111  return;
112  }
113  //factors creations
114  //xgboost require a numeric factor then background=0 signal=1 from fFactorTrain
115  UInt_t size = fFactorTrain.size();
116  fFactorNumeric.resize(size);
117 
118  for (UInt_t i = 0; i < size; i++) {
119  if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
120  else fFactorNumeric[i] = 0;
121  }
122 
123 
124 
125 }
126 
127 void MethodRXGB::Train()
128 {
129  if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
130  ROOT::R::TRObject dmatrix = xgbdmatrix(ROOT::R::Label["data"] = asmatrix(fDfTrain), ROOT::R::Label["label"] = fFactorNumeric);
131  ROOT::R::TRDataFrame params;
132  params["eta"] = fEta;
133  params["max.depth"] = fMaxDepth;
134 
135  SEXP Model = xgbtrain(ROOT::R::Label["data"] = dmatrix,
136  ROOT::R::Label["label"] = fFactorNumeric,
137  ROOT::R::Label["weight"] = fWeightTrain,
138  ROOT::R::Label["nrounds"] = fNRounds,
139  ROOT::R::Label["params"] = params);
140 
141  fModel = new ROOT::R::TRObject(Model);
142  if (IsModelPersistence())
143  {
144  TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
145  Log() << Endl;
146  Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
147  Log() << Endl;
148  xgbsave(Model, path);
149  }
150 }
151 
152 //_______________________________________________________________________
153 void MethodRXGB::DeclareOptions()
154 {
155  DeclareOptionRef(fNRounds, "NRounds", "The max number of iterations");
156  DeclareOptionRef(fEta, "Eta", "Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
157  DeclareOptionRef(fMaxDepth, "MaxDepth", "Maximum depth of the tree");
158 }
159 
160 //_______________________________________________________________________
161 void MethodRXGB::ProcessOptions()
162 {
163 }
164 
165 //_______________________________________________________________________
166 void MethodRXGB::TestClassification()
167 {
168  Log() << kINFO << "Testing Classification RXGB METHOD " << Endl;
169  MethodBase::TestClassification();
170 }
171 
172 
173 //_______________________________________________________________________
174 Double_t MethodRXGB::GetMvaValue(Double_t *errLower, Double_t *errUpper)
175 {
176  NoErrorCalc(errLower, errUpper);
177  Double_t mvaValue;
178  const TMVA::Event *ev = GetEvent();
179  const UInt_t nvar = DataInfo().GetNVariables();
180  ROOT::R::TRDataFrame fDfEvent;
181  for (UInt_t i = 0; i < nvar; i++) {
182  fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
183  }
184  //if using persistence model
185  if (IsModelPersistence()) ReadStateFromFile();
186 
187  mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(fDfEvent)));
188  return mvaValue;
189 }
190 
191 ////////////////////////////////////////////////////////////////////////////////
192 /// get all the MVA values for the events of the current Data type
193 std::vector<Double_t> MethodRXGB::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
194 {
195  Long64_t nEvents = Data()->GetNEvents();
196  if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
197  if (firstEvt < 0) firstEvt = 0;
198 
199  nEvents = lastEvt-firstEvt;
200 
201  UInt_t nvars = Data()->GetNVariables();
202 
203  // use timer
204  Timer timer( nEvents, GetName(), kTRUE );
205  if (logProgress)
206  Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
207  << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
208 
209 
210  // fill R DATA FRAME with events data
211  std::vector<std::vector<Float_t> > inputData(nvars);
212  for (UInt_t i = 0; i < nvars; i++) {
213  inputData[i] = std::vector<Float_t>(nEvents);
214  }
215 
216  for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
217  Data()->SetCurrentEvent(ievt);
218  const TMVA::Event *e = Data()->GetEvent();
219  assert(nvars == e->GetNVariables());
220  for (UInt_t i = 0; i < nvars; i++) {
221  inputData[i][ievt] = e->GetValue(i);
222  }
223  // if (ievt%100 == 0)
224  // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
225  }
226 
227  ROOT::R::TRDataFrame evtData;
228  for (UInt_t i = 0; i < nvars; i++) {
229  evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
230  }
231  //if using persistence model
232  if (IsModelPersistence()) ReadModelFromFile();
233 
234  std::vector<Double_t> mvaValues(nEvents);
235  ROOT::R::TRObject pred = predict(*fModel, xgbdmatrix(ROOT::R::Label["data"] = asmatrix(evtData)));
236  mvaValues = pred.As<std::vector<Double_t>>();
237 
238  if (logProgress) {
239  Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
240  << timer.GetElapsedTime() << " " << Endl;
241  }
242 
243  return mvaValues;
244 
245 }
246 //_______________________________________________________________________
247 void MethodRXGB::GetHelpMessage() const
248 {
249 // get help message text
250 //
251 // typical length of text line:
252 // "|--------------------------------------------------------------|"
253  Log() << Endl;
254  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
255  Log() << Endl;
256  Log() << "Decision Trees and Rule-Based Models " << Endl;
257  Log() << Endl;
258  Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
259  Log() << Endl;
260  Log() << Endl;
261  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
262  Log() << Endl;
263  Log() << "<None>" << Endl;
264 }
265 
266 //_______________________________________________________________________
267 void TMVA::MethodRXGB::ReadModelFromFile()
268 {
269  ROOT::R::TRInterface::Instance().Require("RXGB");
270  TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
271  Log() << Endl;
272  Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
273  Log() << Endl;
274 
275  SEXP Model = xgbload(path);
276  fModel = new ROOT::R::TRObject(Model);
277 
278 }
279 
280 //_______________________________________________________________________
281 void TMVA::MethodRXGB::MakeClass(const TString &/*theClassFileName*/) const
282 {
283 }