20 #ifndef TMVA_TREEINFERENCE_FOREST
21 #define TMVA_TREEINFERENCE_FOREST
40 namespace Experimental {
44 T *GetObjectSafe(TFile *f,
const std::string &n,
const std::string &m)
46 auto v =
reinterpret_cast<T *
>(f->Get(m.c_str()));
48 throw std::runtime_error(
"Failed to read " + m +
" from file " + n +
".");
53 bool CompareTree(
const BranchlessTree<T> &a,
const BranchlessTree<T> &b)
55 if (a.fInputs[0] == b.fInputs[0])
56 return a.fThresholds[0] < b.fThresholds[0];
58 return a.fInputs[0] < b.fInputs[0];
66 template <
typename T,
typename ForestType>
69 std::function<T(T)> fObjectiveFunc;
73 void Inference(
const T *inputs,
const int rows,
bool layout, T *predictions);
82 template <
typename T,
typename ForestType>
83 inline void ForestBase<T, ForestType>::Inference(
const T *inputs,
const int rows,
bool layout, T *predictions)
85 const auto strideTree = layout ? 1 : rows;
86 const auto strideBatch = layout ? fNumInputs : 1;
87 for (
int i = 0; i < rows; i++) {
89 for (
auto &tree : fTrees) {
90 predictions[i] += tree.Inference(inputs + i * strideBatch, strideTree);
92 predictions[i] = fObjectiveFunc(predictions[i]);
100 struct BranchlessForest :
public ForestBase<T, std::vector<BranchlessTree<T>>> {
101 void Load(
const std::string &key,
const std::string &filename,
const int output = 0,
const bool sortTrees =
true);
110 template <
typename T>
112 BranchlessForest<T>::Load(
const std::string &key,
const std::string &filename,
const int output,
const bool sortTrees)
115 auto file = TFile::Open(filename.c_str(),
"READ");
118 auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/max_depth");
119 auto numTrees = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/num_trees");
120 auto numInputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/num_inputs");
121 auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/num_outputs");
122 auto objective = Internal::GetObjectSafe<std::string>(file, filename, key +
"/objective");
123 auto inputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/inputs");
124 auto outputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/outputs");
125 auto thresholds = Internal::GetObjectSafe<std::vector<T>>(file, filename, key +
"/thresholds");
127 this->fNumInputs = numInputs->at(0);
128 this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
129 const auto lenInputs = std::pow(2, maxDepth->at(0)) - 1;
130 const auto lenThresholds = std::pow(2, maxDepth->at(0) + 1) - 1;
133 if (output > numOutputs->at(0))
134 throw std::runtime_error(
"Given output node of the forest is larger or equal to number of output nodes.");
136 for (
int i = 0; i < numTrees->at(0); i++)
137 if (outputs->at(i) == output)
140 std::runtime_error(
"No trees found for given output node of the forest.");
141 this->fTrees.resize(c);
145 for (
int i = 0; i < numTrees->at(0); i++) {
147 if (outputs->at(i) != output)
151 this->fTrees[c].fTreeDepth = maxDepth->at(0);
154 this->fTrees[c].fInputs.resize(lenInputs);
155 for (
int j = 0; j < lenInputs; j++)
156 this->fTrees[c].fInputs[j] = inputs->at(i * lenInputs + j);
159 this->fTrees[c].fThresholds.resize(lenThresholds);
160 for (
int j = 0; j < lenThresholds; j++)
161 this->fTrees[c].fThresholds[j] = thresholds->at(i * lenThresholds + j);
164 this->fTrees[c].FillSparse();
171 std::sort(this->fTrees.begin(), this->fTrees.end(), Internal::CompareTree<T>);
186 template <
typename T>
187 struct BranchlessJittedForest :
public ForestBase<T, std::function<void (const T *, const int, bool, T*)>> {
188 std::string Load(
const std::string &key,
const std::string &filename,
const int output = 0,
const bool sortTrees =
true);
189 void Inference(
const T *inputs,
const int rows,
bool layout, T *predictions);
199 template <
typename T>
201 BranchlessJittedForest<T>::Load(
const std::string &key,
const std::string &filename,
const int output,
const bool sortTrees)
204 auto file = TFile::Open(filename.c_str(),
"READ");
207 auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/max_depth");
208 auto numTrees = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/num_trees");
209 auto numInputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/num_inputs");
210 auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/num_outputs");
211 auto objective = Internal::GetObjectSafe<std::string>(file, filename, key +
"/objective");
212 auto inputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/inputs");
213 auto outputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/outputs");
214 auto thresholds = Internal::GetObjectSafe<std::vector<T>>(file, filename, key +
"/thresholds");
216 this->fNumInputs = numInputs->at(0);
217 this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
218 const auto lenInputs = std::pow(2, maxDepth->at(0)) - 1;
219 const auto lenThresholds = std::pow(2, maxDepth->at(0) + 1) - 1;
222 if (output > numOutputs->at(0))
223 throw std::runtime_error(
"Given output node of the forest is larger or equal to number of output nodes.");
225 for (
int i = 0; i < numTrees->at(0); i++)
226 if (outputs->at(i) == output)
229 std::runtime_error(
"No trees found for given output node of the forest.");
232 std::string typeName = ROOT::Internal::GetDemangledTypeName(
typeid(T));
233 if (typeName.compare(
"") == 0) {
234 throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (typename as string)");
238 std::vector<T> firstThreshold(c);
239 std::vector<int> firstInput(c, -1);
240 std::vector<std::string> codes(c);
242 for (
int i = 0; i < numTrees->at(0); i++) {
244 if (outputs->at(i) != output)
248 BranchlessTree<T> tree;
249 tree.fTreeDepth = maxDepth->at(0);
252 tree.fInputs.resize(lenInputs);
253 for (
int j = 0; j < lenInputs; j++)
254 tree.fInputs[j] = inputs->at(i * lenInputs + j);
257 tree.fThresholds.resize(lenThresholds);
258 for (
int j = 0; j < lenThresholds; j++)
259 tree.fThresholds[j] = thresholds->at(i * lenThresholds + j);
265 firstThreshold[c] = tree.fThresholds[0];
267 firstInput[c] = tree.fInputs[0];
270 std::stringstream ss;
272 codes[c] = tree.GetInferenceCode(ss.str(), typeName);
278 std::vector<int> treeIndices(codes.size());
279 for(
int i = 0; i < c; i++) treeIndices[i] = i;
281 auto compareIndices = [&firstInput, &firstThreshold](
int i,
int j)
283 if (firstInput[i] == firstInput[j])
284 return firstThreshold[i] < firstThreshold[j];
286 return firstInput[i] < firstInput[j];
288 std::sort(treeIndices.begin(), treeIndices.end(), compareIndices);
293 std::string nameSpace = uuid.AsString();
294 for (
auto& v : nameSpace) {
295 if (v ==
'-') v =
'_';
297 nameSpace =
"ns_" + nameSpace;
300 std::stringstream jitForest;
301 jitForest <<
"#pragma cling optimize(3)\n"
302 <<
"namespace " << nameSpace <<
" {\n";
303 for (
int i = 0; i < static_cast<int>(codes.size()); i++) {
304 jitForest << codes[treeIndices[i]] <<
"\n\n";
306 jitForest <<
"void Inference(const "
307 << typeName <<
"* inputs, const int rows, bool layout, "
308 << typeName <<
"* predictions)"
310 <<
" const auto strideTree = layout ? 1 : rows;\n"
311 <<
" const auto strideBatch = layout ? " << this->fNumInputs <<
" : 1;\n"
312 <<
" for (int i = 0; i < rows; i++) {\n"
313 <<
" predictions[i] = 0.0;\n";
314 for (
int i = 0; i < static_cast<int>(codes.size()); i++) {
315 std::stringstream ss;
317 const std::string funcName = ss.str();
318 jitForest <<
" predictions[i] += " << funcName <<
"(inputs + i * strideBatch, strideTree);\n";
322 <<
"} // end namespace " << nameSpace;
323 const std::string jitForestStr = jitForest.str();
324 const auto err = gInterpreter->Declare(jitForestStr.c_str());
326 throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (declare function)");
330 std::stringstream treesFunc;
331 treesFunc <<
"#pragma cling optimize(3)\n" << nameSpace <<
"::Inference";
332 const std::string treesFuncStr = treesFunc.str();
333 auto ptr = gInterpreter->Calc(treesFuncStr.c_str());
335 throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (compile function)");
337 this->fTrees =
reinterpret_cast<void (*)(
const T *,
int,
bool,
float*)
>(ptr);
357 template <
typename T>
358 void BranchlessJittedForest<T>::Inference(
const T *inputs,
const int rows,
bool layout, T *predictions)
360 this->fTrees(inputs, rows, layout, predictions);
361 for (
int i = 0; i < rows; i++)
362 predictions[i] = this->fObjectiveFunc(predictions[i]);
368 #endif // TMVA_TREEINFERENCE_FOREST