Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
DataLoader.h
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 06/06/17
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /////////////////////////////////////////////////////////////////////
13 // Partial specialization of the TDataLoader class to adapt it to //
14 // the TMatrix class. Also the data transfer is kept simple, since //
15 // this implementation (being intended as reference and fallback //
16 // is not optimized for performance. //
17 /////////////////////////////////////////////////////////////////////
18 
19 #ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
20 #define TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
21 
22 #include "TMVA/DNN/DataLoader.h"
23 
24 #include <random>
25 
26 namespace TMVA {
27 namespace DNN {
28 
29 template <typename AReal>
30 class TReference;
31 
32 template <typename AData, typename AReal>
33 class TDataLoader<AData, TReference<AReal>> {
34 private:
35  using BatchIterator_t = TBatchIterator<AData, TReference<AReal>>;
36 
37  const AData &fData;
38 
39  size_t fNSamples;
40  size_t fBatchSize;
41  size_t fNInputFeatures;
42  size_t fNOutputFeatures;
43  size_t fBatchIndex;
44 
45  TMatrixT<AReal> inputMatrix;
46  TMatrixT<AReal> outputMatrix;
47  TMatrixT<AReal> weightMatrix;
48 
49  std::vector<size_t> fSampleIndices; ///< Ordering of the samples in the epoch.
50 
51 public:
52  TDataLoader(const AData &data, size_t nSamples, size_t batchSize, size_t nInputFeatures, size_t nOutputFeatures,
53  size_t nthreads = 1);
54  TDataLoader(const TDataLoader &) = default;
55  TDataLoader(TDataLoader &&) = default;
56  TDataLoader &operator=(const TDataLoader &) = default;
57  TDataLoader &operator=(TDataLoader &&) = default;
58 
59  /** Copy input matrix into the given host buffer. Function to be specialized by
60  * the architecture-specific backend. */
61  void CopyInput(TMatrixT<AReal> &matrix, IndexIterator_t begin);
62  /** Copy output matrix into the given host buffer. Function to be specialized
63  * by the architecture-spcific backend. */
64  void CopyOutput(TMatrixT<AReal> &matrix, IndexIterator_t begin);
65  /** Copy weight matrix into the given host buffer. Function to be specialized
66  * by the architecture-spcific backend. */
67  void CopyWeights(TMatrixT<AReal> &matrix, IndexIterator_t begin);
68 
69  BatchIterator_t begin() { return BatchIterator_t(*this); }
70  BatchIterator_t end() { return BatchIterator_t(*this, fNSamples / fBatchSize); }
71 
72  /** Shuffle the order of the samples in the batch. The shuffling is indirect,
73  * i.e. only the indices are shuffled. No input data is moved by this
74  * routine. */
75  void Shuffle();
76 
77  /** Return the next batch from the training set. The TDataLoader object
78  * keeps an internal counter that cycles over the batches in the training
79  * set. */
80  TBatch<TReference<AReal>> GetBatch();
81 };
82 
83 template <typename AData, typename AReal>
84 TDataLoader<AData, TReference<AReal>>::TDataLoader(const AData &data, size_t nSamples, size_t batchSize,
85  size_t nInputFeatures, size_t nOutputFeatures, size_t /*nthreads*/)
86  : fData(data), fNSamples(nSamples), fBatchSize(batchSize), fNInputFeatures(nInputFeatures),
87  fNOutputFeatures(nOutputFeatures), fBatchIndex(0), inputMatrix(batchSize, nInputFeatures),
88  outputMatrix(batchSize, nOutputFeatures), weightMatrix(batchSize, 1), fSampleIndices()
89 {
90  fSampleIndices.reserve(fNSamples);
91  for (size_t i = 0; i < fNSamples; i++) {
92  fSampleIndices.push_back(i);
93  }
94 }
95 
96 template <typename AData, typename AReal>
97 TBatch<TReference<AReal>> TDataLoader<AData, TReference<AReal>>::GetBatch()
98 {
99  fBatchIndex %= (fNSamples / fBatchSize); // Cycle through samples.
100 
101  size_t sampleIndex = fBatchIndex * fBatchSize;
102  IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;
103 
104  CopyInput(inputMatrix, sampleIndexIterator);
105  CopyOutput(outputMatrix, sampleIndexIterator);
106  CopyWeights(weightMatrix, sampleIndexIterator);
107 
108  fBatchIndex++;
109 
110  return TBatch<TReference<AReal>>(inputMatrix, outputMatrix, weightMatrix);
111 }
112 
113 //______________________________________________________________________________
114 template <typename AData, typename AReal>
115 void TDataLoader<AData, TReference<AReal>>::Shuffle()
116 {
117  std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), std::default_random_engine{});
118 }
119 
120 } // namespace DNN
121 } // namespace TMVA
122 
123 #endif // TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER