18 #ifndef TMVA_DNN_DATALOADER
19 #define TMVA_DNN_DATALOADER
38 using MatrixInput_t = std::tuple<const TMatrixT<Double_t> &,
const TMatrixT<Double_t> &,
const TMatrixT<Double_t> &>;
40 std::tuple<const std::vector<Event *> &,
const DataSetInfo &>;
42 using IndexIterator_t =
typename std::vector<size_t>::iterator;
53 template <
typename AArchitecture>
58 using Matrix_t =
typename AArchitecture::Matrix_t;
60 Matrix_t fInputMatrix;
61 Matrix_t fOutputMatrix;
62 Matrix_t fWeightMatrix;
65 TBatch(Matrix_t &, Matrix_t &, Matrix_t &);
66 TBatch(
const TBatch &) =
default;
67 TBatch( TBatch &&) =
default;
68 TBatch & operator=(
const TBatch &) =
default;
69 TBatch & operator=( TBatch &&) =
default;
72 Matrix_t &GetInput() {
return fInputMatrix; }
74 Matrix_t &GetOutput() {
return fOutputMatrix; }
76 Matrix_t &GetWeights() {
return fWeightMatrix; }
79 template<
typename Data_t,
typename AArchitecture>
class TDataLoader;
89 template<
typename Data_t,
typename AArchitecture>
94 TDataLoader<Data_t, AArchitecture> & fDataLoader;
99 TBatchIterator(TDataLoader<Data_t, AArchitecture> & dataLoader,
size_t index = 0)
100 : fDataLoader(dataLoader), fBatchIndex(index)
105 TBatch<AArchitecture> operator*() {
return fDataLoader.GetBatch();}
106 TBatchIterator operator++() {fBatchIndex++;
return *
this;}
107 bool operator!=(
const TBatchIterator & other) {
108 return fBatchIndex != other.fBatchIndex;
127 template<
typename Data_t,
typename AArchitecture>
132 using HostBuffer_t =
typename AArchitecture::HostBuffer_t;
133 using DeviceBuffer_t =
typename AArchitecture::DeviceBuffer_t;
134 using Matrix_t =
typename AArchitecture::Matrix_t;
135 using BatchIterator_t = TBatchIterator<Data_t, AArchitecture>;
141 size_t fNInputFeatures;
142 size_t fNOutputFeatures;
146 std::vector<DeviceBuffer_t> fDeviceBuffers;
147 std::vector<HostBuffer_t> fHostBuffers;
149 std::vector<size_t> fSampleIndices;
153 TDataLoader(
const Data_t & data,
size_t nSamples,
size_t batchSize,
154 size_t nInputFeatures,
size_t nOutputFeatures,
size_t nStreams = 1);
155 TDataLoader(
const TDataLoader &) =
default;
156 TDataLoader( TDataLoader &&) =
default;
157 TDataLoader & operator=(
const TDataLoader &) =
default;
158 TDataLoader & operator=( TDataLoader &&) =
default;
162 void CopyInput(HostBuffer_t &buffer, IndexIterator_t begin,
size_t batchSize);
165 void CopyOutput(HostBuffer_t &buffer, IndexIterator_t begin,
size_t batchSize);
168 void CopyWeights(HostBuffer_t &buffer, IndexIterator_t begin,
size_t batchSize);
170 BatchIterator_t begin() {
return TBatchIterator<Data_t, AArchitecture>(*this);}
171 BatchIterator_t end()
173 return TBatchIterator<Data_t, AArchitecture>(*
this, fNSamples / fBatchSize);
184 TBatch<AArchitecture> GetBatch();
191 template <
typename AArchitecture>
192 TBatch<AArchitecture>::TBatch(Matrix_t &inputMatrix, Matrix_t &outputMatrix, Matrix_t &weightMatrix)
193 : fInputMatrix(inputMatrix), fOutputMatrix(outputMatrix), fWeightMatrix(weightMatrix)
201 template<
typename Data_t,
typename AArchitecture>
202 TDataLoader<Data_t, AArchitecture>::TDataLoader(
203 const Data_t & data,
size_t nSamples,
size_t batchSize,
204 size_t nInputFeatures,
size_t nOutputFeatures,
size_t nStreams)
205 : fData(data), fNSamples(nSamples), fBatchSize(batchSize),
206 fNInputFeatures(nInputFeatures), fNOutputFeatures(nOutputFeatures),
207 fBatchIndex(0), fNStreams(nStreams), fDeviceBuffers(), fHostBuffers(),
210 size_t inputMatrixSize = fBatchSize * fNInputFeatures;
211 size_t outputMatrixSize = fBatchSize * fNOutputFeatures;
212 size_t weightMatrixSize = fBatchSize;
214 for (
size_t i = 0; i < fNStreams; i++)
216 fHostBuffers.push_back(HostBuffer_t(inputMatrixSize + outputMatrixSize + weightMatrixSize));
217 fDeviceBuffers.push_back(DeviceBuffer_t(inputMatrixSize + outputMatrixSize + weightMatrixSize));
220 fSampleIndices.reserve(fNSamples);
221 for (
size_t i = 0; i < fNSamples; i++) {
222 fSampleIndices.push_back(i);
227 template<
typename Data_t,
typename AArchitecture>
228 TBatch<AArchitecture> TDataLoader<Data_t, AArchitecture>::GetBatch()
230 fBatchIndex %= (fNSamples / fBatchSize);
233 size_t inputMatrixSize = fBatchSize * fNInputFeatures;
234 size_t outputMatrixSize = fBatchSize * fNOutputFeatures;
235 size_t weightMatrixSize = fBatchSize;
237 size_t streamIndex = fBatchIndex % fNStreams;
238 HostBuffer_t & hostBuffer = fHostBuffers[streamIndex];
239 DeviceBuffer_t & deviceBuffer = fDeviceBuffers[streamIndex];
241 HostBuffer_t inputHostBuffer = hostBuffer.GetSubBuffer(0, inputMatrixSize);
242 HostBuffer_t outputHostBuffer = hostBuffer.GetSubBuffer(inputMatrixSize,
244 HostBuffer_t weightHostBuffer = hostBuffer.GetSubBuffer(inputMatrixSize + outputMatrixSize, weightMatrixSize);
246 DeviceBuffer_t inputDeviceBuffer = deviceBuffer.GetSubBuffer(0, inputMatrixSize);
247 DeviceBuffer_t outputDeviceBuffer = deviceBuffer.GetSubBuffer(inputMatrixSize,
249 DeviceBuffer_t weightDeviceBuffer = deviceBuffer.GetSubBuffer(inputMatrixSize + outputMatrixSize, weightMatrixSize);
251 size_t sampleIndex = fBatchIndex * fBatchSize;
252 IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;
254 CopyInput(inputHostBuffer, sampleIndexIterator, fBatchSize);
255 CopyOutput(outputHostBuffer, sampleIndexIterator, fBatchSize);
256 CopyWeights(weightHostBuffer, sampleIndexIterator, fBatchSize);
258 deviceBuffer.CopyFrom(hostBuffer);
259 Matrix_t inputMatrix(inputDeviceBuffer, fBatchSize, fNInputFeatures);
260 Matrix_t outputMatrix(outputDeviceBuffer, fBatchSize, fNOutputFeatures);
261 Matrix_t weightMatrix(weightDeviceBuffer, fBatchSize, fNOutputFeatures);
264 return TBatch<AArchitecture>(inputMatrix, outputMatrix, weightMatrix);
268 template<
typename Data_t,
typename AArchitecture>
269 void TDataLoader<Data_t, AArchitecture>::Shuffle()
271 std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), std::default_random_engine{});