Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
RegressionKeras.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 ## \file
3 ## \ingroup tutorial_tmva_keras
4 ## \notebook -nodraw
5 ## This tutorial shows how to do regression in TMVA with neural networks
6 ## trained with keras.
7 ##
8 ## \macro_code
9 ##
10 ## \date 2017
11 ## \author TMVA Team
12 
13 from ROOT import TMVA, TFile, TTree, TCut
14 from subprocess import call
15 from os.path import isfile
16 
17 from keras.models import Sequential
18 from keras.layers.core import Dense, Activation
19 from keras.regularizers import l2
20 from keras.optimizers import SGD
21 
22 # Setup TMVA
23 TMVA.Tools.Instance()
24 TMVA.PyMethodBase.PyInitialize()
25 
26 output = TFile.Open('TMVA.root', 'RECREATE')
27 factory = TMVA.Factory('TMVARegression', output,
28  '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
29 
30 # Load data
31 if not isfile('tmva_reg_example.root'):
32  call(['curl', '-O', 'http://root.cern.ch/files/tmva_reg_example.root'])
33 
34 data = TFile.Open('tmva_reg_example.root')
35 tree = data.Get('TreeR')
36 
37 dataloader = TMVA.DataLoader('dataset')
38 for branch in tree.GetListOfBranches():
39  name = branch.GetName()
40  if name != 'fvalue':
41  dataloader.AddVariable(name)
42 dataloader.AddTarget('fvalue')
43 
44 dataloader.AddRegressionTree(tree, 1.0)
45 dataloader.PrepareTrainingAndTestTree(TCut(''),
46  'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')
47 
48 # Generate model
49 
50 # Define model
51 model = Sequential()
52 model.add(Dense(64, activation='tanh', W_regularizer=l2(1e-5), input_dim=2))
53 model.add(Dense(1, activation='linear'))
54 
55 # Set loss and optimizer
56 model.compile(loss='mean_squared_error', optimizer=SGD(lr=0.01))
57 
58 # Store model to file
59 model.save('model.h5')
60 model.summary()
61 
62 # Book methods
63 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
64  'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
65 factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',
66  '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
67 
68 # Run TMVA
69 factory.TrainAllMethods()
70 factory.TestAllMethods()
71 factory.EvaluateAllMethods()