Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
BatchNormLayer.h
Go to the documentation of this file.
1 
2 // Author: Vladimir Ilievski
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : TBatchNormLayer *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Dense Layer Class *
12  * *
13  * Authors (alphabetical): *
14  * Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
15  * *
16  * Copyright (c) 2005-2015: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * U. of Bonn, Germany *
21  * *
22  * Redistribution and use in source and binary forms, with or without *
23  * modification, are permitted according to the terms listed in LICENSE *
24  * (http://tmva.sourceforge.net/LICENSE) *
25  **********************************************************************************/
26 
27 #ifndef TMVA_DNN_BatchNormLayer
28 #define TMVA_DNN_BatchNormLayer
29 
30 #include "TMVA/DNN/GeneralLayer.h"
31 #include "TMVA/DNN/Functions.h"
32 
34 
36 
37 #include <iostream>
38 #include <iomanip>
39 
40 namespace TMVA {
41 namespace DNN {
42 
43 /** \class TBatchNormLayer
44 
45  Layer implementing Batch Normalization
46 
47  The input from each batch are normalized during training to have zero mean and unit variance
48  and they are then scaled by two parameter, different for each input variable:
49  - a scale factor gamma
50  - an offset beta
51 
52  In addition a running batch mean and variance is computed and stored in the class
53  During inference the inputs are not normalized using the batch mean but the previously computed
54  at running mean and variance
55  If momentum is in [0,1) the running mean and variances are the exponetial averages using the momentum value
56  runnig_mean = momentum * running_mean + (1-momentum) * batch_mean
57  If instead momentum<1 the cumulative average is computed
58  running_mean = (nb/(nb+1) * running_mean + 1/(nb+1) * batch_mean
59 
60  See more at [https://arxiv.org/pdf/1502.03167v3.pdf]
61 */
62 template <typename Architecture_t>
63 class TBatchNormLayer : public VGeneralLayer<Architecture_t> {
64 public:
65 
66  using Scalar_t = typename Architecture_t::Scalar_t;
67  using Matrix_t = typename Architecture_t::Matrix_t;
68  using Tensor_t = typename Architecture_t::Tensor_t;
69 
70  using HelperDescriptor_t = typename Architecture_t::TensorDescriptor_t;
71  using BNormDescriptors_t = typename Architecture_t::BNormDescriptors_t;
72 
73 
74 private:
75 
76  Tensor_t fDerivatives; ///< First fDerivatives of the activations of this layer.
77 
78  int fNormAxis; ///< Normalization axis. For each element of this axis we will compute mean and stddev
79 
80  Scalar_t fMomentum; ///< The weight decay.
81  Scalar_t fEpsilon;
82 
83  Matrix_t fMu;
84  Matrix_t fVar;
85  Matrix_t fIVar;
86 
87  Matrix_t fMu_Training;
88  Matrix_t fVar_Training;
89 
90  // cached tensor used for Cudnn to get correct shape
91  Tensor_t fReshapedData; // cached reshaped data tensor
92 
93  // counter of trained batches for computing tesing and variance means
94  int fTrainedBatches = 0;
95 
96  TDescriptors * fDescriptors = nullptr;
97 
98 public:
99  /*! Constructor */
100  TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
101  const std::vector<size_t> & shape, int axis = -1, Scalar_t momentum = -1., Scalar_t epsilon = 0.0001);
102 
103  /*! Copy the dense layer provided as a pointer */
104  TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer);
105 
106  /*! Copy Constructor */
107  TBatchNormLayer(const TBatchNormLayer &);
108 
109  /*! Destructor */
110  ~TBatchNormLayer();
111 
112  /*! Compute activation of the layer for the given input. The input
113  * must be in 3D tensor form with the different matrices corresponding to
114  * different events in the batch. Computes activations as well as
115  * the first partial derivative of the activation function at those
116  * activations. */
117  void Forward(Tensor_t &input, bool inTraining = true);
118 
119  /*! Compute weight, bias and activation gradients. Uses the precomputed
120  * first partial derviatives of the activation function computed during
121  * forward propagation and modifies them. Must only be called directly
122  * a the corresponding call to Forward(...). */
123  void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
124  // Tensor_t &inp1, Tensor_t &inp2);
125 
126 
127  /* reset at end of training the batch counter */
128  void ResetTraining() { fTrainedBatches = 0; }
129 
130  /*! Printing the layer info. */
131  void Print() const;
132 
133  /*! Writes the information and the weights about the layer in an XML node. */
134  virtual void AddWeightsXMLTo(void *parent);
135 
136  /*! Read the information and the weights about the layer from XML node. */
137  virtual void ReadWeightsFromXML(void *parent);
138 
139  /* initialize weights */
140  virtual void Initialize();
141 
142  /* get number of trained batches */
143  const int & GetNTrainedBatches() const { return fTrainedBatches;}
144  int & GetNTrainedBatches() { return fTrainedBatches;}
145 
146  /* get batch means for the training phase */
147  const Matrix_t & GetBatchMean() const { return fMu;}
148  Matrix_t & GetBatchMean() { return fMu;}
149 
150  /* Get the normalized batch examples */
151  //const Matrix_t & GetNormedBatch() const { return fXhat;}
152  //Matrix_t & GetNormedBatch() { return fXhat;}
153 
154  /* Get the gradient of gamma for backpropagation */
155  const Matrix_t & GetVariance() const { return fVar;}
156  Matrix_t & GetVariance() { return fVar;}
157 
158  /* Get the sqrt of the batch variances for the training phase */
159  const Matrix_t & GetIVariance() const { return fIVar;}
160  Matrix_t & GetIVariance() { return fIVar;}
161 
162  /* get vector of averages computed in the training phase */
163  const Matrix_t & GetMuVector() const { return fMu_Training;}
164  Matrix_t & GetMuVector() { return fMu_Training;}
165 
166  /* get vector of variances computed in the training phase */
167  const Matrix_t & GetVarVector() const { return fVar_Training;}
168  Matrix_t & GetVarVector() { return fVar_Training;}
169 
170  // Scalar_t GetWeightDecay() const { return fWeightDecay; }
171 
172  /* Get the momentum of the running mean/variance */
173  Scalar_t GetMomentum() const { return fMomentum;}
174 
175  /* Get epsilon */
176  Scalar_t GetEpsilon() const { return fEpsilon;}
177 
178  /* Get normalization axis (the one which will have each element normalized) */
179  Scalar_t GetNormAxis() const { return fNormAxis;}
180 
181  const Matrix_t &GetReshapedData() const { return fReshapedData; }
182  Matrix_t &GetReshapedData() { return fReshapedData; }
183 
184  std::vector<Matrix_t> GetExtraLayerParameters() const {
185  std::vector<Matrix_t> params(2);
186  params[0] = this->GetMuVector();
187  params[1] = this->GetVarVector();
188  return params;
189  }
190 
191  void SetExtraLayerParameters(const std::vector<Matrix_t> & params)
192  {
193  this->GetMuVector() = params[0];
194  this->GetVarVector() = params[1];
195  }
196 
197 protected:
198  static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
199  {
200  if (axis == -1)
201  return c * h * w;
202  else if (axis == 1)
203  return c;
204  else if (axis == 2)
205  return h;
206  else if (axis == 3)
207  return w;
208  return 0;
209  }
210 };
211 
212 
213 //
214 //
215 // The Dense Layer Class - Implementation
216 //______________________________________________________________________________
217 template <typename Architecture_t>
218 TBatchNormLayer<Architecture_t>::TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight,
219  size_t inputWidth, const std::vector<size_t> &shape, int axis,
220  Scalar_t momentum, Scalar_t epsilon)
221  : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, // bs + input shape
222  inputDepth, inputHeight, inputWidth, // output shape
223  2, 1,
224  CalculateNormDim(axis, inputDepth, inputHeight, inputWidth), // weight tensor dim.
225  1, 1, 1, // bias
226  shape[2], shape[0], shape[1], // output tensor shape as bsize, depth, hw
227  EInitialization::kZero),
228  fNormAxis(axis), fMomentum(momentum), fEpsilon(epsilon),
229  fMu(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()), // dimension is same as weights
230  fVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
231  fIVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
232  fMu_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
233  fVar_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
234  fReshapedData(1,1,1) // use a dummy single element tensor
235 
236 {
237 
238 }
239 //______________________________________________________________________________
240 template <typename Architecture_t>
241 TBatchNormLayer<Architecture_t>::TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer)
242  : VGeneralLayer<Architecture_t>(layer)
243 {
244  // to be implemented
245  printf("Error - copy ctor not implmented\n");
246 }
247 
248 //______________________________________________________________________________
249 template <typename Architecture_t>
250 TBatchNormLayer<Architecture_t>::TBatchNormLayer(const TBatchNormLayer &layer) : VGeneralLayer<Architecture_t>(layer)
251 {
252  // to be implmeented
253  printf("Error - copy ctor not implmented\n");
254 }
255 
256 //______________________________________________________________________________
257 template <typename Architecture_t>
258 TBatchNormLayer<Architecture_t>::~TBatchNormLayer()
259 {
260  // release descriptors
261  if (fDescriptors) {
262  Architecture_t::ReleaseBNormDescriptors(fDescriptors);
263  delete fDescriptors;
264  }
265 }
266 
267 template <typename Architecture_t>
268 auto TBatchNormLayer<Architecture_t>::Initialize() -> void
269 {
270  Matrix_t &gamma = this->GetWeightsAt(0);
271  Matrix_t &beta = this->GetWeightsAt(1);
272  size_t bndim = gamma.GetNcols();
273 
274  initialize<Architecture_t>(beta, EInitialization::kZero);
275  for (size_t i = 0; i < bndim; ++i) {
276  gamma(0, i) = 1.;
277  // assign default values for the other parameters
278  fMu_Training(0,i) = 0;
279  fVar_Training(0,i) = 1;
280  }
281 
282  Matrix_t &dgamma = this->GetWeightGradientsAt(0);
283  Matrix_t &dbeta = this->GetWeightGradientsAt(1);
284  initialize<Architecture_t>(dgamma, EInitialization::kZero);
285  initialize<Architecture_t>(dbeta, EInitialization::kZero);
286 
287  fTrainedBatches = 0;
288 
289  Architecture_t::InitializeBNormDescriptors(fDescriptors, this);
290 }
291 
292 //______________________________________________________________________________
293 template <typename Architecture_t>
294 auto TBatchNormLayer<Architecture_t>::Forward(Tensor_t &x, bool inTraining) -> void
295 {
296  Tensor_t x2;
297  Tensor_t y2;
298  if (x.GetLayout() != fReshapedData.GetLayout()) {
299  x2 = Tensor_t(x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
300  y2 = Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
301  }
302  else{
303  x2 = x;
304  y2 = this->GetOutput();
305  }
306 
307  auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
308  if (inTraining) {
309  Architecture_t::BatchNormLayerForwardTraining(fNormAxis, x2, y2,
310  this->GetWeightsAt(0), this->GetWeightsAt(1),
311  this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
312  this->GetMuVector(),
313  this->GetVarVector(), this->GetNTrainedBatches(),
314  this->GetMomentum(), this->GetEpsilon(),
315  descr->HelperDescriptor);
316  fTrainedBatches++;
317  }
318 
319  else {
320  // if (fTrainedBatches > 0) {
321  // Architecture_t::PrintTensor(Tensor_t(this->GetWeightsAt(0)), "bnorm gamma");
322  // Architecture_t::PrintTensor(Tensor_t(this->GetWeightsAt(1)), "bnorm beta");
323  // Architecture_t::PrintTensor(Tensor_t(this->GetMuVector()), "bnorm mu");
324  // Architecture_t::PrintTensor(Tensor_t(this->GetVarVector()), "bnorm var");
325  // }
326  Architecture_t::BatchNormLayerForwardInference(fNormAxis, x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
327  y2, this->GetMuVector(), this->GetVarVector(),
328  this->GetEpsilon(), descr->HelperDescriptor);
329  fTrainedBatches = 0;
330  }
331 
332 }
333 
334 //______________________________________________________________________________
335 template <typename Architecture_t>
336 auto TBatchNormLayer<Architecture_t>::Backward(Tensor_t &gradients_backward,
337  const Tensor_t & activations_backward ) -> void
338 // Tensor_t &, Tensor_t &) -> void
339 {
340  auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
341 
342 
343  if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
344  Tensor_t x = Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
345  Tensor_t dx = Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
346  Tensor_t dy = Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
347 
348  Architecture_t::BatchNormLayerBackward(fNormAxis, x, dy, dx,
349  this->GetWeightsAt(0), // gamma (beta is not needed)
350  this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
351  this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
352  this->GetEpsilon(), descr->HelperDescriptor);
353 
354  } else {
355 
356  Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward, // x
357  this->GetActivationGradients(), // dy
358  gradients_backward, // dx
359  this->GetWeightsAt(0), // gamma (beta is not needed)
360  this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
361  this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
362  this->GetEpsilon(), descr->HelperDescriptor);
363  }
364 }
365 
366 //______________________________________________________________________________
367 template <typename Architecture_t>
368 void TBatchNormLayer<Architecture_t>::Print() const
369 {
370  std::cout << " BATCH NORM Layer: \t";
371  std::cout << " Input/Output = ( " ;
372  auto &shape = this->GetOutput().GetShape();
373  for (size_t i = 0; i < shape.size(); ++i) {
374  if (i > 0) std::cout << " , ";
375  std::cout << shape[i];
376  }
377  std::cout << " ) ";
378  std::cout << "\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
379  std::cout << "\t axis = " << fNormAxis << std::endl;
380  std::cout << std::endl;
381 }
382 
383 //______________________________________________________________________________
384 
385 template <typename Architecture_t>
386 void TBatchNormLayer<Architecture_t>::AddWeightsXMLTo(void *parent)
387 {
388 
389  // write layer width activation function + weigbht and bias matrices
390 
391  auto layerxml = gTools().xmlengine().NewChild(parent, 0, "BatchNormLayer");
392 
393 
394  gTools().AddAttr(layerxml, "Momentum", fMomentum);
395  gTools().AddAttr(layerxml, "Epsilon", fEpsilon);
396 
397  // write stored mean and variances
398  //using Scalar_t = typename Architecture_t::Scalar_t;
399 
400  this->WriteMatrixToXML(layerxml, "Training-mu", this->GetMuVector());
401  this->WriteMatrixToXML(layerxml, "Training-variance", this->GetVarVector());
402 
403  // write weights (gamma and beta)
404  this->WriteMatrixToXML(layerxml, "Gamma", this->GetWeightsAt(0));
405  this->WriteMatrixToXML(layerxml, "Beta", this->GetWeightsAt(1));
406 
407 }
408 
409 //______________________________________________________________________________
410 template <typename Architecture_t>
411 void TBatchNormLayer<Architecture_t>::ReadWeightsFromXML(void *parent)
412 {
413  // momentum and epsilon can be added after constructing the class
414  gTools().ReadAttr(parent, "Momentum", fMomentum);
415  gTools().ReadAttr(parent, "Epsilon", fEpsilon);
416  // Read layer weights and biases from XML
417 
418  this->ReadMatrixXML(parent, "Training-mu", this->GetMuVector());
419  this->ReadMatrixXML(parent, "Training-variance", this->GetVarVector());
420 
421  this->ReadMatrixXML(parent, "Gamma", this->GetWeightsAt(0));
422  this->ReadMatrixXML(parent, "Beta", this->GetWeightsAt(1));
423 }
424 
425 } // namespace DNN
426 } // namespace TMVA
427 
428 #endif