15 namespace Experimental {
20 enum AnalysisType :
unsigned int { Undefined = 0, Classification, Regression, Multiclass };
24 unsigned int numVariables;
25 std::vector<std::string> variables;
26 std::vector<std::string> expressions;
27 unsigned int numClasses;
28 std::vector<std::string> classes;
29 AnalysisType analysisType;
31 : numVariables(0), variables(std::vector<std::string>(0)), numClasses(0), classes(std::vector<std::string>(0)),
32 analysisType(Internal::AnalysisType::Undefined)
38 inline XMLConfig ParseXMLConfig(
const std::string &filename)
44 auto xmldoc = xml.ParseFile(filename.c_str());
47 ss <<
"Failed to open TMVA XML file "
49 throw std::runtime_error(ss.str());
51 auto mainNode = xml.DocGetRootElement(xmldoc);
52 for (
auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
53 const auto nodeName = std::string(xml.GetNodeName(node));
55 if (nodeName.compare(
"Variables") == 0) {
56 c.numVariables = std::atoi(xml.GetAttr(node,
"NVar"));
57 c.variables = std::vector<std::string>(c.numVariables);
58 c.expressions = std::vector<std::string>(c.numVariables);
59 for (
auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
60 const auto iVariable = std::atoi(xml.GetAttr(thisNode,
"VarIndex"));
61 c.variables[iVariable] = xml.GetAttr(thisNode,
"Title");
62 c.expressions[iVariable] = xml.GetAttr(thisNode,
"Expression");
66 else if (nodeName.compare(
"Classes") == 0) {
67 c.numClasses = std::atoi(xml.GetAttr(node,
"NClass"));
68 for (
auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
69 c.classes.push_back(xml.GetAttr(thisNode,
"Name"));
73 else if (nodeName.compare(
"GeneralInfo") == 0) {
74 std::string analysisType =
"";
75 for (
auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
76 if (std::string(
"AnalysisType").compare(xml.GetAttr(thisNode,
"name")) == 0) {
77 analysisType = xml.GetAttr(thisNode,
"value");
80 if (analysisType.compare(
"Classification") == 0) {
81 c.analysisType = Internal::AnalysisType::Classification;
82 }
else if (analysisType.compare(
"Regression") == 0) {
83 c.analysisType = Internal::AnalysisType::Regression;
84 }
else if (analysisType.compare(
"Multiclass") == 0) {
85 c.analysisType = Internal::AnalysisType::Multiclass;
92 if (c.numVariables != c.variables.size() || c.numVariables == 0) {
94 ss <<
"Failed to parse input variables from TMVA config " << filename <<
".";
95 throw std::runtime_error(ss.str());
97 if (c.numClasses != c.classes.size() || c.numClasses == 0) {
99 ss <<
"Failed to parse output classes from TMVA config " << filename <<
".";
100 throw std::runtime_error(ss.str());
102 if (c.analysisType == Internal::AnalysisType::Undefined) {
103 std::stringstream ss;
104 ss <<
"Failed to parse analysis type from TMVA config " << filename <<
".";
105 throw std::runtime_error(ss.str());
116 std::unique_ptr<Reader> fReader;
117 std::vector<float> fValues;
118 std::vector<std::string> fVariables;
119 std::vector<std::string> fExpressions;
120 unsigned int fNumClasses;
121 const char *name =
"RReader";
122 Internal::AnalysisType fAnalysisType;
126 RReader(
const std::string &path)
129 auto c = Internal::ParseXMLConfig(path);
130 fVariables = c.variables;
131 fExpressions = c.expressions;
132 fAnalysisType = c.analysisType;
133 fNumClasses = c.numClasses;
136 fReader = std::make_unique<Reader>(
"Silent");
137 const auto numVars = fVariables.size();
138 fValues = std::vector<float>(numVars);
139 for (std::size_t i = 0; i < numVars; i++) {
140 fReader->AddVariable(TString(fExpressions[i]), &fValues[i]);
142 fReader->BookMVA(name, path.c_str());
146 std::vector<float> Compute(
const std::vector<float> &x)
148 if (x.size() != fVariables.size())
149 throw std::runtime_error(
"Size of input vector is not equal to number of variables.");
152 for (std::size_t i = 0; i < x.size(); i++) {
157 R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
161 if (fAnalysisType == Internal::AnalysisType::Classification) {
162 return std::vector<float>({
static_cast<float>(fReader->EvaluateMVA(name))});
165 else if (fAnalysisType == Internal::AnalysisType::Regression) {
166 return fReader->EvaluateRegression(name);
169 else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
170 return fReader->EvaluateMulticlass(name);
174 throw std::runtime_error(
"RReader has undefined analysis type.");
175 return std::vector<float>();
180 RTensor<float> Compute(RTensor<float> &x)
183 const auto shape = x.GetShape();
184 if (shape.size() != 2)
185 throw std::runtime_error(
"Can only compute model outputs for input tensor of rank 2.");
187 const auto numEntries = shape[0];
188 const auto numVars = shape[1];
189 if (numVars != fVariables.size())
190 throw std::runtime_error(
"Second dimension of input tensor is not equal to number of variables.");
193 unsigned int numClasses = 1;
194 if (fAnalysisType == Internal::AnalysisType::Multiclass)
195 numClasses = fNumClasses;
196 RTensor<float> y({numEntries * numClasses});
197 if (fAnalysisType == Internal::AnalysisType::Multiclass)
198 y = y.Reshape({numEntries, numClasses});
201 for (std::size_t i = 0; i < numEntries; i++) {
202 for (std::size_t j = 0; j < numVars; j++) {
203 fValues[j] = x(i, j);
205 R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
207 if (fAnalysisType == Internal::AnalysisType::Classification) {
208 y(i) = fReader->EvaluateMVA(name);
211 else if (fAnalysisType == Internal::AnalysisType::Regression) {
212 y(i) = fReader->EvaluateRegression(name)[0];
215 else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
216 const auto p = fReader->EvaluateMulticlass(name);
217 for (std::size_t k = 0; k < numClasses; k++)
225 std::vector<std::string> GetVariableNames() {
return fVariables; }
231 #endif // TMVA_RREADER