Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
RTensorUtils.hxx
Go to the documentation of this file.
1 #ifndef TMVA_RTENSOR_UTILS
2 #define TMVA_RTENSOR_UTILS
3 
4 #include <vector>
5 #include <string>
6 
7 #include "TMVA/RTensor.hxx"
8 #include "ROOT/RDataFrame.hxx"
10 
11 namespace TMVA {
12 namespace Experimental {
13 
14 /// \brief Convert the content of an RDataFrame to an RTensor
15 /// \param[in] dataframe RDataFrame node
16 /// \param[in] columns Vector of column names
17 /// \param[in] layout Memory layout
18 /// \return RTensor with content from selected columns
19 template <typename T, typename U>
20 RTensor<T>
21 AsTensor(U &dataframe, std::vector<std::string> columns = {}, MemoryLayout layout = MemoryLayout::RowMajor)
22 {
23  // If no columns are specified, get all columns from dataframe
24  if (columns.size() == 0) {
25  columns = dataframe.GetColumnNames();
26  }
27 
28  // Book actions to read-out columns of dataframe in vectors
29  using ResultPtr = ROOT::RDF::RResultPtr<std::vector<T>>;
30  std::vector<ResultPtr> resultPtrs;
31  for (auto &col : columns) {
32  resultPtrs.emplace_back(dataframe.template Take<T>(col));
33  }
34 
35  // Copy data to tensor based on requested memory layout
36  const auto numCols = resultPtrs.size();
37  const auto numEntries = resultPtrs[0]->size();
38  RTensor<T> x({numEntries, numCols}, layout);
39  const auto data = x.GetData();
40  if (layout == MemoryLayout::RowMajor) {
41  for (std::size_t i = 0; i < numEntries; i++) {
42  const auto entry = data + numCols * i;
43  for (std::size_t j = 0; j < numCols; j++) {
44  entry[j] = resultPtrs[j]->at(i);
45  }
46  }
47  } else if (layout == MemoryLayout::ColumnMajor) {
48  for (std::size_t i = 0; i < numCols; i++) {
49  // TODO: Replace by RVec<T>::insert as soon as available.
50  std::memcpy(data + numEntries * i, &resultPtrs[i]->at(0), numEntries * sizeof(T));
51  }
52  } else {
53  throw std::runtime_error("Memory layout is not known.");
54  }
55 
56  // Remove dimensions of 1
57  x.Squeeze();
58 
59  return x;
60 }
61 
62 } // namespace TMVA::Experimental
63 } // namespace TMVA
64 
65 #endif