Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
RReader.hxx
Go to the documentation of this file.
1 #ifndef TMVA_RREADER
2 #define TMVA_RREADER
3 
4 #include "TString.h"
5 #include "TXMLEngine.h"
6 #include "ROOT/RMakeUnique.hxx"
7 
8 #include "TMVA/RTensor.hxx"
9 #include "TMVA/Reader.h"
10 
11 #include <memory> // std::unique_ptr
12 #include <sstream> // std::stringstream
13 
14 namespace TMVA {
15 namespace Experimental {
16 
17 namespace Internal {
18 
19 /// Internal definition of analysis types
20 enum AnalysisType : unsigned int { Undefined = 0, Classification, Regression, Multiclass };
21 
22 /// Container for information extracted from TMVA XML config
23 struct XMLConfig {
24  unsigned int numVariables;
25  std::vector<std::string> variables;
26  std::vector<std::string> expressions;
27  unsigned int numClasses;
28  std::vector<std::string> classes;
29  AnalysisType analysisType;
30  XMLConfig()
31  : numVariables(0), variables(std::vector<std::string>(0)), numClasses(0), classes(std::vector<std::string>(0)),
32  analysisType(Internal::AnalysisType::Undefined)
33  {
34  }
35 };
36 
37 /// Parse TMVA XML config
38 inline XMLConfig ParseXMLConfig(const std::string &filename)
39 {
40  XMLConfig c;
41 
42  // Parse XML file and find root node
43  TXMLEngine xml;
44  auto xmldoc = xml.ParseFile(filename.c_str());
45  if (xmldoc == 0) {
46  std::stringstream ss;
47  ss << "Failed to open TMVA XML file "
48  << filename << ".";
49  throw std::runtime_error(ss.str());
50  }
51  auto mainNode = xml.DocGetRootElement(xmldoc);
52  for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
53  const auto nodeName = std::string(xml.GetNodeName(node));
54  // Read out input variables
55  if (nodeName.compare("Variables") == 0) {
56  c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
57  c.variables = std::vector<std::string>(c.numVariables);
58  c.expressions = std::vector<std::string>(c.numVariables);
59  for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
60  const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
61  c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
62  c.expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
63  }
64  }
65  // Read out output classes
66  else if (nodeName.compare("Classes") == 0) {
67  c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
68  for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
69  c.classes.push_back(xml.GetAttr(thisNode, "Name"));
70  }
71  }
72  // Read out analysis type
73  else if (nodeName.compare("GeneralInfo") == 0) {
74  std::string analysisType = "";
75  for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
76  if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
77  analysisType = xml.GetAttr(thisNode, "value");
78  }
79  }
80  if (analysisType.compare("Classification") == 0) {
81  c.analysisType = Internal::AnalysisType::Classification;
82  } else if (analysisType.compare("Regression") == 0) {
83  c.analysisType = Internal::AnalysisType::Regression;
84  } else if (analysisType.compare("Multiclass") == 0) {
85  c.analysisType = Internal::AnalysisType::Multiclass;
86  }
87  }
88  }
89  xml.FreeDoc(xmldoc);
90 
91  // Error-handling
92  if (c.numVariables != c.variables.size() || c.numVariables == 0) {
93  std::stringstream ss;
94  ss << "Failed to parse input variables from TMVA config " << filename << ".";
95  throw std::runtime_error(ss.str());
96  }
97  if (c.numClasses != c.classes.size() || c.numClasses == 0) {
98  std::stringstream ss;
99  ss << "Failed to parse output classes from TMVA config " << filename << ".";
100  throw std::runtime_error(ss.str());
101  }
102  if (c.analysisType == Internal::AnalysisType::Undefined) {
103  std::stringstream ss;
104  ss << "Failed to parse analysis type from TMVA config " << filename << ".";
105  throw std::runtime_error(ss.str());
106  }
107 
108  return c;
109 }
110 
111 } // namespace Internal
112 
113 /// TMVA::Reader legacy interface
114 class RReader {
115 private:
116  std::unique_ptr<Reader> fReader;
117  std::vector<float> fValues;
118  std::vector<std::string> fVariables;
119  std::vector<std::string> fExpressions;
120  unsigned int fNumClasses;
121  const char *name = "RReader";
122  Internal::AnalysisType fAnalysisType;
123 
124 public:
125  /// Create TMVA model from XML file
126  RReader(const std::string &path)
127  {
128  // Load config
129  auto c = Internal::ParseXMLConfig(path);
130  fVariables = c.variables;
131  fExpressions = c.expressions;
132  fAnalysisType = c.analysisType;
133  fNumClasses = c.numClasses;
134 
135  // Setup reader
136  fReader = std::make_unique<Reader>("Silent");
137  const auto numVars = fVariables.size();
138  fValues = std::vector<float>(numVars);
139  for (std::size_t i = 0; i < numVars; i++) {
140  fReader->AddVariable(TString(fExpressions[i]), &fValues[i]);
141  }
142  fReader->BookMVA(name, path.c_str());
143  }
144 
145  /// Compute model prediction on vector
146  std::vector<float> Compute(const std::vector<float> &x)
147  {
148  if (x.size() != fVariables.size())
149  throw std::runtime_error("Size of input vector is not equal to number of variables.");
150 
151  // Copy over inputs to memory used by TMVA reader
152  for (std::size_t i = 0; i < x.size(); i++) {
153  fValues[i] = x[i];
154  }
155 
156  // Take lock to protect model evaluation
157  R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
158 
159  // Evaluate TMVA model
160  // Classification
161  if (fAnalysisType == Internal::AnalysisType::Classification) {
162  return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
163  }
164  // Regression
165  else if (fAnalysisType == Internal::AnalysisType::Regression) {
166  return fReader->EvaluateRegression(name);
167  }
168  // Multiclass
169  else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
170  return fReader->EvaluateMulticlass(name);
171  }
172  // Throw error
173  else {
174  throw std::runtime_error("RReader has undefined analysis type.");
175  return std::vector<float>();
176  }
177  }
178 
179  /// Compute model prediction on input RTensor
180  RTensor<float> Compute(RTensor<float> &x)
181  {
182  // Error-handling for input tensor
183  const auto shape = x.GetShape();
184  if (shape.size() != 2)
185  throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
186 
187  const auto numEntries = shape[0];
188  const auto numVars = shape[1];
189  if (numVars != fVariables.size())
190  throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
191 
192  // Define shape of output tensor based on analysis type
193  unsigned int numClasses = 1;
194  if (fAnalysisType == Internal::AnalysisType::Multiclass)
195  numClasses = fNumClasses;
196  RTensor<float> y({numEntries * numClasses});
197  if (fAnalysisType == Internal::AnalysisType::Multiclass)
198  y = y.Reshape({numEntries, numClasses});
199 
200  // Fill output tensor
201  for (std::size_t i = 0; i < numEntries; i++) {
202  for (std::size_t j = 0; j < numVars; j++) {
203  fValues[j] = x(i, j);
204  }
205  R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
206  // Classification
207  if (fAnalysisType == Internal::AnalysisType::Classification) {
208  y(i) = fReader->EvaluateMVA(name);
209  }
210  // Regression
211  else if (fAnalysisType == Internal::AnalysisType::Regression) {
212  y(i) = fReader->EvaluateRegression(name)[0];
213  }
214  // Multiclass
215  else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
216  const auto p = fReader->EvaluateMulticlass(name);
217  for (std::size_t k = 0; k < numClasses; k++)
218  y(i, k) = p[k];
219  }
220  }
221 
222  return y;
223  }
224 
225  std::vector<std::string> GetVariableNames() { return fVariables; }
226 };
227 
228 } // namespace Experimental
229 } // namespace TMVA
230 
231 #endif // TMVA_RREADER