Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
ApplicationClassificationKeras.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 ## \file
3 ## \ingroup tutorial_tmva_keras
4 ## \notebook -nodraw
5 ## This tutorial shows how to apply a trained model to new data.
6 ##
7 ## \macro_code
8 ##
9 ## \date 2017
10 ## \author TMVA Team
11 
12 from ROOT import TMVA, TFile, TString
13 from array import array
14 from subprocess import call
15 from os.path import isfile
16 
17 # Setup TMVA
18 TMVA.Tools.Instance()
19 TMVA.PyMethodBase.PyInitialize()
20 reader = TMVA.Reader("Color:!Silent")
21 
22 # Load data
23 if not isfile('tmva_class_example.root'):
24  call(['curl', '-O', 'http://root.cern.ch/files/tmva_class_example.root'])
25 
26 data = TFile.Open('tmva_class_example.root')
27 signal = data.Get('TreeS')
28 background = data.Get('TreeB')
29 
30 branches = {}
31 for branch in signal.GetListOfBranches():
32  branchName = branch.GetName()
33  branches[branchName] = array('f', [-999])
34  reader.AddVariable(branchName, branches[branchName])
35  signal.SetBranchAddress(branchName, branches[branchName])
36  background.SetBranchAddress(branchName, branches[branchName])
37 
38 # Book methods
39 reader.BookMVA('PyKeras', TString('dataset/weights/TMVAClassification_PyKeras.weights.xml'))
40 
41 # Print some example classifications
42 print('Some signal example classifications:')
43 for i in range(20):
44  signal.GetEntry(i)
45  print(reader.EvaluateMVA('PyKeras'))
46 print('')
47 
48 print('Some background example classifications:')
49 for i in range(20):
50  background.GetEntry(i)
51  print(reader.EvaluateMVA('PyKeras'))