Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
RBDT.hxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Web : http://tmva.sourceforge.net *
5  * *
6  * Description: *
7  * *
8  * Authors: *
9  * Stefan Wunsch (stefan.wunsch@cern.ch) *
10  * *
11  * Copyright (c) 2019: *
12  * CERN, Switzerland *
13  * *
14  * Redistribution and use in source and binary forms, with or without *
15  * modification, are permitted according to the terms listed in LICENSE *
16  * (http://tmva.sourceforge.net/LICENSE) *
17  **********************************************************************************/
18 
19 #ifndef TMVA_RBDT
20 #define TMVA_RBDT
21 
22 #include "TMVA/RTensor.hxx"
24 #include "TFile.h"
25 
26 #include <vector>
27 #include <string>
28 #include <sstream> // std::stringstream
29 
30 namespace TMVA {
31 namespace Experimental {
32 
33 /// Fast boosted decision tree inference
34 template <typename Backend = BranchlessJittedForest<float>>
35 class RBDT {
36 public:
37  using Value_t = typename Backend::Value_t;
38  using Backend_t = Backend;
39 
40 private:
41  int fNumOutputs;
42  bool fNormalizeOutputs;
43  std::vector<Backend_t> fBackends;
44 
45 public:
46  /// Construct backends from model in ROOT file
47  RBDT(const std::string &key, const std::string &filename)
48  {
49  // Get number of output nodes of the forest
50  auto file = TFile::Open(filename.c_str(), "READ");
51  auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key + "/num_outputs");
52  fNumOutputs = numOutputs->at(0);
53  delete numOutputs;
54 
55  // Get objective and decide whether to normalize output nodes for example in the multiclass case
56  auto objective = Internal::GetObjectSafe<std::string>(file, filename, key + "/objective");
57  if (objective->compare("softmax") == 0)
58  fNormalizeOutputs = true;
59  else
60  fNormalizeOutputs = false;
61  delete objective;
62  file->Close();
63 
64  // Initialize backends
65  fBackends = std::vector<Backend_t>(fNumOutputs);
66  for (int i = 0; i < fNumOutputs; i++)
67  fBackends[i].Load(key, filename, i);
68  }
69 
70  /// Compute model prediction on a single event
71  ///
72  /// The method is intended to be used with std::vectors-like containers,
73  /// for example RVecs.
74  template <typename Vector>
75  Vector Compute(const Vector &x)
76  {
77  Vector y;
78  y.resize(fNumOutputs);
79  for (int i = 0; i < fNumOutputs; i++)
80  fBackends[i].Inference(&x[0], 1, true, &y[i]);
81  if (fNormalizeOutputs) {
82  Value_t s = 0.0;
83  for (int i = 0; i < fNumOutputs; i++)
84  s += y[i];
85  for (int i = 0; i < fNumOutputs; i++)
86  y[i] /= s;
87  }
88  return y;
89  }
90 
91  /// Compute model prediction on a single event
92  std::vector<Value_t> Compute(const std::vector<Value_t> &x) { return this->Compute<std::vector<Value_t>>(x); }
93 
94  /// Compute model prediction on input RTensor
95  RTensor<Value_t> Compute(const RTensor<Value_t> &x)
96  {
97  const auto rows = x.GetShape()[0];
98  RTensor<Value_t> y({rows, static_cast<std::size_t>(fNumOutputs)}, MemoryLayout::ColumnMajor);
99  const bool layout = x.GetMemoryLayout() == MemoryLayout::ColumnMajor ? false : true;
100  for (int i = 0; i < fNumOutputs; i++)
101  fBackends[i].Inference(x.GetData(), rows, layout, &y(0, i));
102  if (fNormalizeOutputs) {
103  Value_t s;
104  for (int i = 0; i < static_cast<int>(rows); i++) {
105  s = 0.0;
106  for (int j = 0; j < fNumOutputs; j++)
107  s += y(i, j);
108  for (int j = 0; j < fNumOutputs; j++)
109  y(i, j) /= s;
110  }
111  }
112  return y;
113  }
114 };
115 
116 extern template class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessForest<float>>;
117 extern template class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessJittedForest<float>>;
118 
119 } // namespace Experimental
120 } // namespace TMVA
121 
122 #endif // TMVA_RBDT