Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
MulticlassKeras.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 ## \file
3 ## \ingroup tutorial_tmva_keras
4 ## \notebook -nodraw
5 ## This tutorial shows how to do multiclass classification in TMVA with neural
6 ## networks trained with keras.
7 ##
8 ## \macro_code
9 ##
10 ## \date 2017
11 ## \author TMVA Team
12 
13 from ROOT import TMVA, TFile, TTree, TCut, gROOT
14 from os.path import isfile
15 
16 from keras.models import Sequential
17 from keras.layers.core import Dense, Activation
18 from keras.regularizers import l2
19 from keras.optimizers import SGD
20 
21 # Setup TMVA
22 TMVA.Tools.Instance()
23 TMVA.PyMethodBase.PyInitialize()
24 
25 output = TFile.Open('TMVA.root', 'RECREATE')
26 factory = TMVA.Factory('TMVAClassification', output,
27  '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
28 
29 # Load data
30 if not isfile('tmva_example_multiple_background.root'):
31  createDataMacro = str(gROOT.GetTutorialDir()) + '/tmva/createData.C'
32  print(createDataMacro)
33  gROOT.ProcessLine('.L {}'.format(createDataMacro))
34  gROOT.ProcessLine('create_MultipleBackground(4000)')
35 
36 data = TFile.Open('tmva_example_multiple_background.root')
37 signal = data.Get('TreeS')
38 background0 = data.Get('TreeB0')
39 background1 = data.Get('TreeB1')
40 background2 = data.Get('TreeB2')
41 
42 dataloader = TMVA.DataLoader('dataset')
43 for branch in signal.GetListOfBranches():
44  dataloader.AddVariable(branch.GetName())
45 
46 dataloader.AddTree(signal, 'Signal')
47 dataloader.AddTree(background0, 'Background_0')
48 dataloader.AddTree(background1, 'Background_1')
49 dataloader.AddTree(background2, 'Background_2')
50 dataloader.PrepareTrainingAndTestTree(TCut(''),
51  'SplitMode=Random:NormMode=NumEvents:!V')
52 
53 # Generate model
54 
55 # Define model
56 model = Sequential()
57 model.add(Dense(32, activation='relu', W_regularizer=l2(1e-5), input_dim=4))
58 model.add(Dense(4, activation='softmax'))
59 
60 # Set loss and optimizer
61 model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01), metrics=['accuracy',])
62 
63 # Store model to file
64 model.save('model.h5')
65 model.summary()
66 
67 # Book methods
68 factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
69  '!H:!V:Fisher:VarTransform=D,G')
70 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, "PyKeras",
71  'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
72 
73 # Run TMVA
74 factory.TrainAllMethods()
75 factory.TestAllMethods()
76 factory.EvaluateAllMethods()