Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Forest.hxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Web : http://tmva.sourceforge.net *
5  * *
6  * Description: *
7  * *
8  * Authors: *
9  * Stefan Wunsch (stefan.wunsch@cern.ch) *
10  * Luca Zampieri (luca.zampieri@alumni.epfl.ch) *
11  * *
12  * Copyright (c) 2019: *
13  * CERN, Switzerland *
14  * *
15  * Redistribution and use in source and binary forms, with or without *
16  * modification, are permitted according to the terms listed in LICENSE *
17  * (http://tmva.sourceforge.net/LICENSE) *
18  **********************************************************************************/
19 
20 #ifndef TMVA_TREEINFERENCE_FOREST
21 #define TMVA_TREEINFERENCE_FOREST
22 
23 #include <functional>
24 #include <string>
25 #include <vector>
26 #include <stdexcept>
27 #include <cmath>
28 #include <algorithm>
29 
30 #include "TFile.h"
31 #include "TDirectory.h"
32 #include "TInterpreter.h"
33 #include "TUUID.h"
34 #include "TGenericClassInfo.h" // ROOT::Internal::GetDemangledTypeName
35 
36 #include "BranchlessTree.hxx"
37 #include "Objectives.hxx"
38 
39 namespace TMVA {
40 namespace Experimental {
41 
42 namespace Internal {
43 template <typename T>
44 T *GetObjectSafe(TFile *f, const std::string &n, const std::string &m)
45 {
46  auto v = reinterpret_cast<T *>(f->Get(m.c_str()));
47  if (v == nullptr)
48  throw std::runtime_error("Failed to read " + m + " from file " + n + ".");
49  return v;
50 }
51 
52 template <typename T>
53 bool CompareTree(const BranchlessTree<T> &a, const BranchlessTree<T> &b)
54 {
55  if (a.fInputs[0] == b.fInputs[0])
56  return a.fThresholds[0] < b.fThresholds[0];
57  else
58  return a.fInputs[0] < b.fInputs[0];
59 }
60 } // namespace Internal
61 
62 /// Forest base class
63 ///
64 /// \tparam T Value type for the computation (usually floating point type)
65 /// \tparam ForestType Type of the collection of trees
66 template <typename T, typename ForestType>
67 struct ForestBase {
68  using Value_t = T;
69  std::function<T(T)> fObjectiveFunc; ///< Objective function
70  ForestType fTrees; ///< Store the forest, either as vector or jitted function
71  int fNumInputs; ///< Number of input variables
72 
73  void Inference(const T *inputs, const int rows, bool layout, T *predictions);
74 };
75 
76 /// Perform inference of the forest on a batch of inputs
77 ///
78 /// \param[in] inputs Pointer to data containing the inputs
79 /// \param[in] rows Number of events in inputs vector
80 /// \param[in] layout Row major (true) or column major (false) memory layout
81 /// \param[in] predictions Pointer to the buffer to be filled with the predictions
82 template <typename T, typename ForestType>
83 inline void ForestBase<T, ForestType>::Inference(const T *inputs, const int rows, bool layout, T *predictions)
84 {
85  const auto strideTree = layout ? 1 : rows;
86  const auto strideBatch = layout ? fNumInputs : 1;
87  for (int i = 0; i < rows; i++) {
88  predictions[i] = 0.0;
89  for (auto &tree : fTrees) {
90  predictions[i] += tree.Inference(inputs + i * strideBatch, strideTree);
91  }
92  predictions[i] = fObjectiveFunc(predictions[i]);
93  }
94 }
95 
96 /// Forest using branchless trees
97 ///
98 /// \tparam T Value type for the computation (usually floating point type)
99 template <typename T>
100 struct BranchlessForest : public ForestBase<T, std::vector<BranchlessTree<T>>> {
101  void Load(const std::string &key, const std::string &filename, const int output = 0, const bool sortTrees = true);
102 };
103 
104 /// Load parameters from a ROOT file to the branchless trees
105 ///
106 /// \param[in] key Name of folder in the ROOT file containing the model parameters
107 /// \param[in] filename Filename of the ROOT file
108 /// \param[in] output Load trees corresponding to the given output node of the forest
109 /// \param[in] sortTrees Flag to indicate sorting the input trees by the cut value of the first node of each tree
110 template <typename T>
111 inline void
112 BranchlessForest<T>::Load(const std::string &key, const std::string &filename, const int output, const bool sortTrees)
113 {
114  // Open input file and get folder from key
115  auto file = TFile::Open(filename.c_str(), "READ");
116 
117  // Load parameters from file
118  auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/max_depth");
119  auto numTrees = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_trees");
120  auto numInputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_inputs");
121  auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_outputs");
122  auto objective = Internal::GetObjectSafe<std::string>(file, filename, key + "/objective");
123  auto inputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/inputs");
124  auto outputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/outputs");
125  auto thresholds = Internal::GetObjectSafe<std::vector<T>>(file, filename, key + "/thresholds");
126 
127  this->fNumInputs = numInputs->at(0);
128  this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
129  const auto lenInputs = std::pow(2, maxDepth->at(0)) - 1;
130  const auto lenThresholds = std::pow(2, maxDepth->at(0) + 1) - 1;
131 
132  // Find number of trees corresponding to given output node
133  if (output > numOutputs->at(0))
134  throw std::runtime_error("Given output node of the forest is larger or equal to number of output nodes.");
135  int c = 0;
136  for (int i = 0; i < numTrees->at(0); i++)
137  if (outputs->at(i) == output)
138  c++;
139  if (c == 0)
140  std::runtime_error("No trees found for given output node of the forest.");
141  this->fTrees.resize(c);
142 
143  // Load parameters in trees
144  c = 0;
145  for (int i = 0; i < numTrees->at(0); i++) {
146  // Select only trees for the given output node of the forest
147  if (outputs->at(i) != output)
148  continue;
149 
150  // Set tree depth
151  this->fTrees[c].fTreeDepth = maxDepth->at(0);
152 
153  // Set feature indices
154  this->fTrees[c].fInputs.resize(lenInputs);
155  for (int j = 0; j < lenInputs; j++)
156  this->fTrees[c].fInputs[j] = inputs->at(i * lenInputs + j);
157 
158  // Set threshold values
159  this->fTrees[c].fThresholds.resize(lenThresholds);
160  for (int j = 0; j < lenThresholds; j++)
161  this->fTrees[c].fThresholds[j] = thresholds->at(i * lenThresholds + j);
162 
163  // Fill sparse trees fully
164  this->fTrees[c].FillSparse();
165 
166  c++;
167  }
168 
169  // Sort trees by first cut variable and threshold value
170  if (sortTrees)
171  std::sort(this->fTrees.begin(), this->fTrees.end(), Internal::CompareTree<T>);
172 
173  // Clean-up
174  delete maxDepth;
175  delete numTrees;
176  delete numInputs;
177  delete objective;
178  delete inputs;
179  delete thresholds;
180  file->Close();
181 }
182 
183 /// Forest using branchless jitted trees
184 ///
185 /// \tparam T Value type for the computation (usually floating point type)
186 template <typename T>
187 struct BranchlessJittedForest : public ForestBase<T, std::function<void (const T *, const int, bool, T*)>> {
188  std::string Load(const std::string &key, const std::string &filename, const int output = 0, const bool sortTrees = true);
189  void Inference(const T *inputs, const int rows, bool layout, T *predictions);
190 };
191 
192 /// Load parameters from a ROOT file to the branchless trees
193 ///
194 /// \param[in] key Name of folder in the ROOT file containing the model parameters
195 /// \param[in] filename Filename of the ROOT file
196 /// \param[in] output Load trees corresponding to the given output node of the forest
197 /// \param[in] sortTrees Flag to indicate sorting the input trees by the cut value of the first node of each tree
198 /// \param[out] Return jitted code as string
199 template <typename T>
200 inline std::string
201 BranchlessJittedForest<T>::Load(const std::string &key, const std::string &filename, const int output, const bool sortTrees)
202 {
203  // Open input file and get folder from key
204  auto file = TFile::Open(filename.c_str(), "READ");
205 
206  // Load parameters from file
207  auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/max_depth");
208  auto numTrees = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_trees");
209  auto numInputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_inputs");
210  auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_outputs");
211  auto objective = Internal::GetObjectSafe<std::string>(file, filename, key + "/objective");
212  auto inputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/inputs");
213  auto outputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/outputs");
214  auto thresholds = Internal::GetObjectSafe<std::vector<T>>(file, filename, key + "/thresholds");
215 
216  this->fNumInputs = numInputs->at(0);
217  this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
218  const auto lenInputs = std::pow(2, maxDepth->at(0)) - 1;
219  const auto lenThresholds = std::pow(2, maxDepth->at(0) + 1) - 1;
220 
221  // Find number of trees corresponding to given output node
222  if (output > numOutputs->at(0))
223  throw std::runtime_error("Given output node of the forest is larger or equal to number of output nodes.");
224  int c = 0;
225  for (int i = 0; i < numTrees->at(0); i++)
226  if (outputs->at(i) == output)
227  c++;
228  if (c == 0)
229  std::runtime_error("No trees found for given output node of the forest.");
230 
231  // Get typename of template argument as string
232  std::string typeName = ROOT::Internal::GetDemangledTypeName(typeid(T));
233  if (typeName.compare("") == 0) {
234  throw std::runtime_error("Failed to just-in-time compile inference code for branchless forest (typename as string)");
235  }
236 
237  // Load parameters in trees
238  std::vector<T> firstThreshold(c);
239  std::vector<int> firstInput(c, -1);
240  std::vector<std::string> codes(c);
241  c = 0;
242  for (int i = 0; i < numTrees->at(0); i++) {
243  // Select only trees for the given output node of the forest
244  if (outputs->at(i) != output)
245  continue;
246 
247  // Set tree depth
248  BranchlessTree<T> tree;
249  tree.fTreeDepth = maxDepth->at(0);
250 
251  // Set feature indices
252  tree.fInputs.resize(lenInputs);
253  for (int j = 0; j < lenInputs; j++)
254  tree.fInputs[j] = inputs->at(i * lenInputs + j);
255 
256  // Set threshold values
257  tree.fThresholds.resize(lenThresholds);
258  for (int j = 0; j < lenThresholds; j++)
259  tree.fThresholds[j] = thresholds->at(i * lenThresholds + j);
260 
261  // Fill sparse trees fully
262  tree.FillSparse();
263 
264  // Save first threshold and input index for ordering the trees later
265  firstThreshold[c] = tree.fThresholds[0];
266  if (lenInputs != 0)
267  firstInput[c] = tree.fInputs[0];
268 
269  // Save code for jitting
270  std::stringstream ss;
271  ss << "tree" << c;
272  codes[c] = tree.GetInferenceCode(ss.str(), typeName);
273 
274  c++;
275  }
276 
277  // Sort trees by first cut variable and threshold value
278  std::vector<int> treeIndices(codes.size());
279  for(int i = 0; i < c; i++) treeIndices[i] = i;
280  if (sortTrees) {
281  auto compareIndices = [&firstInput, &firstThreshold](int i, int j)
282  {
283  if (firstInput[i] == firstInput[j])
284  return firstThreshold[i] < firstThreshold[j];
285  else
286  return firstInput[i] < firstInput[j];
287  };
288  std::sort(treeIndices.begin(), treeIndices.end(), compareIndices);
289  }
290 
291  // Get unique ID for a private namespace
292  TUUID uuid;
293  std::string nameSpace = uuid.AsString();
294  for (auto& v : nameSpace) {
295  if (v == '-') v = '_';
296  }
297  nameSpace = "ns_" + nameSpace;
298 
299  // JIT the forest
300  std::stringstream jitForest;
301  jitForest << "#pragma cling optimize(3)\n"
302  << "namespace " << nameSpace << " {\n";
303  for (int i = 0; i < static_cast<int>(codes.size()); i++) {
304  jitForest << codes[treeIndices[i]] << "\n\n";
305  }
306  jitForest << "void Inference(const "
307  << typeName << "* inputs, const int rows, bool layout, "
308  << typeName << "* predictions)"
309  << "\n{\n"
310  << " const auto strideTree = layout ? 1 : rows;\n"
311  << " const auto strideBatch = layout ? " << this->fNumInputs << " : 1;\n"
312  << " for (int i = 0; i < rows; i++) {\n"
313  << " predictions[i] = 0.0;\n";
314  for (int i = 0; i < static_cast<int>(codes.size()); i++) {
315  std::stringstream ss;
316  ss << "tree" << i;
317  const std::string funcName = ss.str();
318  jitForest << " predictions[i] += " << funcName << "(inputs + i * strideBatch, strideTree);\n";
319  }
320  jitForest << " }\n"
321  << "}\n"
322  << "} // end namespace " << nameSpace;
323  const std::string jitForestStr = jitForest.str();
324  const auto err = gInterpreter->Declare(jitForestStr.c_str());
325  if (err == 0) {
326  throw std::runtime_error("Failed to just-in-time compile inference code for branchless forest (declare function)");
327  }
328 
329  // Get function pointer and attach pointer to the forest
330  std::stringstream treesFunc;
331  treesFunc << "#pragma cling optimize(3)\n" << nameSpace << "::Inference";
332  const std::string treesFuncStr = treesFunc.str();
333  auto ptr = gInterpreter->Calc(treesFuncStr.c_str());
334  if (ptr == 0) {
335  throw std::runtime_error("Failed to just-in-time compile inference code for branchless forest (compile function)");
336  }
337  this->fTrees = reinterpret_cast<void (*)(const T *, int, bool, float*)>(ptr);
338 
339  // Clean-up
340  delete maxDepth;
341  delete numTrees;
342  delete numInputs;
343  delete objective;
344  delete inputs;
345  delete thresholds;
346  file->Close();
347 
348  return jitForestStr;
349 }
350 
351 /// Perform inference of the forest with the jitted branchless implementation on a batch of inputs
352 ///
353 /// \param[in] inputs Pointer to data containing the inputs
354 /// \param[in] rows Number of events in inputs vector
355 /// \param[in] layout Row major (true) or column major (false) memory layout
356 /// \param[in] predictions Pointer to the buffer to be filled with the predictions
357 template <typename T>
358 void BranchlessJittedForest<T>::Inference(const T *inputs, const int rows, bool layout, T *predictions)
359 {
360  this->fTrees(inputs, rows, layout, predictions);
361  for (int i = 0; i < rows; i++)
362  predictions[i] = this->fObjectiveFunc(predictions[i]);
363 }
364 
365 } // namespace Experimental
366 } // namespace TMVA
367 
368 #endif // TMVA_TREEINFERENCE_FOREST