47 Bool_t MethodRXGB::IsModuleLoaded = ROOT::R::TRInterface::Instance().Require("xgboost");
 
   50 MethodRXGB::MethodRXGB(const TString &jobName,
 
   51                        const TString &methodTitle,
 
   53                        const TString &theOption) : RMethodBase(jobName, Types::kRXGB, methodTitle, dsi, theOption),
 
   57    predict("predict", "xgboost"),
 
   59    xgbdmatrix("xgb.DMatrix"),
 
   62    asfactor("as.factor"),
 
   63    asmatrix("as.matrix"),
 
   71 MethodRXGB::MethodRXGB(DataSetInfo &theData, 
const TString &theWeightFile)
 
   72    : RMethodBase(Types::kRXGB, theData, theWeightFile),
 
   76      predict(
"predict", 
"xgboost"),
 
   78      xgbdmatrix(
"xgb.DMatrix"),
 
   81      asfactor(
"as.factor"),
 
   82      asmatrix(
"as.matrix"),
 
   90 MethodRXGB::~MethodRXGB(
void)
 
   92    if (fModel) 
delete fModel;
 
   96 Bool_t MethodRXGB::HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
 
   98    if (type == Types::kClassification && numberClasses == 2) 
return kTRUE;
 
  104 void     MethodRXGB::Init()
 
  107    if (!IsModuleLoaded) {
 
  108       Error(
"Init", 
"R's package xgboost can not be loaded.");
 
  109       Log() << kFATAL << 
" R's package xgboost can not be loaded." 
  115    UInt_t size = fFactorTrain.size();
 
  116    fFactorNumeric.resize(size);
 
  118    for (UInt_t i = 0; i < size; i++) {
 
  119       if (fFactorTrain[i] == 
"signal") fFactorNumeric[i] = 1;
 
  120       else fFactorNumeric[i] = 0;
 
  127 void MethodRXGB::Train()
 
  129    if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << 
"<Train> Data() has zero events" << Endl;
 
  130    ROOT::R::TRObject dmatrix = xgbdmatrix(ROOT::R::Label[
"data"] = asmatrix(fDfTrain), ROOT::R::Label[
"label"] = fFactorNumeric);
 
  131    ROOT::R::TRDataFrame params;
 
  132    params[
"eta"] = fEta;
 
  133    params[
"max.depth"] = fMaxDepth;
 
  135    SEXP Model = xgbtrain(ROOT::R::Label[
"data"] = dmatrix,
 
  136                          ROOT::R::Label[
"label"] = fFactorNumeric,
 
  137                          ROOT::R::Label[
"weight"] = fWeightTrain,
 
  138                          ROOT::R::Label[
"nrounds"] = fNRounds,
 
  139                          ROOT::R::Label[
"params"] = params);
 
  141    fModel = 
new ROOT::R::TRObject(Model);
 
  142    if (IsModelPersistence())
 
  144         TString path = GetWeightFileDir() +  
"/" + GetName() + 
".RData";
 
  146         Log() << gTools().Color(
"bold") << 
"--- Saving State File In:" << gTools().Color(
"reset") << path << Endl;
 
  148         xgbsave(Model, path);
 
  153 void MethodRXGB::DeclareOptions()
 
  155    DeclareOptionRef(fNRounds, 
"NRounds", 
"The max number of iterations");
 
  156    DeclareOptionRef(fEta, 
"Eta", 
"Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features. and eta actually shrinks the feature weights to make the boosting process more conservative.");
 
  157    DeclareOptionRef(fMaxDepth, 
"MaxDepth", 
"Maximum depth of the tree");
 
  161 void MethodRXGB::ProcessOptions()
 
  166 void MethodRXGB::TestClassification()
 
  168    Log() << kINFO << 
"Testing Classification RXGB METHOD  " << Endl;
 
  169    MethodBase::TestClassification();
 
  174 Double_t MethodRXGB::GetMvaValue(Double_t *errLower, Double_t *errUpper)
 
  176    NoErrorCalc(errLower, errUpper);
 
  178    const TMVA::Event *ev = GetEvent();
 
  179    const UInt_t nvar = DataInfo().GetNVariables();
 
  180    ROOT::R::TRDataFrame fDfEvent;
 
  181    for (UInt_t i = 0; i < nvar; i++) {
 
  182       fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
 
  185    if (IsModelPersistence()) ReadStateFromFile();
 
  187    mvaValue = (Double_t)predict(*fModel, xgbdmatrix(ROOT::R::Label[
"data"] = asmatrix(fDfEvent)));
 
  193 std::vector<Double_t> MethodRXGB::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
 
  195    Long64_t nEvents = Data()->GetNEvents();
 
  196    if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
 
  197    if (firstEvt < 0) firstEvt = 0;
 
  199    nEvents = lastEvt-firstEvt; 
 
  201    UInt_t nvars = Data()->GetNVariables();
 
  204    Timer timer( nEvents, GetName(), kTRUE );
 
  206       Log() << kINFO<<Form(
"Dataset[%s] : ",DataInfo().GetName())<< 
"Evaluation of " << GetMethodName() << 
" on " 
  207             << (Data()->GetCurrentType()==Types::kTraining?
"training":
"testing") << 
" sample (" << nEvents << 
" events)" << Endl;
 
  211    std::vector<std::vector<Float_t> > inputData(nvars);
 
  212    for (UInt_t i = 0; i < nvars; i++) {
 
  213       inputData[i] =  std::vector<Float_t>(nEvents); 
 
  216    for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
 
  217      Data()->SetCurrentEvent(ievt);
 
  218       const TMVA::Event *e = Data()->GetEvent();
 
  219       assert(nvars == e->GetNVariables());
 
  220       for (UInt_t i = 0; i < nvars; i++) {
 
  221          inputData[i][ievt] = e->GetValue(i);
 
  227    ROOT::R::TRDataFrame evtData;
 
  228    for (UInt_t i = 0; i < nvars; i++) {
 
  229       evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
 
  232    if (IsModelPersistence()) ReadModelFromFile();
 
  234    std::vector<Double_t> mvaValues(nEvents); 
 
  235    ROOT::R::TRObject pred = predict(*fModel, xgbdmatrix(ROOT::R::Label[
"data"] = asmatrix(evtData)));
 
  236    mvaValues = pred.As<std::vector<Double_t>>();
 
  239       Log() << kINFO <<Form(
"Dataset[%s] : ",DataInfo().GetName())<< 
"Elapsed time for evaluation of " << nEvents <<  
" events: " 
  240             << timer.GetElapsedTime() << 
"       " << Endl;
 
  247 void MethodRXGB::GetHelpMessage()
 const 
  254    Log() << gTools().Color(
"bold") << 
"--- Short description:" << gTools().Color(
"reset") << Endl;
 
  256    Log() << 
"Decision Trees and Rule-Based Models " << Endl;
 
  258    Log() << gTools().Color(
"bold") << 
"--- Performance optimisation:" << gTools().Color(
"reset") << Endl;
 
  261    Log() << gTools().Color(
"bold") << 
"--- Performance tuning via configuration options:" << gTools().Color(
"reset") << Endl;
 
  263    Log() << 
"<None>" << Endl;
 
  267 void TMVA::MethodRXGB::ReadModelFromFile()
 
  269    ROOT::R::TRInterface::Instance().Require(
"RXGB");
 
  270    TString path = GetWeightFileDir() +  
"/" + GetName() + 
".RData";
 
  272    Log() << gTools().Color(
"bold") << 
"--- Loading State File From:" << gTools().Color(
"reset") << path << Endl;
 
  275    SEXP Model = xgbload(path);
 
  276    fModel = 
new ROOT::R::TRObject(Model);
 
  281 void TMVA::MethodRXGB::MakeClass(
const TString &)
 const