Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
BranchlessTree.hxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Web : http://tmva.sourceforge.net *
5  * *
6  * Description: *
7  * *
8  * Authors: *
9  * Stefan Wunsch (stefan.wunsch@cern.ch) *
10  * Luca Zampieri (luca.zampieri@alumni.epfl.ch) *
11  * *
12  * Copyright (c) 2019: *
13  * CERN, Switzerland *
14  * *
15  * Redistribution and use in source and binary forms, with or without *
16  * modification, are permitted according to the terms listed in LICENSE *
17  * (http://tmva.sourceforge.net/LICENSE) *
18  **********************************************************************************/
19 
20 #ifndef TMVA_TREEINFERENCE_BRANCHLESSTREE
21 #define TMVA_TREEINFERENCE_BRANCHLESSTREE
22 
23 #include <vector>
24 #include <algorithm>
25 #include <string>
26 #include <sstream>
27 
28 namespace TMVA {
29 namespace Experimental {
30 
31 namespace Internal {
32 
33 /// Fill the empty nodes of a sparse tree recursively
34 template <typename T>
35 void RecursiveFill(int thisIndex, int lastIndex, int treeDepth, int maxTreeDepth, std::vector<T> &thresholds,
36  std::vector<int> &inputs)
37 {
38  // If we are upstream of a leaf in a sparse branch, copy the last threshold value
39  // and mark this node as a leaf again
40  if (inputs[lastIndex] == -1) {
41  thresholds.at(thisIndex) = thresholds.at(lastIndex);
42  // Don't access the feature vector in the last layer of the tree since we
43  // don't store these values in the inputs vector
44  if (treeDepth < maxTreeDepth)
45  inputs.at(thisIndex) = -1;
46  }
47 
48  // Fill the children of this node if we are not in the final layer of the tree
49  if (treeDepth < maxTreeDepth) {
50  Internal::RecursiveFill<T>(2 * thisIndex + 1, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
51  Internal::RecursiveFill<T>(2 * thisIndex + 2, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
52  }
53 }
54 
55 } // namespace Internal
56 
57 /// \class BranchlessTree
58 /// \brief Branchless representation of a decision tree using topological ordering
59 ///
60 /// \tparam T Value type for the computation (usually floating point type)
61 template <typename T>
62 struct BranchlessTree {
63  int fTreeDepth; ///< Depth of the tree
64  std::vector<T> fThresholds; ///< Cut thresholds or scores if corresponding node is a leaf
65  std::vector<int> fInputs; ///< Cut variables / inputs
66 
67  inline T Inference(const T *input, const int stride);
68  inline void FillSparse();
69  inline std::string GetInferenceCode(const std::string& funcName, const std::string& typeName);
70 };
71 
72 /// Perform inference on a single input vector
73 /// \param[in] input Pointer to data containing the input values
74 /// \param[in] stride Stride to go from one input variable to the next one
75 /// \param[out] Tree score, result of the inference
76 template <typename T>
77 inline T BranchlessTree<T>::Inference(const T *input, const int stride)
78 {
79  int index = 0;
80  for (int level = 0; level < fTreeDepth; ++level) {
81  index = 2 * index + 1 + (input[fInputs[index] * stride] > fThresholds[index]);
82  }
83  return fThresholds[index];
84 }
85 
86 /// Fill nodes of a sparse tree forming a full tree
87 ///
88 /// Sparse parts of the tree are marked with -1 values in the feature vector. The
89 /// algorithm fills these parts up with the last threshold value so that the result
90 /// of the inference stays the same but the computation always traverses the full tree,
91 /// which is needed to avoid branching logic.
92 template <typename T>
93 inline void BranchlessTree<T>::FillSparse()
94 {
95  // Fill threshold / leaf values recursively
96  Internal::RecursiveFill<T>(1, 0, 1, fTreeDepth, fThresholds, fInputs);
97  Internal::RecursiveFill<T>(2, 0, 1, fTreeDepth, fThresholds, fInputs);
98 
99  // Replace feature indices of -1 with 0
100  std::replace(fInputs.begin(), fInputs.end(), -1.0, 0.0);
101 }
102 
103 /// Get code for compiling the inference function of the branchless tree with
104 /// the current thresholds and cut variables
105 ///
106 /// \param[in] funcName Name of the function
107 /// \param[in] typeName Name of the type used for the computation
108 /// \param[out] Code of the inference function as string
109 template <typename T>
110 inline std::string BranchlessTree<T>::GetInferenceCode(const std::string& funcName, const std::string& typeName)
111 {
112  std::stringstream ss;
113 
114  // Build signature
115  ss << "inline " << typeName << " " << funcName << "(const " << typeName << "* input, const int stride)";
116 
117  // Function body
118  ss << "\n{\n";
119 
120  // Hard-code thresholds and cut variables
121  ss << " const int inputs[" << fInputs.size() << "] = {";
122  int last = static_cast<int>(fInputs.size() - 1);
123  for (int i = 0; i < last + 1; i++) {
124  ss << fInputs[i];
125  if (i != last) ss << ", ";
126  }
127  ss << "};\n";
128 
129  ss << " const " << typeName << " thresholds[" << fThresholds.size() << "] = {";
130  last = static_cast<int>(fThresholds.size() - 1);
131  for (int i = 0; i < last + 1; i++) {
132  ss << fThresholds[i];
133  if (i != last) ss << ", ";
134  }
135  ss << "};\n";
136 
137  // Add inference code
138  ss << " int index = 0;\n";
139  for (int level = 0; level < fTreeDepth; ++level) {
140  ss << " index = 2 * index + 1 + (input[inputs[index] * stride] > thresholds[index]);\n";
141  }
142  ss << " return thresholds[index];\n";
143  ss << "}";
144 
145  return ss.str();
146 }
147 
148 } // namespace Experimental
149 } // namespace TMVA
150 
151 #endif // TMVA_TREEINFERENCE_BRANCHLESSTREE