20 #ifndef TMVA_TREEINFERENCE_BRANCHLESSTREE
21 #define TMVA_TREEINFERENCE_BRANCHLESSTREE
29 namespace Experimental {
35 void RecursiveFill(
int thisIndex,
int lastIndex,
int treeDepth,
int maxTreeDepth, std::vector<T> &thresholds,
36 std::vector<int> &inputs)
40 if (inputs[lastIndex] == -1) {
41 thresholds.at(thisIndex) = thresholds.at(lastIndex);
44 if (treeDepth < maxTreeDepth)
45 inputs.at(thisIndex) = -1;
49 if (treeDepth < maxTreeDepth) {
50 Internal::RecursiveFill<T>(2 * thisIndex + 1, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
51 Internal::RecursiveFill<T>(2 * thisIndex + 2, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
62 struct BranchlessTree {
64 std::vector<T> fThresholds;
65 std::vector<int> fInputs;
67 inline T Inference(
const T *input,
const int stride);
68 inline void FillSparse();
69 inline std::string GetInferenceCode(
const std::string& funcName,
const std::string& typeName);
77 inline T BranchlessTree<T>::Inference(
const T *input,
const int stride)
80 for (
int level = 0; level < fTreeDepth; ++level) {
81 index = 2 * index + 1 + (input[fInputs[index] * stride] > fThresholds[index]);
83 return fThresholds[index];
93 inline void BranchlessTree<T>::FillSparse()
96 Internal::RecursiveFill<T>(1, 0, 1, fTreeDepth, fThresholds, fInputs);
97 Internal::RecursiveFill<T>(2, 0, 1, fTreeDepth, fThresholds, fInputs);
100 std::replace(fInputs.begin(), fInputs.end(), -1.0, 0.0);
109 template <
typename T>
110 inline std::string BranchlessTree<T>::GetInferenceCode(
const std::string& funcName,
const std::string& typeName)
112 std::stringstream ss;
115 ss <<
"inline " << typeName <<
" " << funcName <<
"(const " << typeName <<
"* input, const int stride)";
121 ss <<
" const int inputs[" << fInputs.size() <<
"] = {";
122 int last =
static_cast<int>(fInputs.size() - 1);
123 for (
int i = 0; i < last + 1; i++) {
125 if (i != last) ss <<
", ";
129 ss <<
" const " << typeName <<
" thresholds[" << fThresholds.size() <<
"] = {";
130 last =
static_cast<int>(fThresholds.size() - 1);
131 for (
int i = 0; i < last + 1; i++) {
132 ss << fThresholds[i];
133 if (i != last) ss <<
", ";
138 ss <<
" int index = 0;\n";
139 for (
int level = 0; level < fTreeDepth; ++level) {
140 ss <<
" index = 2 * index + 1 + (input[inputs[index] * stride] > thresholds[index]);\n";
142 ss <<
" return thresholds[index];\n";
151 #endif // TMVA_TREEINFERENCE_BRANCHLESSTREE