47 Bool_t MethodRXGB::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("xgboost");
50 MethodRXGB::MethodRXGB(const TString &jobName,
51 const TString &methodTitle,
53 const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
57 predict("predict", "xgboost"),
59 xgbdmatrix("xgb.DMatrix"),
62 asfactor("as.factor"),
63 asmatrix("as.matrix"),
71 MethodRXGB::MethodRXGB(DataSetInfo &theData,
const TString &theWeightFile)
72 : RMethodBase(Types::kRXGB, theData, theWeightFile),
76 predict(
"predict",
"xgboost"),
78 xgbdmatrix(
"xgb.DMatrix"),
81 asfactor(
"as.factor"),
82 asmatrix(
"as.matrix"),
90 MethodRXGB::~MethodRXGB(
void)
92 if (fModel)
delete fModel;
96 Bool_t MethodRXGB::HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
98 if (type == Types::kClassification && numberClasses == 2)
return kTRUE;
104 void MethodRXGB::Init()
107 if (!IsModuleLoaded) {
108 Error(
"Init",
"R's package xgboost can not be loaded.");
109 Log() << kFATAL <<
" R's package xgboost can not be loaded."
115 UInt_t size = fFactorTrain.size();
116 fFactorNumeric.resize(size);
118 for (UInt_t i = 0; i < size; i++) {
119 if (fFactorTrain[i] ==
"signal") fFactorNumeric[i] = 1;
120 else fFactorNumeric[i] = 0;
127 void MethodRXGB::Train()
129 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL <<
"<Train> Data() has zero events" << Endl;
130 ROOT::R::TRObject dmatrix = xgbdmatrix(ROOT::R::Label[
"data"] = asmatrix(fDfTrain), ROOT::R::Label[
"label"] = fFactorNumeric);
131 ROOT::R::TRDataFrame params;
132 params[
"eta"] = fEta;
133 params[
"max.depth"] = fMaxDepth;
135 SEXP Model = xgbtrain(ROOT::R::Label[
"data"] = dmatrix,
136 ROOT::R::Label[
"label"] = fFactorNumeric,
137 ROOT::R::Label[
"weight"] = fWeightTrain,
138 ROOT::R::Label[
"nrounds"] = fNRounds,
139 ROOT::R::Label[
"params"] = params);
141 fModel =
new ROOT::R::TRObject(Model);
142 if (IsModelPersistence())
144 TString path = GetWeightFileDir() +
"/" + GetName() +
".RData";
146 Log() << gTools().Color(
"bold") <<
"--- Saving State File In:" << gTools().Color(
"reset") << path << Endl;
148 xgbsave(Model, path);
153 void MethodRXGB::DeclareOptions()
155 DeclareOptionRef(fNRounds,
"NRounds",
"The max number of iterations");
156 DeclareOptionRef(fEta,
"Eta",
"Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
157 DeclareOptionRef(fMaxDepth,
"MaxDepth",
"Maximum depth of the tree");
161 void MethodRXGB::ProcessOptions()
166 void MethodRXGB::TestClassification()
168 Log() << kINFO <<
"Testing Classification RXGB METHOD " << Endl;
169 MethodBase::TestClassification();
174 Double_t MethodRXGB::GetMvaValue(Double_t *errLower, Double_t *errUpper)
176 NoErrorCalc(errLower, errUpper);
178 const TMVA::Event *ev = GetEvent();
179 const UInt_t nvar = DataInfo().GetNVariables();
180 ROOT::R::TRDataFrame fDfEvent;
181 for (UInt_t i = 0; i < nvar; i++) {
182 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
185 if (IsModelPersistence()) ReadStateFromFile();
187 mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label[
"data"] = asmatrix(fDfEvent)));
193 std::vector<Double_t> MethodRXGB::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
195 Long64_t nEvents = Data()->GetNEvents();
196 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
197 if (firstEvt < 0) firstEvt = 0;
199 nEvents = lastEvt-firstEvt;
201 UInt_t nvars = Data()->GetNVariables();
204 Timer timer( nEvents, GetName(), kTRUE );
206 Log() << kINFO<<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Evaluation of " << GetMethodName() <<
" on "
207 << (Data()->GetCurrentType()==Types::kTraining?
"training":
"testing") <<
" sample (" << nEvents <<
" events)" << Endl;
211 std::vector<std::vector<Float_t> > inputData(nvars);
212 for (UInt_t i = 0; i < nvars; i++) {
213 inputData[i] = std::vector<Float_t>(nEvents);
216 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
217 Data()->SetCurrentEvent(ievt);
218 const TMVA::Event *e = Data()->GetEvent();
219 assert(nvars == e->GetNVariables());
220 for (UInt_t i = 0; i < nvars; i++) {
221 inputData[i][ievt] = e->GetValue(i);
227 ROOT::R::TRDataFrame evtData;
228 for (UInt_t i = 0; i < nvars; i++) {
229 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
232 if (IsModelPersistence()) ReadModelFromFile();
234 std::vector<Double_t> mvaValues(nEvents);
235 ROOT::R::TRObject pred = predict(*fModel, xgbdmatrix(ROOT::R::Label[
"data"] = asmatrix(evtData)));
236 mvaValues = pred.As<std::vector<Double_t>>();
239 Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<<
"Elapsed time for evaluation of " << nEvents <<
" events: "
240 << timer.GetElapsedTime() <<
" " << Endl;
247 void MethodRXGB::GetHelpMessage()
const
254 Log() << gTools().Color(
"bold") <<
"--- Short description:" << gTools().Color(
"reset") << Endl;
256 Log() <<
"Decision Trees and Rule-Based Models " << Endl;
258 Log() << gTools().Color(
"bold") <<
"--- Performance optimisation:" << gTools().Color(
"reset") << Endl;
261 Log() << gTools().Color(
"bold") <<
"--- Performance tuning via configuration options:" << gTools().Color(
"reset") << Endl;
263 Log() <<
"<None>" << Endl;
267 void TMVA::MethodRXGB::ReadModelFromFile()
269 ROOT::R::TRInterface::Instance().Require(
"RXGB");
270 TString path = GetWeightFileDir() +
"/" + GetName() +
".RData";
272 Log() << gTools().Color(
"bold") <<
"--- Loading State File From:" << gTools().Color(
"reset") << path << Endl;
275 SEXP Model = xgbload(path);
276 fModel =
new ROOT::R::TRObject(Model);
281 void TMVA::MethodRXGB::MakeClass(
const TString &)
const