Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
RArrowDS.cxx
Go to the documentation of this file.
1 // Author: Giulio Eulisse CERN 2/2018
2 
3 /*************************************************************************
4  * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers. *
5  * All rights reserved. *
6  * *
7  * For the licensing terms see $ROOTSYS/LICENSE. *
8  * For the list of contributors see $ROOTSYS/README/CREDITS. *
9  *************************************************************************/
10 
11 // clang-format off
12 /** \class ROOT::RDF::RArrowDS
13  \ingroup dataframe
14  \brief RDataFrame data source class to interface with Apache Arrow.
15 
16 The RArrowDS implements a proxy RDataSource to be able to use Apache Arrow
17 tables with RDataFrame.
18 
19 A RDataFrame that adapts an arrow::Table class can be constructed using the factory method
20 ROOT::RDF::MakeArrowDataFrame, which accepts one parameter:
21 1. An arrow::Table smart pointer.
22 
23 The types of the columns are derived from the types in the associated
24 arrow::Schema.
25 
26 */
27 // clang-format on
28 
29 #include <ROOT/RDF/Utils.hxx>
30 #include <ROOT/TSeq.hxx>
31 #include <ROOT/RArrowDS.hxx>
32 #include <ROOT/RMakeUnique.hxx>
33 
34 #include <algorithm>
35 #include <sstream>
36 #include <string>
37 
38 #if defined(__GNUC__)
39 #pragma GCC diagnostic push
40 #pragma GCC diagnostic ignored "-Wshadow"
41 #pragma GCC diagnostic ignored "-Wunused-parameter"
42 #endif
43 #include <arrow/table.h>
44 #include <arrow/stl.h>
45 #if defined(__GNUC__)
46 #pragma GCC diagnostic pop
47 #endif
48 
49 namespace ROOT {
50 namespace Internal {
51 namespace RDF {
52 
53 // This is needed by Arrow 0.12.0 which dropped
54 //
55 // using ArrowType = ArrowType_;
56 //
57 // from ARROW_STL_CONVERSION
58 template <typename T>
59 struct RootConversionTraits {};
60 
61 #define ROOT_ARROW_STL_CONVERSION(c_type, ArrowType_) \
62  template <> \
63  struct RootConversionTraits<c_type> { \
64  using ArrowType = ::arrow::ArrowType_; \
65  };
66 
67 ROOT_ARROW_STL_CONVERSION(bool, BooleanType)
68 ROOT_ARROW_STL_CONVERSION(int8_t, Int8Type)
69 ROOT_ARROW_STL_CONVERSION(int16_t, Int16Type)
70 ROOT_ARROW_STL_CONVERSION(int32_t, Int32Type)
71 ROOT_ARROW_STL_CONVERSION(Long64_t, Int64Type)
72 ROOT_ARROW_STL_CONVERSION(uint8_t, UInt8Type)
73 ROOT_ARROW_STL_CONVERSION(uint16_t, UInt16Type)
74 ROOT_ARROW_STL_CONVERSION(uint32_t, UInt32Type)
75 ROOT_ARROW_STL_CONVERSION(ULong64_t, UInt64Type)
76 ROOT_ARROW_STL_CONVERSION(float, FloatType)
77 ROOT_ARROW_STL_CONVERSION(double, DoubleType)
78 ROOT_ARROW_STL_CONVERSION(std::string, StringType)
79 
80 // Per slot visitor of an Array.
81 class ArrayPtrVisitor : public ::arrow::ArrayVisitor {
82 private:
83  /// The pointer to update.
84  void **fResult;
85  bool fCachedBool{false}; // Booleans need to be unpacked, so we use a cached entry.
86  // FIXME: I should really use a variant here
87  RVec<float> fCachedRVecFloat;
88  RVec<double> fCachedRVecDouble;
89  RVec<ULong64_t> fCachedRVecULong64;
90  RVec<UInt_t> fCachedRVecUInt;
91  RVec<Long64_t> fCachedRVecLong64;
92  RVec<Int_t> fCachedRVecInt;
93  std::string fCachedString;
94  /// The entry in the array which should be looked up.
95  ULong64_t fCurrentEntry;
96 
97  template <typename T>
98  void *getTypeErasedPtrFrom(arrow::ListArray const &array, int32_t entry, RVec<T> &cache)
99  {
100  using ArrowType = typename RootConversionTraits<T>::ArrowType;
101  using ArrayType = typename arrow::TypeTraits<ArrowType>::ArrayType;
102  auto values = reinterpret_cast<ArrayType *>(array.values().get());
103  auto offset = array.value_offset(entry);
104  // Here the cast to void* is a worksround while we figure out the
105  // issues we have with long long types, signed and unsigned.
106  RVec<T> tmp(reinterpret_cast<T *>((void *)values->raw_values()) + offset, array.value_length(entry));
107  std::swap(cache, tmp);
108  return (void *)(&cache);
109  }
110 
111 public:
112  ArrayPtrVisitor(void **result) : fResult{result}, fCurrentEntry{0} {}
113 
114  void SetEntry(ULong64_t entry) { fCurrentEntry = entry; }
115 
116  /// Check if we are asking the same entry as before.
117  virtual arrow::Status Visit(arrow::Int32Array const &array) final
118  {
119  *fResult = (void *)(array.raw_values() + fCurrentEntry);
120  return arrow::Status::OK();
121  }
122 
123  virtual arrow::Status Visit(arrow::Int64Array const &array) final
124  {
125  *fResult = (void *)(array.raw_values() + fCurrentEntry);
126  return arrow::Status::OK();
127  }
128 
129  /// Check if we are asking the same entry as before.
130  virtual arrow::Status Visit(arrow::UInt32Array const &array) final
131  {
132  *fResult = (void *)(array.raw_values() + fCurrentEntry);
133  return arrow::Status::OK();
134  }
135 
136  virtual arrow::Status Visit(arrow::UInt64Array const &array) final
137  {
138  *fResult = (void *)(array.raw_values() + fCurrentEntry);
139  return arrow::Status::OK();
140  }
141 
142  virtual arrow::Status Visit(arrow::FloatArray const &array) final
143  {
144  *fResult = (void *)(array.raw_values() + fCurrentEntry);
145  return arrow::Status::OK();
146  }
147 
148  virtual arrow::Status Visit(arrow::DoubleArray const &array) final
149  {
150  *fResult = (void *)(array.raw_values() + fCurrentEntry);
151  return arrow::Status::OK();
152  }
153 
154  virtual arrow::Status Visit(arrow::BooleanArray const &array) final
155  {
156  fCachedBool = array.Value(fCurrentEntry);
157  *fResult = reinterpret_cast<void *>(&fCachedBool);
158  return arrow::Status::OK();
159  }
160 
161  virtual arrow::Status Visit(arrow::StringArray const &array) final
162  {
163  fCachedString = array.GetString(fCurrentEntry);
164  *fResult = reinterpret_cast<void *>(&fCachedString);
165  return arrow::Status::OK();
166  }
167 
168  virtual arrow::Status Visit(arrow::ListArray const &array) final
169  {
170  switch (array.value_type()->id()) {
171  case arrow::Type::FLOAT: {
172  *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecFloat);
173  return arrow::Status::OK();
174  }
175  case arrow::Type::DOUBLE: {
176  *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecDouble);
177  return arrow::Status::OK();
178  }
179  case arrow::Type::UINT32: {
180  *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecUInt);
181  return arrow::Status::OK();
182  }
183  case arrow::Type::UINT64: {
184  *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecULong64);
185  return arrow::Status::OK();
186  }
187  case arrow::Type::INT32: {
188  *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecInt);
189  return arrow::Status::OK();
190  }
191  case arrow::Type::INT64: {
192  *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecLong64);
193  return arrow::Status::OK();
194  }
195  default: return arrow::Status::TypeError("Type not supported");
196  }
197  }
198 
199  using ::arrow::ArrayVisitor::Visit;
200 };
201 
202 /// Helper class which keeps track for each slot where to get the entry.
203 class TValueGetter {
204 private:
205  std::vector<void *> fValuesPtrPerSlot;
206  std::vector<ULong64_t> fLastEntryPerSlot;
207  std::vector<ULong64_t> fLastChunkPerSlot;
208  std::vector<ULong64_t> fFirstEntryPerChunk;
209  std::vector<ArrayPtrVisitor> fArrayVisitorPerSlot;
210  /// Since data can be chunked in different arrays we need to construct an
211  /// index which contains the first element of each chunk, so that we can
212  /// quickly move to the correct chunk.
213  std::vector<ULong64_t> fChunkIndex;
214  arrow::ArrayVector fChunks;
215 
216 public:
217  TValueGetter(size_t slots, arrow::ArrayVector chunks)
218  : fValuesPtrPerSlot(slots, nullptr), fLastEntryPerSlot(slots, 0), fLastChunkPerSlot(slots, 0), fChunks{chunks}
219  {
220  fChunkIndex.reserve(fChunks.size());
221  size_t next = 0;
222  for (auto &chunk : chunks) {
223  fFirstEntryPerChunk.push_back(next);
224  next += chunk->length();
225  fChunkIndex.push_back(next);
226  }
227  for (size_t si = 0, se = fValuesPtrPerSlot.size(); si != se; ++si) {
228  fArrayVisitorPerSlot.push_back(ArrayPtrVisitor{fValuesPtrPerSlot.data() + si});
229  }
230  }
231 
232  /// This returns the ptr to the ptr to actual data.
233  std::vector<void *> SlotPtrs()
234  {
235  std::vector<void *> result;
236  for (size_t i = 0; i < fValuesPtrPerSlot.size(); ++i) {
237  result.push_back(fValuesPtrPerSlot.data() + i);
238  }
239  return result;
240  }
241 
242  // Convenience method to avoid code duplication between
243  // SetEntry and InitSlot
244  void UncachedSlotLookup(unsigned int slot, ULong64_t entry)
245  {
246  // If entry is greater than the previous one,
247  // we can skip all the chunks before the last one we
248  // queried.
249  size_t ci = 0;
250  assert(slot < fLastChunkPerSlot.size());
251  if (fLastEntryPerSlot[slot] < entry) {
252  ci = fLastChunkPerSlot.at(slot);
253  }
254 
255  for (size_t ce = fChunkIndex.size(); ci != ce; ++ci) {
256  if (entry < fChunkIndex[ci]) {
257  assert(slot < fLastChunkPerSlot.size());
258  fLastChunkPerSlot[slot] = ci;
259  break;
260  }
261  }
262 
263  // Update the pointer to the requested entry.
264  // Notice that we need to find the entry
265  auto chunk = fChunks.at(fLastChunkPerSlot[slot]);
266  assert(slot < fArrayVisitorPerSlot.size());
267  fArrayVisitorPerSlot[slot].SetEntry(entry - fFirstEntryPerChunk[fLastChunkPerSlot[slot]]);
268  fLastEntryPerSlot[slot] = entry;
269  auto status = chunk->Accept(fArrayVisitorPerSlot.data() + slot);
270  if (!status.ok()) {
271  std::string msg = "Could not get pointer for slot ";
272  msg += std::to_string(slot) + " looking at entry " + std::to_string(entry);
273  throw std::runtime_error(msg);
274  }
275  }
276 
277  /// Set the current entry to be retrieved
278  void SetEntry(unsigned int slot, ULong64_t entry)
279  {
280  // Same entry as before
281  if (fLastEntryPerSlot[slot] == entry) {
282  return;
283  }
284  UncachedSlotLookup(slot, entry);
285  }
286 };
287 
288 } // namespace RDF
289 } // namespace Internal
290 
291 namespace RDF {
292 
293 /// Helper to get the contents of a given column
294 
295 /// Helper to get the human readable name of type
296 class RDFTypeNameGetter : public ::arrow::TypeVisitor {
297 private:
298  std::vector<std::string> fTypeName;
299 
300 public:
301  arrow::Status Visit(const arrow::Int64Type &) override
302  {
303  fTypeName.push_back("Long64_t");
304  return arrow::Status::OK();
305  }
306  arrow::Status Visit(const arrow::Int32Type &) override
307  {
308  fTypeName.push_back("Int_t");
309  return arrow::Status::OK();
310  }
311  arrow::Status Visit(const arrow::UInt64Type &) override
312  {
313  fTypeName.push_back("ULong64_t");
314  return arrow::Status::OK();
315  }
316  arrow::Status Visit(const arrow::UInt32Type &) override
317  {
318  fTypeName.push_back("UInt_t");
319  return arrow::Status::OK();
320  }
321  arrow::Status Visit(const arrow::FloatType &) override
322  {
323  fTypeName.push_back("float");
324  return arrow::Status::OK();
325  }
326  arrow::Status Visit(const arrow::DoubleType &) override
327  {
328  fTypeName.push_back("double");
329  return arrow::Status::OK();
330  }
331  arrow::Status Visit(const arrow::StringType &) override
332  {
333  fTypeName.push_back("string");
334  return arrow::Status::OK();
335  }
336  arrow::Status Visit(const arrow::BooleanType &) override
337  {
338  fTypeName.push_back("bool");
339  return arrow::Status::OK();
340  }
341  arrow::Status Visit(const arrow::ListType &l) override
342  {
343  /// Recursively visit List types and map them to
344  /// an RVec. We accumulate the result of the recursion on
345  /// fTypeName so that we can create the actual type
346  /// when the recursion is done.
347  fTypeName.push_back("ROOT::VecOps::RVec<%s>");
348  return l.value_type()->Accept(this);
349  }
350  std::string result()
351  {
352  // This recursively builds a nested type.
353  std::string result = "%s";
354  char buffer[8192];
355  for (size_t i = 0; i < fTypeName.size(); ++i) {
356  snprintf(buffer, 8192, result.c_str(), fTypeName[i].c_str());
357  result = buffer;
358  }
359  return result;
360  }
361 
362  using ::arrow::TypeVisitor::Visit;
363 };
364 
365 /// Helper to determine if a given Column is a supported type.
366 class VerifyValidColumnType : public ::arrow::TypeVisitor {
367 private:
368 public:
369  virtual arrow::Status Visit(const arrow::Int64Type &) override { return arrow::Status::OK(); }
370  virtual arrow::Status Visit(const arrow::UInt64Type &) override { return arrow::Status::OK(); }
371  virtual arrow::Status Visit(const arrow::Int32Type &) override { return arrow::Status::OK(); }
372  virtual arrow::Status Visit(const arrow::UInt32Type &) override { return arrow::Status::OK(); }
373  virtual arrow::Status Visit(const arrow::FloatType &) override { return arrow::Status::OK(); }
374  virtual arrow::Status Visit(const arrow::DoubleType &) override { return arrow::Status::OK(); }
375  virtual arrow::Status Visit(const arrow::StringType &) override { return arrow::Status::OK(); }
376  virtual arrow::Status Visit(const arrow::BooleanType &) override { return arrow::Status::OK(); }
377  virtual arrow::Status Visit(const arrow::ListType &) override { return arrow::Status::OK(); }
378 
379  using ::arrow::TypeVisitor::Visit;
380 };
381 
382 ////////////////////////////////////////////////////////////////////////
383 /// Constructor to create an Arrow RDataSource for RDataFrame.
384 /// \param[in] table the arrow Table to observe.
385 /// \param[in] columns the name of the columns to use
386 /// In case columns is empty, we use all the columns found in the table
387 RArrowDS::RArrowDS(std::shared_ptr<arrow::Table> inTable, std::vector<std::string> const &inColumns)
388  : fTable{inTable}, fColumnNames{inColumns}
389 {
390  auto &columnNames = fColumnNames;
391  auto &table = fTable;
392  auto &index = fGetterIndex;
393  // We want to allow people to specify which columns they
394  // need so that we can think of upfront IO optimizations.
395  auto filterWantedColumns = [&columnNames, &table]() {
396  if (columnNames.empty()) {
397  for (auto &field : table->schema()->fields()) {
398  columnNames.push_back(field->name());
399  }
400  }
401  };
402 
403  auto getRecordsFirstColumn = [&columnNames, &table]() {
404  if (columnNames.empty()) {
405  throw std::runtime_error("At least one column required");
406  }
407  const auto name = columnNames.front();
408  const auto columnIdx = table->schema()->GetFieldIndex(name);
409  return table->column(columnIdx)->length();
410  };
411 
412  // All columns are supposed to have the same number of entries.
413  auto verifyColumnSize = [](std::shared_ptr<arrow::Column> column, int nRecords) {
414  if (column->length() != nRecords) {
415  std::string msg = "Column ";
416  msg += column->name() + " has a different number of entries.";
417  throw std::runtime_error(msg);
418  }
419  };
420 
421  /// For the moment we support only a few native types.
422  auto verifyColumnType = [](std::shared_ptr<arrow::Column> column) {
423  auto verifyType = std::make_unique<VerifyValidColumnType>();
424  auto result = column->type()->Accept(verifyType.get());
425  if (result.ok() == false) {
426  std::string msg = "Column ";
427  msg += column->name() + " contains an unsupported type.";
428  throw std::runtime_error(msg);
429  }
430  };
431 
432  /// This is used to create an index between the columnId
433  /// and the associated getter.
434  auto addColumnToGetterIndex = [&index](int columnId) { index.push_back(std::make_pair(columnId, index.size())); };
435 
436  /// Assuming we can get called more than once, we need to
437  /// reset the getter index each time.
438  auto resetGetterIndex = [&index]() { index.clear(); };
439 
440  /// This is what initialization actually does
441  filterWantedColumns();
442  resetGetterIndex();
443  auto nRecords = getRecordsFirstColumn();
444  for (auto &columnName : fColumnNames) {
445  auto columnIdx = fTable->schema()->GetFieldIndex(columnName);
446  addColumnToGetterIndex(columnIdx);
447 
448  auto column = fTable->column(columnIdx);
449  verifyColumnSize(column, nRecords);
450  verifyColumnType(column);
451  }
452 }
453 
454 ////////////////////////////////////////////////////////////////////////
455 /// Destructor.
456 RArrowDS::~RArrowDS()
457 {
458 }
459 
460 const std::vector<std::string> &RArrowDS::GetColumnNames() const
461 {
462  return fColumnNames;
463 }
464 
465 std::vector<std::pair<ULong64_t, ULong64_t>> RArrowDS::GetEntryRanges()
466 {
467  auto entryRanges(std::move(fEntryRanges)); // empty fEntryRanges
468  return entryRanges;
469 }
470 
471 std::string RArrowDS::GetTypeName(std::string_view colName) const
472 {
473  auto field = fTable->schema()->GetFieldByName(std::string(colName));
474  if (!field) {
475  std::string msg = "The dataset does not have column ";
476  msg += colName;
477  throw std::runtime_error(msg);
478  }
479  RDFTypeNameGetter typeGetter;
480  auto status = field->type()->Accept(&typeGetter);
481  if (status.ok() == false) {
482  std::string msg = "RArrowDS does not support a column of type ";
483  msg += field->type()->name();
484  throw std::runtime_error(msg);
485  }
486  return typeGetter.result();
487 }
488 
489 bool RArrowDS::HasColumn(std::string_view colName) const
490 {
491  auto field = fTable->schema()->GetFieldByName(std::string(colName));
492  if (!field) {
493  return false;
494  }
495  return true;
496 }
497 
498 bool RArrowDS::SetEntry(unsigned int slot, ULong64_t entry)
499 {
500  for (auto link : fGetterIndex) {
501  auto &getter = fValueGetters[link.second];
502  getter->SetEntry(slot, entry);
503  }
504  return true;
505 }
506 
507 void RArrowDS::InitSlot(unsigned int slot, ULong64_t entry)
508 {
509  for (auto link : fGetterIndex) {
510  auto &getter = fValueGetters[link.second];
511  getter->UncachedSlotLookup(slot, entry);
512  }
513 }
514 
515 void splitInEqualRanges(std::vector<std::pair<ULong64_t, ULong64_t>> &ranges, int nRecords, unsigned int nSlots)
516 {
517  ranges.clear();
518  const auto chunkSize = nRecords / nSlots;
519  const auto remainder = 1U == nSlots ? 0 : nRecords % nSlots;
520  auto start = 0UL;
521  auto end = 0UL;
522  for (auto i : ROOT::TSeqU(nSlots)) {
523  start = end;
524  end += chunkSize;
525  ranges.emplace_back(start, end);
526  (void)i;
527  }
528  ranges.back().second += remainder;
529 }
530 
531 int getNRecords(std::shared_ptr<arrow::Table> &table, std::vector<std::string> &columnNames)
532 {
533  auto index = table->schema()->GetFieldIndex(columnNames.front());
534  return table->column(index)->length();
535 };
536 
537 void RArrowDS::SetNSlots(unsigned int nSlots)
538 {
539  assert(0U == fNSlots && "Setting the number of slots even if the number of slots is different from zero.");
540  fNSlots = nSlots;
541  // We dump all the previous getters structures and we rebuild it.
542  auto nColumns = fGetterIndex.size();
543 
544  fValueGetters.clear();
545  for (size_t ci = 0; ci != nColumns; ++ci) {
546  auto chunkedArray = fTable->column(fGetterIndex[ci].first)->data();
547  fValueGetters.emplace_back(std::make_unique<ROOT::Internal::RDF::TValueGetter>(nSlots, chunkedArray->chunks()));
548  }
549 }
550 
551 /// This needs to return a pointer to the pointer each value getter
552 /// will point to.
553 std::vector<void *> RArrowDS::GetColumnReadersImpl(std::string_view colName, const std::type_info &)
554 {
555  auto &index = fGetterIndex;
556  auto findGetterIndex = [&index](unsigned int column) {
557  for (auto &entry : index) {
558  if (entry.first == column) {
559  return entry.second;
560  }
561  }
562  throw std::runtime_error("No column found at index " + std::to_string(column));
563  };
564 
565  const int columnIdx = fTable->schema()->GetFieldIndex(std::string(colName));
566  const int getterIdx = findGetterIndex(columnIdx);
567  assert(getterIdx != -1);
568  assert((unsigned int)getterIdx < fValueGetters.size());
569  return fValueGetters[getterIdx]->SlotPtrs();
570 }
571 
572 void RArrowDS::Initialise()
573 {
574  auto nRecords = getNRecords(fTable, fColumnNames);
575  splitInEqualRanges(fEntryRanges, nRecords, fNSlots);
576 }
577 
578 std::string RArrowDS::GetLabel()
579 {
580  return "ArrowDS";
581 }
582 
583 /// Creates a RDataFrame using an arrow::Table as input.
584 /// \param[in] table the arrow Table to observe.
585 /// \param[in] columnNames the name of the columns to use
586 /// In case columnNames is empty, we use all the columns found in the table
587 RDataFrame MakeArrowDataFrame(std::shared_ptr<arrow::Table> table, std::vector<std::string> const &columnNames)
588 {
589  ROOT::RDataFrame tdf(std::make_unique<RArrowDS>(table, columnNames));
590  return tdf;
591 }
592 
593 } // namespace RDF
594 
595 } // namespace ROOT