39 #pragma GCC diagnostic push
40 #pragma GCC diagnostic ignored "-Wshadow"
41 #pragma GCC diagnostic ignored "-Wunused-parameter"
43 #include <arrow/table.h>
44 #include <arrow/stl.h>
46 #pragma GCC diagnostic pop
59 struct RootConversionTraits {};
61 #define ROOT_ARROW_STL_CONVERSION(c_type, ArrowType_) \
63 struct RootConversionTraits<c_type> { \
64 using ArrowType = ::arrow::ArrowType_; \
67 ROOT_ARROW_STL_CONVERSION(
bool, BooleanType)
68 ROOT_ARROW_STL_CONVERSION(int8_t, Int8Type)
69 ROOT_ARROW_STL_CONVERSION(int16_t, Int16Type)
70 ROOT_ARROW_STL_CONVERSION(int32_t, Int32Type)
71 ROOT_ARROW_STL_CONVERSION(Long64_t, Int64Type)
72 ROOT_ARROW_STL_CONVERSION(uint8_t, UInt8Type)
73 ROOT_ARROW_STL_CONVERSION(uint16_t, UInt16Type)
74 ROOT_ARROW_STL_CONVERSION(uint32_t, UInt32Type)
75 ROOT_ARROW_STL_CONVERSION(ULong64_t, UInt64Type)
76 ROOT_ARROW_STL_CONVERSION(
float, FloatType)
77 ROOT_ARROW_STL_CONVERSION(
double, DoubleType)
78 ROOT_ARROW_STL_CONVERSION(std::
string, StringType)
81 class ArrayPtrVisitor : public ::arrow::ArrayVisitor {
85 bool fCachedBool{
false};
87 RVec<float> fCachedRVecFloat;
88 RVec<double> fCachedRVecDouble;
89 RVec<ULong64_t> fCachedRVecULong64;
90 RVec<UInt_t> fCachedRVecUInt;
91 RVec<Long64_t> fCachedRVecLong64;
92 RVec<Int_t> fCachedRVecInt;
93 std::string fCachedString;
95 ULong64_t fCurrentEntry;
98 void *getTypeErasedPtrFrom(arrow::ListArray
const &array, int32_t entry, RVec<T> &cache)
100 using ArrowType =
typename RootConversionTraits<T>::ArrowType;
101 using ArrayType =
typename arrow::TypeTraits<ArrowType>::ArrayType;
102 auto values =
reinterpret_cast<ArrayType *
>(array.values().get());
103 auto offset = array.value_offset(entry);
106 RVec<T> tmp(reinterpret_cast<T *>((
void *)values->raw_values()) + offset, array.value_length(entry));
107 std::swap(cache, tmp);
108 return (
void *)(&cache);
112 ArrayPtrVisitor(
void **result) : fResult{result}, fCurrentEntry{0} {}
114 void SetEntry(ULong64_t entry) { fCurrentEntry = entry; }
117 virtual arrow::Status Visit(arrow::Int32Array
const &array)
final
119 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
120 return arrow::Status::OK();
123 virtual arrow::Status Visit(arrow::Int64Array
const &array)
final
125 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
126 return arrow::Status::OK();
130 virtual arrow::Status Visit(arrow::UInt32Array
const &array)
final
132 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
133 return arrow::Status::OK();
136 virtual arrow::Status Visit(arrow::UInt64Array
const &array)
final
138 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
139 return arrow::Status::OK();
142 virtual arrow::Status Visit(arrow::FloatArray
const &array)
final
144 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
145 return arrow::Status::OK();
148 virtual arrow::Status Visit(arrow::DoubleArray
const &array)
final
150 *fResult = (
void *)(array.raw_values() + fCurrentEntry);
151 return arrow::Status::OK();
154 virtual arrow::Status Visit(arrow::BooleanArray
const &array)
final
156 fCachedBool = array.Value(fCurrentEntry);
157 *fResult =
reinterpret_cast<void *
>(&fCachedBool);
158 return arrow::Status::OK();
161 virtual arrow::Status Visit(arrow::StringArray
const &array)
final
163 fCachedString = array.GetString(fCurrentEntry);
164 *fResult =
reinterpret_cast<void *
>(&fCachedString);
165 return arrow::Status::OK();
168 virtual arrow::Status Visit(arrow::ListArray
const &array)
final
170 switch (array.value_type()->id()) {
171 case arrow::Type::FLOAT: {
172 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecFloat);
173 return arrow::Status::OK();
175 case arrow::Type::DOUBLE: {
176 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecDouble);
177 return arrow::Status::OK();
179 case arrow::Type::UINT32: {
180 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecUInt);
181 return arrow::Status::OK();
183 case arrow::Type::UINT64: {
184 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecULong64);
185 return arrow::Status::OK();
187 case arrow::Type::INT32: {
188 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecInt);
189 return arrow::Status::OK();
191 case arrow::Type::INT64: {
192 *fResult = getTypeErasedPtrFrom(array, fCurrentEntry, fCachedRVecLong64);
193 return arrow::Status::OK();
195 default:
return arrow::Status::TypeError(
"Type not supported");
199 using ::arrow::ArrayVisitor::Visit;
205 std::vector<void *> fValuesPtrPerSlot;
206 std::vector<ULong64_t> fLastEntryPerSlot;
207 std::vector<ULong64_t> fLastChunkPerSlot;
208 std::vector<ULong64_t> fFirstEntryPerChunk;
209 std::vector<ArrayPtrVisitor> fArrayVisitorPerSlot;
213 std::vector<ULong64_t> fChunkIndex;
214 arrow::ArrayVector fChunks;
217 TValueGetter(
size_t slots, arrow::ArrayVector chunks)
218 : fValuesPtrPerSlot(slots, nullptr), fLastEntryPerSlot(slots, 0), fLastChunkPerSlot(slots, 0), fChunks{chunks}
220 fChunkIndex.reserve(fChunks.size());
222 for (
auto &chunk : chunks) {
223 fFirstEntryPerChunk.push_back(next);
224 next += chunk->length();
225 fChunkIndex.push_back(next);
227 for (
size_t si = 0, se = fValuesPtrPerSlot.size(); si != se; ++si) {
228 fArrayVisitorPerSlot.push_back(ArrayPtrVisitor{fValuesPtrPerSlot.data() + si});
233 std::vector<void *> SlotPtrs()
235 std::vector<void *> result;
236 for (
size_t i = 0; i < fValuesPtrPerSlot.size(); ++i) {
237 result.push_back(fValuesPtrPerSlot.data() + i);
244 void UncachedSlotLookup(
unsigned int slot, ULong64_t entry)
250 assert(slot < fLastChunkPerSlot.size());
251 if (fLastEntryPerSlot[slot] < entry) {
252 ci = fLastChunkPerSlot.at(slot);
255 for (
size_t ce = fChunkIndex.size(); ci != ce; ++ci) {
256 if (entry < fChunkIndex[ci]) {
257 assert(slot < fLastChunkPerSlot.size());
258 fLastChunkPerSlot[slot] = ci;
265 auto chunk = fChunks.at(fLastChunkPerSlot[slot]);
266 assert(slot < fArrayVisitorPerSlot.size());
267 fArrayVisitorPerSlot[slot].SetEntry(entry - fFirstEntryPerChunk[fLastChunkPerSlot[slot]]);
268 fLastEntryPerSlot[slot] = entry;
269 auto status = chunk->Accept(fArrayVisitorPerSlot.data() + slot);
271 std::string msg =
"Could not get pointer for slot ";
272 msg += std::to_string(slot) +
" looking at entry " + std::to_string(entry);
273 throw std::runtime_error(msg);
278 void SetEntry(
unsigned int slot, ULong64_t entry)
281 if (fLastEntryPerSlot[slot] == entry) {
284 UncachedSlotLookup(slot, entry);
296 class RDFTypeNameGetter :
public ::arrow::TypeVisitor {
298 std::vector<std::string> fTypeName;
301 arrow::Status Visit(
const arrow::Int64Type &)
override
303 fTypeName.push_back(
"Long64_t");
304 return arrow::Status::OK();
306 arrow::Status Visit(
const arrow::Int32Type &)
override
308 fTypeName.push_back(
"Int_t");
309 return arrow::Status::OK();
311 arrow::Status Visit(
const arrow::UInt64Type &)
override
313 fTypeName.push_back(
"ULong64_t");
314 return arrow::Status::OK();
316 arrow::Status Visit(
const arrow::UInt32Type &)
override
318 fTypeName.push_back(
"UInt_t");
319 return arrow::Status::OK();
321 arrow::Status Visit(
const arrow::FloatType &)
override
323 fTypeName.push_back(
"float");
324 return arrow::Status::OK();
326 arrow::Status Visit(
const arrow::DoubleType &)
override
328 fTypeName.push_back(
"double");
329 return arrow::Status::OK();
331 arrow::Status Visit(
const arrow::StringType &)
override
333 fTypeName.push_back(
"string");
334 return arrow::Status::OK();
336 arrow::Status Visit(
const arrow::BooleanType &)
override
338 fTypeName.push_back(
"bool");
339 return arrow::Status::OK();
341 arrow::Status Visit(
const arrow::ListType &l)
override
347 fTypeName.push_back(
"ROOT::VecOps::RVec<%s>");
348 return l.value_type()->Accept(
this);
353 std::string result =
"%s";
355 for (
size_t i = 0; i < fTypeName.size(); ++i) {
356 snprintf(buffer, 8192, result.c_str(), fTypeName[i].c_str());
362 using ::arrow::TypeVisitor::Visit;
366 class VerifyValidColumnType :
public ::arrow::TypeVisitor {
369 virtual arrow::Status Visit(
const arrow::Int64Type &)
override {
return arrow::Status::OK(); }
370 virtual arrow::Status Visit(
const arrow::UInt64Type &)
override {
return arrow::Status::OK(); }
371 virtual arrow::Status Visit(
const arrow::Int32Type &)
override {
return arrow::Status::OK(); }
372 virtual arrow::Status Visit(
const arrow::UInt32Type &)
override {
return arrow::Status::OK(); }
373 virtual arrow::Status Visit(
const arrow::FloatType &)
override {
return arrow::Status::OK(); }
374 virtual arrow::Status Visit(
const arrow::DoubleType &)
override {
return arrow::Status::OK(); }
375 virtual arrow::Status Visit(
const arrow::StringType &)
override {
return arrow::Status::OK(); }
376 virtual arrow::Status Visit(
const arrow::BooleanType &)
override {
return arrow::Status::OK(); }
377 virtual arrow::Status Visit(
const arrow::ListType &)
override {
return arrow::Status::OK(); }
379 using ::arrow::TypeVisitor::Visit;
387 RArrowDS::RArrowDS(std::shared_ptr<arrow::Table> inTable, std::vector<std::string>
const &inColumns)
388 : fTable{inTable}, fColumnNames{inColumns}
390 auto &columnNames = fColumnNames;
391 auto &table = fTable;
392 auto &index = fGetterIndex;
395 auto filterWantedColumns = [&columnNames, &table]() {
396 if (columnNames.empty()) {
397 for (
auto &field : table->schema()->fields()) {
398 columnNames.push_back(field->name());
403 auto getRecordsFirstColumn = [&columnNames, &table]() {
404 if (columnNames.empty()) {
405 throw std::runtime_error(
"At least one column required");
407 const auto name = columnNames.front();
408 const auto columnIdx = table->schema()->GetFieldIndex(name);
409 return table->column(columnIdx)->length();
413 auto verifyColumnSize = [](std::shared_ptr<arrow::Column> column,
int nRecords) {
414 if (column->length() != nRecords) {
415 std::string msg =
"Column ";
416 msg += column->name() +
" has a different number of entries.";
417 throw std::runtime_error(msg);
422 auto verifyColumnType = [](std::shared_ptr<arrow::Column> column) {
423 auto verifyType = std::make_unique<VerifyValidColumnType>();
424 auto result = column->type()->Accept(verifyType.get());
425 if (result.ok() ==
false) {
426 std::string msg =
"Column ";
427 msg += column->name() +
" contains an unsupported type.";
428 throw std::runtime_error(msg);
434 auto addColumnToGetterIndex = [&index](
int columnId) { index.push_back(std::make_pair(columnId, index.size())); };
438 auto resetGetterIndex = [&index]() { index.clear(); };
441 filterWantedColumns();
443 auto nRecords = getRecordsFirstColumn();
444 for (
auto &columnName : fColumnNames) {
445 auto columnIdx = fTable->schema()->GetFieldIndex(columnName);
446 addColumnToGetterIndex(columnIdx);
448 auto column = fTable->column(columnIdx);
449 verifyColumnSize(column, nRecords);
450 verifyColumnType(column);
456 RArrowDS::~RArrowDS()
460 const std::vector<std::string> &RArrowDS::GetColumnNames()
const
465 std::vector<std::pair<ULong64_t, ULong64_t>> RArrowDS::GetEntryRanges()
467 auto entryRanges(std::move(fEntryRanges));
471 std::string RArrowDS::GetTypeName(std::string_view colName)
const
473 auto field = fTable->schema()->GetFieldByName(std::string(colName));
475 std::string msg =
"The dataset does not have column ";
477 throw std::runtime_error(msg);
479 RDFTypeNameGetter typeGetter;
480 auto status = field->type()->Accept(&typeGetter);
481 if (status.ok() ==
false) {
482 std::string msg =
"RArrowDS does not support a column of type ";
483 msg += field->type()->name();
484 throw std::runtime_error(msg);
486 return typeGetter.result();
489 bool RArrowDS::HasColumn(std::string_view colName)
const
491 auto field = fTable->schema()->GetFieldByName(std::string(colName));
498 bool RArrowDS::SetEntry(
unsigned int slot, ULong64_t entry)
500 for (
auto link : fGetterIndex) {
501 auto &getter = fValueGetters[link.second];
502 getter->SetEntry(slot, entry);
507 void RArrowDS::InitSlot(
unsigned int slot, ULong64_t entry)
509 for (
auto link : fGetterIndex) {
510 auto &getter = fValueGetters[link.second];
511 getter->UncachedSlotLookup(slot, entry);
515 void splitInEqualRanges(std::vector<std::pair<ULong64_t, ULong64_t>> &ranges,
int nRecords,
unsigned int nSlots)
518 const auto chunkSize = nRecords / nSlots;
519 const auto remainder = 1U == nSlots ? 0 : nRecords % nSlots;
522 for (
auto i : ROOT::TSeqU(nSlots)) {
525 ranges.emplace_back(start, end);
528 ranges.back().second += remainder;
531 int getNRecords(std::shared_ptr<arrow::Table> &table, std::vector<std::string> &columnNames)
533 auto index = table->schema()->GetFieldIndex(columnNames.front());
534 return table->column(index)->length();
537 void RArrowDS::SetNSlots(
unsigned int nSlots)
539 assert(0U == fNSlots &&
"Setting the number of slots even if the number of slots is different from zero.");
542 auto nColumns = fGetterIndex.size();
544 fValueGetters.clear();
545 for (
size_t ci = 0; ci != nColumns; ++ci) {
546 auto chunkedArray = fTable->column(fGetterIndex[ci].first)->data();
547 fValueGetters.emplace_back(std::make_unique<ROOT::Internal::RDF::TValueGetter>(nSlots, chunkedArray->chunks()));
553 std::vector<void *> RArrowDS::GetColumnReadersImpl(std::string_view colName,
const std::type_info &)
555 auto &index = fGetterIndex;
556 auto findGetterIndex = [&index](
unsigned int column) {
557 for (
auto &entry : index) {
558 if (entry.first == column) {
562 throw std::runtime_error(
"No column found at index " + std::to_string(column));
565 const int columnIdx = fTable->schema()->GetFieldIndex(std::string(colName));
566 const int getterIdx = findGetterIndex(columnIdx);
567 assert(getterIdx != -1);
568 assert((
unsigned int)getterIdx < fValueGetters.size());
569 return fValueGetters[getterIdx]->SlotPtrs();
572 void RArrowDS::Initialise()
574 auto nRecords = getNRecords(fTable, fColumnNames);
575 splitInEqualRanges(fEntryRanges, nRecords, fNSlots);
578 std::string RArrowDS::GetLabel()
587 RDataFrame MakeArrowDataFrame(std::shared_ptr<arrow::Table> table, std::vector<std::string>
const &columnNames)
589 ROOT::RDataFrame tdf(std::make_unique<RArrowDS>(table, columnNames));