20 #ifndef TMVA_TREEINFERENCE_FOREST 
   21 #define TMVA_TREEINFERENCE_FOREST 
   40 namespace Experimental {
 
   44 T *GetObjectSafe(TFile *f, 
const std::string &n, 
const std::string &m)
 
   46    auto v = 
reinterpret_cast<T *
>(f->Get(m.c_str()));
 
   48       throw std::runtime_error(
"Failed to read " + m + 
" from file " + n + 
".");
 
   53 bool CompareTree(
const BranchlessTree<T> &a, 
const BranchlessTree<T> &b)
 
   55    if (a.fInputs[0] == b.fInputs[0])
 
   56       return a.fThresholds[0] < b.fThresholds[0];
 
   58       return a.fInputs[0] < b.fInputs[0];
 
   66 template <
typename T, 
typename ForestType>
 
   69    std::function<T(T)> fObjectiveFunc; 
 
   73    void Inference(
const T *inputs, 
const int rows, 
bool layout, T *predictions);
 
   82 template <
typename T, 
typename ForestType>
 
   83 inline void ForestBase<T, ForestType>::Inference(
const T *inputs, 
const int rows, 
bool layout, T *predictions)
 
   85    const auto strideTree = layout ? 1 : rows;
 
   86    const auto strideBatch = layout ? fNumInputs : 1;
 
   87    for (
int i = 0; i < rows; i++) {
 
   89       for (
auto &tree : fTrees) {
 
   90          predictions[i] += tree.Inference(inputs + i * strideBatch, strideTree);
 
   92       predictions[i] = fObjectiveFunc(predictions[i]);
 
  100 struct BranchlessForest : 
public ForestBase<T, std::vector<BranchlessTree<T>>> {
 
  101    void Load(
const std::string &key, 
const std::string &filename, 
const int output = 0, 
const bool sortTrees = 
true);
 
  110 template <
typename T>
 
  112 BranchlessForest<T>::Load(
const std::string &key, 
const std::string &filename, 
const int output, 
const bool sortTrees)
 
  115    auto file = TFile::Open(filename.c_str(), 
"READ");
 
  118    auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/max_depth");
 
  119    auto numTrees = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/num_trees");
 
  120    auto numInputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/num_inputs");
 
  121    auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/num_outputs");
 
  122    auto objective = Internal::GetObjectSafe<std::string>(file, filename, key + 
"/objective");
 
  123    auto inputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/inputs");
 
  124    auto outputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/outputs");
 
  125    auto thresholds = Internal::GetObjectSafe<std::vector<T>>(file, filename, key + 
"/thresholds");
 
  127    this->fNumInputs = numInputs->at(0);
 
  128    this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
 
  129    const auto lenInputs = std::pow(2, maxDepth->at(0)) - 1;
 
  130    const auto lenThresholds = std::pow(2, maxDepth->at(0) + 1) - 1;
 
  133    if (output > numOutputs->at(0))
 
  134       throw std::runtime_error(
"Given output node of the forest is larger or equal to number of output nodes.");
 
  136    for (
int i = 0; i < numTrees->at(0); i++)
 
  137       if (outputs->at(i) == output)
 
  140       std::runtime_error(
"No trees found for given output node of the forest.");
 
  141    this->fTrees.resize(c);
 
  145    for (
int i = 0; i < numTrees->at(0); i++) {
 
  147       if (outputs->at(i) != output)
 
  151       this->fTrees[c].fTreeDepth = maxDepth->at(0);
 
  154       this->fTrees[c].fInputs.resize(lenInputs);
 
  155       for (
int j = 0; j < lenInputs; j++)
 
  156          this->fTrees[c].fInputs[j] = inputs->at(i * lenInputs + j);
 
  159       this->fTrees[c].fThresholds.resize(lenThresholds);
 
  160       for (
int j = 0; j < lenThresholds; j++)
 
  161          this->fTrees[c].fThresholds[j] = thresholds->at(i * lenThresholds + j);
 
  164       this->fTrees[c].FillSparse();
 
  171       std::sort(this->fTrees.begin(), this->fTrees.end(), Internal::CompareTree<T>);
 
  186 template <
typename T>
 
  187 struct BranchlessJittedForest : 
public ForestBase<T, std::function<void (const T *, const int, bool, T*)>> {
 
  188     std::string Load(
const std::string &key, 
const std::string &filename, 
const int output = 0, 
const bool sortTrees = 
true);
 
  189    void Inference(
const T *inputs, 
const int rows, 
bool layout, T *predictions);
 
  199 template <
typename T>
 
  201 BranchlessJittedForest<T>::Load(
const std::string &key, 
const std::string &filename, 
const int output, 
const bool sortTrees)
 
  204    auto file = TFile::Open(filename.c_str(), 
"READ");
 
  207    auto maxDepth = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/max_depth");
 
  208    auto numTrees = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/num_trees");
 
  209    auto numInputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/num_inputs");
 
  210    auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/num_outputs");
 
  211    auto objective = Internal::GetObjectSafe<std::string>(file, filename, key + 
"/objective");
 
  212    auto inputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/inputs");
 
  213    auto outputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + 
"/outputs");
 
  214    auto thresholds = Internal::GetObjectSafe<std::vector<T>>(file, filename, key + 
"/thresholds");
 
  216    this->fNumInputs = numInputs->at(0);
 
  217    this->fObjectiveFunc = Objectives::GetFunction<T>(*objective);
 
  218    const auto lenInputs = std::pow(2, maxDepth->at(0)) - 1;
 
  219    const auto lenThresholds = std::pow(2, maxDepth->at(0) + 1) - 1;
 
  222    if (output > numOutputs->at(0))
 
  223       throw std::runtime_error(
"Given output node of the forest is larger or equal to number of output nodes.");
 
  225    for (
int i = 0; i < numTrees->at(0); i++)
 
  226       if (outputs->at(i) == output)
 
  229       std::runtime_error(
"No trees found for given output node of the forest.");
 
  232    std::string typeName = ROOT::Internal::GetDemangledTypeName(
typeid(T));
 
  233    if (typeName.compare(
"") == 0) {
 
  234       throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (typename as string)");
 
  238    std::vector<T> firstThreshold(c);
 
  239    std::vector<int> firstInput(c, -1);
 
  240    std::vector<std::string> codes(c);
 
  242    for (
int i = 0; i < numTrees->at(0); i++) {
 
  244       if (outputs->at(i) != output)
 
  248       BranchlessTree<T> tree;
 
  249       tree.fTreeDepth = maxDepth->at(0);
 
  252       tree.fInputs.resize(lenInputs);
 
  253       for (
int j = 0; j < lenInputs; j++)
 
  254          tree.fInputs[j] = inputs->at(i * lenInputs + j);
 
  257       tree.fThresholds.resize(lenThresholds);
 
  258       for (
int j = 0; j < lenThresholds; j++)
 
  259          tree.fThresholds[j] = thresholds->at(i * lenThresholds + j);
 
  265       firstThreshold[c] = tree.fThresholds[0];
 
  267           firstInput[c] = tree.fInputs[0];
 
  270       std::stringstream ss;
 
  272       codes[c] = tree.GetInferenceCode(ss.str(), typeName);
 
  278    std::vector<int> treeIndices(codes.size());
 
  279    for(
int i = 0; i < c; i++) treeIndices[i] = i;
 
  281       auto compareIndices = [&firstInput, &firstThreshold](
int i, 
int j)
 
  283                  if (firstInput[i] == firstInput[j])
 
  284                     return firstThreshold[i] < firstThreshold[j];
 
  286                     return firstInput[i] < firstInput[j];
 
  288       std::sort(treeIndices.begin(), treeIndices.end(), compareIndices);
 
  293    std::string nameSpace = uuid.AsString();
 
  294    for (
auto& v : nameSpace) {
 
  295       if (v == 
'-') v = 
'_';
 
  297    nameSpace = 
"ns_" + nameSpace;
 
  300    std::stringstream jitForest;
 
  301    jitForest << 
"#pragma cling optimize(3)\n" 
  302              << 
"namespace " << nameSpace << 
" {\n";
 
  303    for (
int i = 0; i < static_cast<int>(codes.size()); i++) {
 
  304       jitForest << codes[treeIndices[i]] << 
"\n\n";
 
  306    jitForest << 
"void Inference(const " 
  307              << typeName << 
"* inputs, const int rows, bool layout, " 
  308              << typeName << 
"* predictions)" 
  310              << 
"   const auto strideTree = layout ? 1 : rows;\n" 
  311              << 
"   const auto strideBatch = layout ? " << this->fNumInputs << 
" : 1;\n" 
  312              << 
"   for (int i = 0; i < rows; i++) {\n" 
  313              << 
"      predictions[i] = 0.0;\n";
 
  314    for (
int i = 0; i < static_cast<int>(codes.size()); i++) {
 
  315       std::stringstream ss;
 
  317       const std::string funcName = ss.str();
 
  318       jitForest << 
"      predictions[i] += " << funcName << 
"(inputs + i * strideBatch, strideTree);\n";
 
  322              << 
"} // end namespace " << nameSpace;
 
  323    const std::string jitForestStr = jitForest.str();
 
  324    const auto err = gInterpreter->Declare(jitForestStr.c_str());
 
  326       throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (declare function)");
 
  330    std::stringstream treesFunc;
 
  331    treesFunc << 
"#pragma cling optimize(3)\n" << nameSpace << 
"::Inference";
 
  332    const std::string treesFuncStr = treesFunc.str();
 
  333    auto ptr = gInterpreter->Calc(treesFuncStr.c_str());
 
  335       throw std::runtime_error(
"Failed to just-in-time compile inference code for branchless forest (compile function)");
 
  337    this->fTrees = 
reinterpret_cast<void (*)(
const T *, 
int, 
bool, 
float*)
>(ptr);
 
  357 template <
typename T>
 
  358 void BranchlessJittedForest<T>::Inference(
const T *inputs, 
const int rows, 
bool layout, T *predictions)
 
  360    this->fTrees(inputs, rows, layout, predictions);
 
  361    for (
int i = 0; i < rows; i++)
 
  362       predictions[i] = this->fObjectiveFunc(predictions[i]);
 
  368 #endif // TMVA_TREEINFERENCE_FOREST