13 from ROOT
import TMVA, TFile, TTree, TCut
14 from subprocess
import call
15 from os.path
import isfile
17 from keras.models
import Sequential
18 from keras.layers
import Dense, Activation
19 from keras.regularizers
import l2
20 from keras.optimizers
import SGD
24 TMVA.PyMethodBase.PyInitialize()
26 output = TFile.Open(
'TMVA.root',
'RECREATE')
27 factory = TMVA.Factory(
'TMVAClassification', output,
28 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
31 if not isfile(
'tmva_class_example.root'):
32 call([
'curl',
'-O',
'http://root.cern.ch/files/tmva_class_example.root'])
34 data = TFile.Open(
'tmva_class_example.root')
35 signal = data.Get(
'TreeS')
36 background = data.Get(
'TreeB')
38 dataloader = TMVA.DataLoader(
'dataset')
39 for branch
in signal.GetListOfBranches():
40 dataloader.AddVariable(branch.GetName())
42 dataloader.AddSignalTree(signal, 1.0)
43 dataloader.AddBackgroundTree(background, 1.0)
44 dataloader.PrepareTrainingAndTestTree(TCut(
''),
45 'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
51 model.add(Dense(64, activation=
'relu', W_regularizer=l2(1e-5), input_dim=4))
52 model.add(Dense(2, activation=
'softmax'))
55 model.compile(loss=
'categorical_crossentropy',
56 optimizer=SGD(lr=0.01), metrics=[
'accuracy', ])
59 model.save(
'model.h5')
63 factory.BookMethod(dataloader, TMVA.Types.kFisher,
'Fisher',
64 '!H:!V:Fisher:VarTransform=D,G')
65 factory.BookMethod(dataloader, TMVA.Types.kPyKeras,
'PyKeras',
66 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')
69 factory.TrainAllMethods()
70 factory.TestAllMethods()
71 factory.EvaluateAllMethods()