13 from ROOT 
import TMVA, TFile, TTree, TCut
 
   14 from subprocess 
import call
 
   15 from os.path 
import isfile
 
   17 from keras.models 
import Sequential
 
   18 from keras.layers 
import Dense, Activation
 
   19 from keras.regularizers 
import l2
 
   20 from keras.optimizers 
import SGD
 
   24 TMVA.PyMethodBase.PyInitialize()
 
   26 output = TFile.Open(
'TMVA.root', 
'RECREATE')
 
   27 factory = TMVA.Factory(
'TMVAClassification', output,
 
   28                        '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
 
   31 if not isfile(
'tmva_class_example.root'):
 
   32     call([
'curl', 
'-O', 
'http://root.cern.ch/files/tmva_class_example.root'])
 
   34 data = TFile.Open(
'tmva_class_example.root')
 
   35 signal = data.Get(
'TreeS')
 
   36 background = data.Get(
'TreeB')
 
   38 dataloader = TMVA.DataLoader(
'dataset')
 
   39 for branch 
in signal.GetListOfBranches():
 
   40     dataloader.AddVariable(branch.GetName())
 
   42 dataloader.AddSignalTree(signal, 1.0)
 
   43 dataloader.AddBackgroundTree(background, 1.0)
 
   44 dataloader.PrepareTrainingAndTestTree(TCut(
''),
 
   45                                       'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
 
   51 model.add(Dense(64, activation=
'relu', W_regularizer=l2(1e-5), input_dim=4))
 
   52 model.add(Dense(2, activation=
'softmax'))
 
   55 model.compile(loss=
'categorical_crossentropy',
 
   56               optimizer=SGD(lr=0.01), metrics=[
'accuracy', ])
 
   59 model.save(
'model.h5')
 
   63 factory.BookMethod(dataloader, TMVA.Types.kFisher, 
'Fisher',
 
   64                    '!H:!V:Fisher:VarTransform=D,G')
 
   65 factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 
'PyKeras',
 
   66                    'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
 
   69 factory.TrainAllMethods()
 
   70 factory.TestAllMethods()
 
   71 factory.EvaluateAllMethods()