Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
ClassificationKeras.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 ## \file
3 ## \ingroup tutorial_tmva_keras
4 ## \notebook -nodraw
5 ## This tutorial shows how to do classification in TMVA with neural networks
6 ## trained with keras.
7 ##
8 ## \macro_code
9 ##
10 ## \date 2017
11 ## \author TMVA Team
12 
13 from ROOT import TMVA, TFile, TTree, TCut
14 from subprocess import call
15 from os.path import isfile
16 
17 from keras.models import Sequential
18 from keras.layers import Dense, Activation
19 from keras.regularizers import l2
20 from keras.optimizers import SGD
21 
22 # Setup TMVA
23 TMVA.Tools.Instance()
24 TMVA.PyMethodBase.PyInitialize()
25 
26 output = TFile.Open('TMVA.root', 'RECREATE')
27 factory = TMVA.Factory('TMVAClassification', output,
28  '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
29 
30 # Load data
31 if not isfile('tmva_class_example.root'):
32  call(['curl', '-O', 'http://root.cern.ch/files/tmva_class_example.root'])
33 
34 data = TFile.Open('tmva_class_example.root')
35 signal = data.Get('TreeS')
36 background = data.Get('TreeB')
37 
38 dataloader = TMVA.DataLoader('dataset')
39 for branch in signal.GetListOfBranches():
40  dataloader.AddVariable(branch.GetName())
41 
42 dataloader.AddSignalTree(signal, 1.0)
43 dataloader.AddBackgroundTree(background, 1.0)
44 dataloader.PrepareTrainingAndTestTree(TCut(''),
45  'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
46 
47 # Generate model
48 
49 # Define model
50 model = Sequential()
51 model.add(Dense(64, activation='relu', W_regularizer=l2(1e-5), input_dim=4))
52 model.add(Dense(2, activation='softmax'))
53 
54 # Set loss and optimizer
55 model.compile(loss='categorical_crossentropy',
56  optimizer=SGD(lr=0.01), metrics=['accuracy', ])
57 
58 # Store model to file
59 model.save('model.h5')
60 model.summary()
61 
62 # Book methods
63 factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
64  '!H:!V:Fisher:VarTransform=D,G')
65 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',
66  'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
67 
68 # Run training, test and evaluation
69 factory.TrainAllMethods()
70 factory.TestAllMethods()
71 factory.EvaluateAllMethods()