Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
ApplicationRegressionKeras.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 ## \file
3 ## \ingroup tutorial_tmva_keras
4 ## \notebook -nodraw
5 ## This tutorial shows how to apply a trained model to new data (regression).
6 ##
7 ## \macro_code
8 ##
9 ## \date 2017
10 ## \author TMVA Team
11 
12 from ROOT import TMVA, TFile, TString
13 from array import array
14 from subprocess import call
15 from os.path import isfile
16 
17 # Setup TMVA
18 TMVA.Tools.Instance()
19 TMVA.PyMethodBase.PyInitialize()
20 reader = TMVA.Reader("Color:!Silent")
21 
22 # Load data
23 if not isfile('tmva_reg_example.root'):
24  call(['curl', '-O', 'http://root.cern.ch/files/tmva_reg_example.root'])
25 
26 data = TFile.Open('tmva_reg_example.root')
27 tree = data.Get('TreeR')
28 
29 branches = {}
30 for branch in tree.GetListOfBranches():
31  branchName = branch.GetName()
32  branches[branchName] = array('f', [-999])
33  tree.SetBranchAddress(branchName, branches[branchName])
34  if branchName != 'fvalue':
35  reader.AddVariable(branchName, branches[branchName])
36 
37 # Book methods
38 reader.BookMVA('PyKeras', TString('dataset/weights/TMVARegression_PyKeras.weights.xml'))
39 
40 # Print some example regressions
41 print('Some example regressions:')
42 for i in range(20):
43  tree.GetEntry(i)
44  print('True/MVA value: {}/{}'.format(branches['fvalue'][0],reader.EvaluateMVA('PyKeras')))