Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
tmva100_DataPreparation.py
Go to the documentation of this file.
1 ## \file
2 ## \ingroup tutorial_tmva
3 ## \notebook -nodraw
4 ## This tutorial illustrates how to prepare ROOT datasets to be nicely readable
5 ## by most machine learning methods. This requires filtering the inital complex
6 ## datasets and writing the data in a flat format.
7 ##
8 ## \macro_code
9 ## \macro_output
10 ##
11 ## \date August 2019
12 ## \author Stefan Wunsch
13 
14 import ROOT
15 
16 
17 def filter_events(df):
18  """
19  Reduce initial dataset to only events which shall be used for training
20  """
21  return df.Filter("nElectron>=2 && nMuon>=2", "At least two electrons and two muons")
22 
23 
24 def define_variables(df):
25  """
26  Define the variables which shall be used for training
27  """
28  return df.Define("Muon_pt_1", "Muon_pt[0]")\
29  .Define("Muon_pt_2", "Muon_pt[1]")\
30  .Define("Electron_pt_1", "Electron_pt[0]")\
31  .Define("Electron_pt_2", "Electron_pt[1]")
32 
33 
34 variables = ["Muon_pt_1", "Muon_pt_2", "Electron_pt_1", "Electron_pt_2"]
35 
36 
37 if __name__ == "__main__":
38  for filename, label in [["SMHiggsToZZTo4L.root", "signal"], ["ZZTo2e2mu.root", "background"]]:
39  print(">>> Extract the training and testing events for {} from the {} dataset.".format(
40  label, filename))
41 
42  # Load dataset, filter the required events and define the training variables
43  filepath = "root://eospublic.cern.ch//eos/root-eos/cms_opendata_2012_nanoaod/" + filename
44  df = ROOT.RDataFrame("Events", filepath)
45  df = filter_events(df)
46  df = define_variables(df)
47 
48  # Book cutflow report
49  report = df.Report()
50 
51  # Split dataset by event number for training and testing
52  columns = ROOT.std.vector["string"](variables)
53  df.Filter("event % 2 == 0", "Select events with even event number for training")\
54  .Snapshot("Events", "train_" + label + ".root", columns)
55  df.Filter("event % 2 == 1", "Select events with odd event number for training")\
56  .Snapshot("Events", "test_" + label + ".root", columns)
57 
58  # Print cutflow report
59  report.Print()