Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
tmva102_Testing.py
Go to the documentation of this file.
1 ## \file
2 ## \ingroup tutorial_tmva
3 ## \notebook -nodraw
4 ## This tutorial illustrates how you can test a trained BDT model using the fast
5 ## tree inference engine offered by TMVA and external tools such as scikit-learn.
6 ##
7 ## \macro_code
8 ## \macro_output
9 ##
10 ## \date August 2019
11 ## \author Stefan Wunsch
12 
13 import ROOT
14 import pickle
15 
16 from tmva100_DataPreparation import variables
17 from tmva101_Training import load_data
18 
19 
20 # Load data
21 x, y_true, w = load_data("test_signal.root", "test_background.root")
22 
23 # Load trained model
24 bdt = ROOT.TMVA.Experimental.RBDT[""]("myBDT", "tmva101.root")
25 
26 # Make prediction
27 y_pred = bdt.Compute(x)
28 
29 # Compute ROC using sklearn
30 from sklearn.metrics import roc_curve, auc
31 fpr, tpr, _ = roc_curve(y_true, y_pred, sample_weight=w)
32 score = auc(fpr, tpr, reorder=True)
33 
34 # Plot ROC
35 c = ROOT.TCanvas("roc", "", 600, 600)
36 g = ROOT.TGraph(len(fpr), fpr, tpr)
37 g.SetTitle("AUC = {:.2f}".format(score))
38 g.SetLineWidth(3)
39 g.SetLineColor(ROOT.kRed)
40 g.Draw("AC")
41 g.GetXaxis().SetRangeUser(0, 1)
42 g.GetYaxis().SetRangeUser(0, 1)
43 g.GetXaxis().SetTitle("False-positive rate")
44 g.GetYaxis().SetTitle("True-positive rate")
45 c.Draw()