19 #ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
20 #define TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
29 template <
typename AReal>
32 template <
typename AData,
typename AReal>
33 class TDataLoader<AData, TReference<AReal>> {
35 using BatchIterator_t = TBatchIterator<AData, TReference<AReal>>;
41 size_t fNInputFeatures;
42 size_t fNOutputFeatures;
45 TMatrixT<AReal> inputMatrix;
46 TMatrixT<AReal> outputMatrix;
47 TMatrixT<AReal> weightMatrix;
49 std::vector<size_t> fSampleIndices;
52 TDataLoader(
const AData &data,
size_t nSamples,
size_t batchSize,
size_t nInputFeatures,
size_t nOutputFeatures,
54 TDataLoader(
const TDataLoader &) =
default;
55 TDataLoader(TDataLoader &&) =
default;
56 TDataLoader &operator=(
const TDataLoader &) =
default;
57 TDataLoader &operator=(TDataLoader &&) =
default;
61 void CopyInput(TMatrixT<AReal> &matrix, IndexIterator_t begin);
64 void CopyOutput(TMatrixT<AReal> &matrix, IndexIterator_t begin);
67 void CopyWeights(TMatrixT<AReal> &matrix, IndexIterator_t begin);
69 BatchIterator_t begin() {
return BatchIterator_t(*
this); }
70 BatchIterator_t end() {
return BatchIterator_t(*
this, fNSamples / fBatchSize); }
80 TBatch<TReference<AReal>> GetBatch();
83 template <
typename AData,
typename AReal>
84 TDataLoader<AData, TReference<AReal>>::TDataLoader(
const AData &data,
size_t nSamples,
size_t batchSize,
85 size_t nInputFeatures,
size_t nOutputFeatures,
size_t )
86 : fData(data), fNSamples(nSamples), fBatchSize(batchSize), fNInputFeatures(nInputFeatures),
87 fNOutputFeatures(nOutputFeatures), fBatchIndex(0), inputMatrix(batchSize, nInputFeatures),
88 outputMatrix(batchSize, nOutputFeatures), weightMatrix(batchSize, 1), fSampleIndices()
90 fSampleIndices.reserve(fNSamples);
91 for (
size_t i = 0; i < fNSamples; i++) {
92 fSampleIndices.push_back(i);
96 template <
typename AData,
typename AReal>
97 TBatch<TReference<AReal>> TDataLoader<AData, TReference<AReal>>::GetBatch()
99 fBatchIndex %= (fNSamples / fBatchSize);
101 size_t sampleIndex = fBatchIndex * fBatchSize;
102 IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;
104 CopyInput(inputMatrix, sampleIndexIterator);
105 CopyOutput(outputMatrix, sampleIndexIterator);
106 CopyWeights(weightMatrix, sampleIndexIterator);
110 return TBatch<TReference<AReal>>(inputMatrix, outputMatrix, weightMatrix);
114 template <
typename AData,
typename AReal>
115 void TDataLoader<AData, TReference<AReal>>::Shuffle()
117 std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), std::default_random_engine{});
123 #endif // TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER