34 #ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_TENSORDATALOADER
35 #define TMVA_DNN_ARCHITECTURES_REFERENCE_TENSORDATALOADER
43 template <
typename AReal>
46 template <
typename AData,
typename AReal>
47 class TTensorDataLoader<AData, TReference<AReal>> {
49 using BatchIterator_t = TTensorBatchIterator<AData, TReference<AReal>>;
58 size_t fNOutputFeatures;
61 std::vector<size_t> fInputShape;
63 std::vector<TMatrixT<AReal>> inputTensor;
64 TMatrixT<AReal> outputMatrix;
65 TMatrixT<AReal> weightMatrix;
67 std::vector<size_t> fSampleIndices;
71 TTensorDataLoader(
const AData &data,
size_t nSamples,
size_t batchDepth,
72 size_t batchHeight,
size_t batchWidth,
size_t nOutputFeatures,
73 std::vector<size_t> inputShape,
size_t nStreams = 1);
75 TTensorDataLoader(
const TTensorDataLoader &) =
default;
76 TTensorDataLoader(TTensorDataLoader &&) =
default;
77 TTensorDataLoader &operator=(
const TTensorDataLoader &) =
default;
78 TTensorDataLoader &operator=(TTensorDataLoader &&) =
default;
82 void CopyTensorInput(std::vector<TMatrixT<AReal>> &tensor, IndexIterator_t sampleIterator);
85 void CopyTensorOutput(TMatrixT<AReal> &matrix, IndexIterator_t sampleIterator);
88 void CopyTensorWeights(TMatrixT<AReal> &matrix, IndexIterator_t sampleIterator);
90 BatchIterator_t begin() {
return BatchIterator_t(*
this); }
91 BatchIterator_t end() {
return BatchIterator_t(*
this, fNSamples / fInputShape[0]); }
96 template<
typename RNG>
97 void Shuffle(RNG & rng);
102 TTensorBatch<TReference<AReal>> GetTensorBatch();
108 template <
typename AData,
typename AReal>
109 TTensorDataLoader<AData, TReference<AReal>>::TTensorDataLoader(
const AData &data,
size_t nSamples,
size_t batchDepth,
110 size_t batchHeight,
size_t batchWidth,
size_t nOutputFeatures,
111 std::vector<size_t> inputShape,
size_t )
112 : fData(data), fNSamples(nSamples), fBatchDepth(batchDepth), fBatchHeight(batchHeight),
113 fBatchWidth(batchWidth), fNOutputFeatures(nOutputFeatures), fBatchIndex(0), fInputShape(std::move(inputShape)), inputTensor(),
114 outputMatrix(inputShape[0], nOutputFeatures), weightMatrix(inputShape[0], 1), fSampleIndices()
117 inputTensor.reserve(fBatchDepth);
118 for (
size_t i = 0; i < fBatchDepth; i++) {
119 inputTensor.emplace_back(batchHeight, batchWidth);
122 fSampleIndices.reserve(fNSamples);
123 for (
size_t i = 0; i < fNSamples; i++) {
124 fSampleIndices.push_back(i);
128 template <
typename AData,
typename AReal>
129 template <
typename RNG>
130 void TTensorDataLoader<AData, TReference<AReal>>::Shuffle(RNG & rng)
132 std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), rng);
135 template <
typename AData,
typename AReal>
136 auto TTensorDataLoader<AData, TReference<AReal>>::GetTensorBatch() -> TTensorBatch<TReference<AReal>>
138 fBatchIndex %= (fNSamples / fInputShape[0]);
140 size_t sampleIndex = fBatchIndex * fInputShape[0];
141 IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;
143 CopyTensorInput(inputTensor, sampleIndexIterator);
144 CopyTensorOutput(outputMatrix, sampleIndexIterator);
145 CopyTensorWeights(weightMatrix, sampleIndexIterator);
148 return TTensorBatch<TReference<AReal>>(inputTensor, outputMatrix, weightMatrix);