Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
GenerateModel.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 ## \file
3 ## \ingroup tutorial_tmva_keras
4 ## \notebook -nodraw
5 ## This tutorial shows how to define and generate a keras model for use with
6 ## TMVA.
7 ##
8 ## \macro_code
9 ##
10 ## \date 2017
11 ## \author TMVA Team
12 
13 from keras.models import Sequential
14 from keras.layers.core import Dense, Activation
15 from keras.regularizers import l2
16 from keras.optimizers import SGD
17 
18 # Setup the model here
19 num_input_nodes = 4
20 num_output_nodes = 2
21 num_hidden_layers = 1
22 nodes_hidden_layer = 64
23 l2_val = 1e-5
24 
25 model = Sequential()
26 
27 # Hidden layer 1
28 # NOTE: Number of input nodes need to be defined in this layer
29 model.add(Dense(nodes_hidden_layer, activation='relu', W_regularizer=l2(l2_val), input_dim=num_input_nodes))
30 
31 # Hidden layer 2 to num_hidden_layers
32 # NOTE: Here, you can do what you want
33 for k in range(num_hidden_layers-1):
34  model.add(Dense(nodes_hidden_layer, activation='relu', W_regularizer=l2(l2_val)))
35 
36 # Ouput layer
37 # NOTE: Use following output types for the different tasks
38 # Binary classification: 2 output nodes with 'softmax' activation
39 # Regression: 1 output with any activation ('linear' recommended)
40 # Multiclass classification: (number of classes) output nodes with 'softmax' activation
41 model.add(Dense(num_output_nodes, activation='softmax'))
42 
43 # Compile model
44 # NOTE: Use following settings for the different tasks
45 # Any classification: 'categorical_crossentropy' is recommended loss function
46 # Regression: 'mean_squared_error' is recommended loss function
47 model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01), metrics=['accuracy',])
48 
49 # Save model
50 model.save('model.h5')
51 
52 # Additional information about the model
53 # NOTE: This is not needed to run the model
54 
55 # Print summary
56 model.summary()
57 
58 # Visualize model as graph
59 try:
60  from keras.utils.visualize_util import plot
61  plot(model, to_file='model.png', show_shapes=True)
62 except:
63  print('[INFO] Failed to make model plot')