Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
tmva101_Training.py
Go to the documentation of this file.
1 ## \file
2 ## \ingroup tutorial_tmva
3 ## \notebook -nodraw
4 ## This tutorial show how you can train a machine learning model with any package
5 ## reading the training data directly from ROOT files. Using XGBoost, we illustrate
6 ## how you can convert an externally trained model in a format serializable and readable
7 ## with the fast tree inference engine offered by TMVA.
8 ##
9 ## \macro_code
10 ## \macro_output
11 ##
12 ## \date August 2019
13 ## \author Stefan Wunsch
14 
15 import ROOT
16 import numpy as np
17 import pickle
18 
19 from tmva100_DataPreparation import variables
20 
21 
22 def load_data(signal_filename, background_filename):
23  # Read data from ROOT files
24  data_sig = ROOT.RDataFrame("Events", signal_filename).AsNumpy()
25  data_bkg = ROOT.RDataFrame("Events", background_filename).AsNumpy()
26 
27  # Convert inputs to format readable by machine learning tools
28  x_sig = np.vstack([data_sig[var] for var in variables]).T
29  x_bkg = np.vstack([data_bkg[var] for var in variables]).T
30  x = np.vstack([x_sig, x_bkg])
31 
32  # Create labels
33  num_sig = x_sig.shape[0]
34  num_bkg = x_bkg.shape[0]
35  y = np.hstack([np.ones(num_sig), np.zeros(num_bkg)])
36 
37  # Compute weights balancing both classes
38  num_all = num_sig + num_bkg
39  w = np.hstack([np.ones(num_sig) * num_all / num_sig, np.ones(num_bkg) * num_all / num_bkg])
40 
41  return x, y, w
42 
43 if __name__ == "__main__":
44  # Load data
45  x, y, w = load_data("train_signal.root", "train_background.root")
46 
47  # Fit xgboost model
48  from xgboost import XGBClassifier
49  bdt = XGBClassifier(max_depth=3, n_estimators=500)
50  bdt.fit(x, y, w)
51 
52  # Save model in TMVA format
53  ROOT.TMVA.Experimental.SaveXGBoost(bdt, "myBDT", "tmva101.root")