22 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
23 #include <numpy/arrayobject.h>
54 PyGILState_STATE m_GILState;
57 PyGILRAII() : m_GILState(PyGILState_Ensure()) {}
58 ~PyGILRAII() { PyGILState_Release(m_GILState); }
63 REGISTER_METHOD(PyRandomForest)
65 ClassImp(MethodPyRandomForest);
68 MethodPyRandomForest::MethodPyRandomForest(const TString &jobName,
69 const TString &methodTitle,
71 const TString &theOption) :
72 PyMethodBase(jobName, Types::kPyRandomForest, methodTitle, dsi, theOption),
78 fMinWeightFractionLeaf(0),
79 fMaxFeatures("'auto'"),
80 fMaxLeafNodes("None"),
92 MethodPyRandomForest::MethodPyRandomForest(DataSetInfo &theData,
const TString &theWeightFile)
93 : PyMethodBase(Types::kPyRandomForest, theData, theWeightFile),
99 fMinWeightFractionLeaf(0),
100 fMaxFeatures(
"'auto'"),
101 fMaxLeafNodes(
"None"),
105 fRandomState(
"None"),
114 MethodPyRandomForest::~MethodPyRandomForest(
void)
119 Bool_t MethodPyRandomForest::HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t)
121 if (type == Types::kClassification && numberClasses == 2)
return kTRUE;
122 if (type == Types::kMulticlass && numberClasses >= 2)
return kTRUE;
127 void MethodPyRandomForest::DeclareOptions()
129 MethodBase::DeclareCompatibilityOptions();
131 DeclareOptionRef(fNestimators,
"NEstimators",
"Integer, optional (default=10). The number of trees in the forest.");
132 DeclareOptionRef(fCriterion,
"Criterion",
"String, optional (default='gini') \
133 The function to measure the quality of a split. Supported criteria are \
134 'gini' for the Gini impurity and 'entropy' for the information gain. \
135 Note: this parameter is tree-specific.");
137 DeclareOptionRef(fMaxDepth,
"MaxDepth",
"integer or None, optional (default=None) \
138 The maximum depth of the tree. If None, then nodes are expanded until \
139 all leaves are pure or until all leaves contain less than \
140 min_samples_split samples. \
141 Ignored if ``max_leaf_nodes`` is not None.");
143 DeclareOptionRef(fMinSamplesSplit,
"MinSamplesSplit",
"integer, optional (default=2)\
144 The minimum number of samples required to split an internal node.");
146 DeclareOptionRef(fMinSamplesLeaf,
"MinSamplesLeaf",
"integer, optional (default=1) \
147 The minimum number of samples in newly created leaves. A split is \
148 discarded if after the split, one of the leaves would contain less then \
149 ``min_samples_leaf`` samples.");
150 DeclareOptionRef(fMinWeightFractionLeaf,
"MinWeightFractionLeaf",
"//float, optional (default=0.) \
151 The minimum weighted fraction of the input samples required to be at a \
153 DeclareOptionRef(fMaxFeatures,
"MaxFeatures",
"The number of features to consider when looking for the best split");
155 DeclareOptionRef(fMaxLeafNodes,
"MaxLeafNodes",
"int or None, optional (default=None)\
156 Grow trees with ``max_leaf_nodes`` in best-first fashion.\
157 Best nodes are defined as relative reduction in impurity.\
158 If None then unlimited number of leaf nodes.\
159 If not None then ``max_depth`` will be ignored.");
161 DeclareOptionRef(fBootstrap,
"Bootstrap",
"boolean, optional (default=True) \
162 Whether bootstrap samples are used when building trees.");
164 DeclareOptionRef(fOobScore,
"OoBScore",
" bool Whether to use out-of-bag samples to estimate\
165 the generalization error.");
167 DeclareOptionRef(fNjobs,
"NJobs",
" integer, optional (default=1) \
168 The number of jobs to run in parallel for both `fit` and `predict`. \
169 If -1, then the number of jobs is set to the number of cores.");
171 DeclareOptionRef(fRandomState,
"RandomState",
"int, RandomState instance or None, optional (default=None)\
172 If int, random_state is the seed used by the random number generator;\
173 If RandomState instance, random_state is the random number generator;\
174 If None, the random number generator is the RandomState instance used\
177 DeclareOptionRef(fVerbose,
"Verbose",
"int, optional (default=0)\
178 Controls the verbosity of the tree building process.");
180 DeclareOptionRef(fWarmStart,
"WarmStart",
"bool, optional (default=False)\
181 When set to ``True``, reuse the solution of the previous call to fit\
182 and add more estimators to the ensemble, otherwise, just fit a whole\
185 DeclareOptionRef(fClassWeight,
"ClassWeight",
"dict, list of dicts, \"auto\", \"subsample\" or None, optional\
186 Weights associated with classes in the form ``{class_label: weight}``.\
187 If not given, all classes are supposed to have weight one. For\
188 multi-output problems, a list of dicts can be provided in the same\
189 order as the columns of y.\
190 The \"auto\" mode uses the values of y to automatically adjust\
191 weights inversely proportional to class frequencies in the input data.\
192 The \"subsample\" mode is the same as \"auto\" except that weights are\
193 computed based on the bootstrap sample for every tree grown.\
194 For multi-output, the weights of each column of y will be multiplied.\
195 Note that these weights will be multiplied with sample_weight (passed\
196 through the fit method) if sample_weight is specified.");
198 DeclareOptionRef(fFilenameClassifier,
"FilenameClassifier",
199 "Store trained classifier in this file");
204 void MethodPyRandomForest::ProcessOptions()
206 if (fNestimators <= 0) {
207 Log() << kFATAL <<
" NEstimators <=0... that does not work !! " << Endl;
209 pNestimators = Eval(Form(
"%i", fNestimators));
210 PyDict_SetItemString(fLocalNS,
"nEstimators", pNestimators);
212 if (fCriterion !=
"gini" && fCriterion !=
"entropy") {
213 Log() << kFATAL << Form(
" Criterion = %s... that does not work !! ", fCriterion.Data())
214 <<
" The options are `gini` or `entropy`." << Endl;
216 pCriterion = Eval(Form(
"'%s'", fCriterion.Data()));
217 PyDict_SetItemString(fLocalNS,
"criterion", pCriterion);
219 pMaxDepth = Eval(fMaxDepth);
220 PyDict_SetItemString(fLocalNS,
"maxDepth", pMaxDepth);
222 Log() << kFATAL << Form(
" MaxDepth = %s... that does not work !! ", fMaxDepth.Data())
223 <<
" The options are None or integer." << Endl;
226 if (fMinSamplesSplit < 0) {
227 Log() << kFATAL <<
" MinSamplesSplit < 0... that does not work !! " << Endl;
229 pMinSamplesSplit = Eval(Form(
"%i", fMinSamplesSplit));
230 PyDict_SetItemString(fLocalNS,
"minSamplesSplit", pMinSamplesSplit);
232 if (fMinSamplesLeaf < 0) {
233 Log() << kFATAL <<
" MinSamplesLeaf < 0... that does not work !! " << Endl;
235 pMinSamplesLeaf = Eval(Form(
"%i", fMinSamplesLeaf));
236 PyDict_SetItemString(fLocalNS,
"minSamplesLeaf", pMinSamplesLeaf);
238 if (fMinWeightFractionLeaf < 0) {
239 Log() << kERROR <<
" MinWeightFractionLeaf < 0... that does not work !! " << Endl;
241 pMinWeightFractionLeaf = Eval(Form(
"%f", fMinWeightFractionLeaf));
242 PyDict_SetItemString(fLocalNS,
"minWeightFractionLeaf", pMinWeightFractionLeaf);
244 if (fMaxFeatures ==
"auto" || fMaxFeatures ==
"sqrt" || fMaxFeatures ==
"log2"){
245 fMaxFeatures = Form(
"'%s'", fMaxFeatures.Data());
247 pMaxFeatures = Eval(fMaxFeatures);
248 PyDict_SetItemString(fLocalNS,
"maxFeatures", pMaxFeatures);
251 Log() << kFATAL << Form(
" MaxFeatures = %s... that does not work !! ", fMaxFeatures.Data())
252 <<
"int, float, string or None, optional (default='auto')"
253 <<
"The number of features to consider when looking for the best split:"
254 <<
"If int, then consider `max_features` features at each split."
255 <<
"If float, then `max_features` is a percentage and"
256 <<
"`int(max_features * n_features)` features are considered at each split."
257 <<
"If 'auto', then `max_features=sqrt(n_features)`."
258 <<
"If 'sqrt', then `max_features=sqrt(n_features)`."
259 <<
"If 'log2', then `max_features=log2(n_features)`."
260 <<
"If None, then `max_features=n_features`." << Endl;
263 pMaxLeafNodes = Eval(fMaxLeafNodes);
264 if (!pMaxLeafNodes) {
265 Log() << kFATAL << Form(
" MaxLeafNodes = %s... that does not work !! ", fMaxLeafNodes.Data())
266 <<
" The options are None or integer." << Endl;
268 PyDict_SetItemString(fLocalNS,
"maxLeafNodes", pMaxLeafNodes);
270 pRandomState = Eval(fRandomState);
272 Log() << kFATAL << Form(
" RandomState = %s... that does not work !! ", fRandomState.Data())
273 <<
"If int, random_state is the seed used by the random number generator;"
274 <<
"If RandomState instance, random_state is the random number generator;"
275 <<
"If None, the random number generator is the RandomState instance used by `np.random`." << Endl;
277 PyDict_SetItemString(fLocalNS,
"randomState", pRandomState);
279 pClassWeight = Eval(fClassWeight);
281 Log() << kFATAL << Form(
" ClassWeight = %s... that does not work !! ", fClassWeight.Data())
282 <<
"dict, list of dicts, 'auto', 'subsample' or None, optional" << Endl;
284 PyDict_SetItemString(fLocalNS,
"classWeight", pClassWeight);
287 Log() << kFATAL << Form(
" NJobs = %i... that does not work !! ", fNjobs)
288 <<
"Value has to be greater than zero." << Endl;
290 pNjobs = Eval(Form(
"%i", fNjobs));
291 PyDict_SetItemString(fLocalNS,
"nJobs", pNjobs);
293 pBootstrap = Eval(Form(
"%i", UInt_t(fBootstrap)));
294 PyDict_SetItemString(fLocalNS,
"bootstrap", pBootstrap);
295 pOobScore = Eval(Form(
"%i", UInt_t(fOobScore)));
296 PyDict_SetItemString(fLocalNS,
"oobScore", pOobScore);
297 pVerbose = Eval(Form(
"%i", fVerbose));
298 PyDict_SetItemString(fLocalNS,
"verbose", pVerbose);
299 pWarmStart = Eval(Form(
"%i", UInt_t(fWarmStart)));
300 PyDict_SetItemString(fLocalNS,
"warmStart", pWarmStart);
303 if(fFilenameClassifier.IsNull())
305 fFilenameClassifier = GetWeightFileDir() +
"/PyRFModel_" + GetName() +
".PyData";
310 void MethodPyRandomForest::Init()
312 TMVA::Internal::PyGILRAII raii;
319 PyRunString(
"import sklearn.ensemble");
322 fNvars = GetNVariables();
323 fNoutputs = DataInfo().GetNClasses();
327 void MethodPyRandomForest::Train()
330 int fNrowsTraining = Data()->GetNTrainingEvents();
331 npy_intp dimsData[2];
332 dimsData[0] = fNrowsTraining;
333 dimsData[1] = fNvars;
334 PyArrayObject * fTrainData = (PyArrayObject *)PyArray_SimpleNew(2, dimsData, NPY_FLOAT);
335 PyDict_SetItemString(fLocalNS,
"trainData", (PyObject*)fTrainData);
336 float *TrainData = (
float *)(PyArray_DATA(fTrainData));
338 npy_intp dimsClasses = (npy_intp) fNrowsTraining;
339 PyArrayObject * fTrainDataClasses = (PyArrayObject *)PyArray_SimpleNew(1, &dimsClasses, NPY_FLOAT);
340 PyDict_SetItemString(fLocalNS,
"trainDataClasses", (PyObject*)fTrainDataClasses);
341 float *TrainDataClasses = (
float *)(PyArray_DATA(fTrainDataClasses));
343 PyArrayObject * fTrainDataWeights = (PyArrayObject *)PyArray_SimpleNew(1, &dimsClasses, NPY_FLOAT);
344 PyDict_SetItemString(fLocalNS,
"trainDataWeights", (PyObject*)fTrainDataWeights);
345 float *TrainDataWeights = (
float *)(PyArray_DATA(fTrainDataWeights));
347 for (
int i = 0; i < fNrowsTraining; i++) {
349 const TMVA::Event *e = Data()->GetTrainingEvent(i);
350 for (UInt_t j = 0; j < fNvars; j++) {
351 TrainData[j + i * fNvars] = e->GetValue(j);
355 TrainDataClasses[i] = e->GetClass();
358 TrainDataWeights[i] = e->GetWeight();
362 PyRunString(
"classifier = sklearn.ensemble.RandomForestClassifier(bootstrap=bootstrap, class_weight=classWeight, criterion=criterion, max_depth=maxDepth, max_features=maxFeatures, max_leaf_nodes=maxLeafNodes, min_samples_leaf=minSamplesLeaf, min_samples_split=minSamplesSplit, min_weight_fraction_leaf=minWeightFractionLeaf, n_estimators=nEstimators, n_jobs=nJobs, oob_score=oobScore, random_state=randomState, verbose=verbose, warm_start=warmStart)",
363 "Failed to setup classifier");
367 PyRunString(
"dump = classifier.fit(trainData, trainDataClasses, trainDataWeights)",
"Failed to train classifier");
370 fClassifier = PyDict_GetItemString(fLocalNS,
"classifier");
371 if(fClassifier == 0) {
372 Log() << kFATAL <<
"Can't create classifier object from RandomForestClassifier" << Endl;
376 if (IsModelPersistence()) {
378 Log() << gTools().Color(
"bold") <<
"Saving state file: " << gTools().Color(
"reset") << fFilenameClassifier << Endl;
380 Serialize(fFilenameClassifier, fClassifier);
385 void MethodPyRandomForest::TestClassification()
387 MethodBase::TestClassification();
391 std::vector<Double_t> MethodPyRandomForest::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
394 if (fClassifier == 0) ReadModelFromFile();
397 Long64_t nEvents = Data()->GetNEvents();
398 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
399 if (firstEvt < 0) firstEvt = 0;
400 nEvents = lastEvt-firstEvt;
403 Timer timer( nEvents, GetName(), kTRUE );
406 Log() << kHEADER << Form(
"[%s] : ",DataInfo().GetName())
407 <<
"Evaluation of " << GetMethodName() <<
" on "
408 << (Data()->GetCurrentType() == Types::kTraining ?
"training" :
"testing")
409 <<
" sample (" << nEvents <<
" events)" << Endl;
415 PyArrayObject *pEvent= (PyArrayObject *)PyArray_SimpleNew(2, dims, NPY_FLOAT);
416 float *pValue = (
float *)(PyArray_DATA(pEvent));
418 for (Int_t ievt=0; ievt<nEvents; ievt++) {
419 Data()->SetCurrentEvent(ievt);
420 const TMVA::Event *e = Data()->GetEvent();
421 for (UInt_t i = 0; i < fNvars; i++) {
422 pValue[ievt * fNvars + i] = e->GetValue(i);
427 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>(
"predict_proba"),
const_cast<char *
>(
"(O)"), pEvent);
428 double *proba = (
double *)(PyArray_DATA(result));
431 if(Long64_t(mvaValues.size()) != nEvents) mvaValues.resize(nEvents);
432 for (
int i = 0; i < nEvents; ++i) {
433 mvaValues[i] = proba[fNoutputs*i + TMVA::Types::kSignal];
441 <<
"Elapsed time for evaluation of " << nEvents <<
" events: "
442 << timer.GetElapsedTime() <<
" " << Endl;
449 Double_t MethodPyRandomForest::GetMvaValue(Double_t *errLower, Double_t *errUpper)
452 NoErrorCalc(errLower, errUpper);
455 if (fClassifier == 0) ReadModelFromFile();
458 const TMVA::Event *e = Data()->GetEvent();
462 PyArrayObject *pEvent= (PyArrayObject *)PyArray_SimpleNew(2, dims, NPY_FLOAT);
463 float *pValue = (
float *)(PyArray_DATA(pEvent));
464 for (UInt_t i = 0; i < fNvars; i++) pValue[i] = e->GetValue(i);
467 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>(
"predict_proba"),
const_cast<char *
>(
"(O)"), pEvent);
468 double *proba = (
double *)(PyArray_DATA(result));
472 mvaValue = proba[TMVA::Types::kSignal];
481 std::vector<Float_t>& MethodPyRandomForest::GetMulticlassValues()
484 if (fClassifier == 0) ReadModelFromFile();
487 const TMVA::Event *e = Data()->GetEvent();
491 PyArrayObject *pEvent= (PyArrayObject *)PyArray_SimpleNew(2, dims, NPY_FLOAT);
492 float *pValue = (
float *)(PyArray_DATA(pEvent));
493 for (UInt_t i = 0; i < fNvars; i++) pValue[i] = e->GetValue(i);
496 PyArrayObject *result = (PyArrayObject *)PyObject_CallMethod(fClassifier, const_cast<char *>(
"predict_proba"),
const_cast<char *
>(
"(O)"), pEvent);
497 double *proba = (
double *)(PyArray_DATA(result));
500 if(UInt_t(classValues.size()) != fNoutputs) classValues.resize(fNoutputs);
501 for(UInt_t i = 0; i < fNoutputs; i++) classValues[i] = proba[i];
510 void MethodPyRandomForest::ReadModelFromFile()
512 if (!PyIsInitialized()) {
517 Log() << gTools().Color(
"bold") <<
"Loading state file: " << gTools().Color(
"reset") << fFilenameClassifier << Endl;
521 Int_t err = UnSerialize(fFilenameClassifier, &fClassifier);
524 Log() << kFATAL << Form(
"Failed to load classifier from file (error code: %i): %s", err, fFilenameClassifier.Data()) << Endl;
528 PyDict_SetItemString(fLocalNS,
"classifier", fClassifier);
532 fNvars = GetNVariables();
533 fNoutputs = DataInfo().GetNClasses();
537 const Ranking* MethodPyRandomForest::CreateRanking()
541 PyArrayObject* pRanking = (PyArrayObject*) PyObject_GetAttrString(fClassifier,
"feature_importances_");
542 if(pRanking == 0) Log() << kFATAL <<
"Failed to get ranking from classifier" << Endl;
545 fRanking =
new Ranking(GetName(),
"Variable Importance");
546 Double_t* rankingData = (Double_t*) PyArray_DATA(pRanking);
547 for(UInt_t iVar=0; iVar<fNvars; iVar++){
548 fRanking->AddRank(Rank(GetInputLabel(iVar), rankingData[iVar]));
557 void MethodPyRandomForest::GetHelpMessage()
const
561 Log() <<
"A random forest is a meta estimator that fits a number of decision" << Endl;
562 Log() <<
"tree classifiers on various sub-samples of the dataset and use" << Endl;
563 Log() <<
"averaging to improve the predictive accuracy and control over-fitting." << Endl;
565 Log() <<
"Check out the scikit-learn documentation for more information." << Endl;