31 namespace Experimental {
34 template <
typename Backend = BranchlessJittedForest<
float>>
37 using Value_t =
typename Backend::Value_t;
38 using Backend_t = Backend;
42 bool fNormalizeOutputs;
43 std::vector<Backend_t> fBackends;
47 RBDT(
const std::string &key,
const std::string &filename)
50 auto file = TFile::Open(filename.c_str(),
"READ");
51 auto numOutputs = Internal::GetObjectSafe<std::vector<int>>(file, filename, key +
"/num_outputs");
52 fNumOutputs = numOutputs->at(0);
56 auto objective = Internal::GetObjectSafe<std::string>(file, filename, key +
"/objective");
57 if (objective->compare(
"softmax") == 0)
58 fNormalizeOutputs =
true;
60 fNormalizeOutputs =
false;
65 fBackends = std::vector<Backend_t>(fNumOutputs);
66 for (
int i = 0; i < fNumOutputs; i++)
67 fBackends[i].Load(key, filename, i);
74 template <
typename Vector>
75 Vector Compute(
const Vector &x)
78 y.resize(fNumOutputs);
79 for (
int i = 0; i < fNumOutputs; i++)
80 fBackends[i].Inference(&x[0], 1,
true, &y[i]);
81 if (fNormalizeOutputs) {
83 for (
int i = 0; i < fNumOutputs; i++)
85 for (
int i = 0; i < fNumOutputs; i++)
92 std::vector<Value_t> Compute(
const std::vector<Value_t> &x) {
return this->Compute<std::vector<Value_t>>(x); }
95 RTensor<Value_t> Compute(
const RTensor<Value_t> &x)
97 const auto rows = x.GetShape()[0];
98 RTensor<Value_t> y({rows,
static_cast<std::size_t
>(fNumOutputs)}, MemoryLayout::ColumnMajor);
99 const bool layout = x.GetMemoryLayout() == MemoryLayout::ColumnMajor ?
false :
true;
100 for (
int i = 0; i < fNumOutputs; i++)
101 fBackends[i].Inference(x.GetData(), rows, layout, &y(0, i));
102 if (fNormalizeOutputs) {
104 for (
int i = 0; i < static_cast<int>(rows); i++) {
106 for (
int j = 0; j < fNumOutputs; j++)
108 for (
int j = 0; j < fNumOutputs; j++)
116 extern template class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessForest<float>>;
117 extern template class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessJittedForest<float>>;