27 #ifndef TMVA_DNN_BatchNormLayer
28 #define TMVA_DNN_BatchNormLayer
62 template <
typename Architecture_t>
63 class TBatchNormLayer :
public VGeneralLayer<Architecture_t> {
66 using Scalar_t =
typename Architecture_t::Scalar_t;
67 using Matrix_t =
typename Architecture_t::Matrix_t;
68 using Tensor_t =
typename Architecture_t::Tensor_t;
70 using HelperDescriptor_t =
typename Architecture_t::TensorDescriptor_t;
71 using BNormDescriptors_t =
typename Architecture_t::BNormDescriptors_t;
76 Tensor_t fDerivatives;
87 Matrix_t fMu_Training;
88 Matrix_t fVar_Training;
91 Tensor_t fReshapedData;
94 int fTrainedBatches = 0;
96 TDescriptors * fDescriptors =
nullptr;
100 TBatchNormLayer(
size_t batchSize,
size_t inputDepth,
size_t inputHeight,
size_t inputWidth,
101 const std::vector<size_t> & shape,
int axis = -1, Scalar_t momentum = -1., Scalar_t epsilon = 0.0001);
104 TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer);
107 TBatchNormLayer(
const TBatchNormLayer &);
117 void Forward(Tensor_t &input,
bool inTraining =
true);
123 void Backward(Tensor_t &gradients_backward,
const Tensor_t &activations_backward);
128 void ResetTraining() { fTrainedBatches = 0; }
134 virtual void AddWeightsXMLTo(
void *parent);
137 virtual void ReadWeightsFromXML(
void *parent);
140 virtual void Initialize();
143 const int & GetNTrainedBatches()
const {
return fTrainedBatches;}
144 int & GetNTrainedBatches() {
return fTrainedBatches;}
147 const Matrix_t & GetBatchMean()
const {
return fMu;}
148 Matrix_t & GetBatchMean() {
return fMu;}
155 const Matrix_t & GetVariance()
const {
return fVar;}
156 Matrix_t & GetVariance() {
return fVar;}
159 const Matrix_t & GetIVariance()
const {
return fIVar;}
160 Matrix_t & GetIVariance() {
return fIVar;}
163 const Matrix_t & GetMuVector()
const {
return fMu_Training;}
164 Matrix_t & GetMuVector() {
return fMu_Training;}
167 const Matrix_t & GetVarVector()
const {
return fVar_Training;}
168 Matrix_t & GetVarVector() {
return fVar_Training;}
173 Scalar_t GetMomentum()
const {
return fMomentum;}
176 Scalar_t GetEpsilon()
const {
return fEpsilon;}
179 Scalar_t GetNormAxis()
const {
return fNormAxis;}
181 const Matrix_t &GetReshapedData()
const {
return fReshapedData; }
182 Matrix_t &GetReshapedData() {
return fReshapedData; }
184 std::vector<Matrix_t> GetExtraLayerParameters()
const {
185 std::vector<Matrix_t> params(2);
186 params[0] = this->GetMuVector();
187 params[1] = this->GetVarVector();
191 void SetExtraLayerParameters(
const std::vector<Matrix_t> & params)
193 this->GetMuVector() = params[0];
194 this->GetVarVector() = params[1];
198 static size_t CalculateNormDim(
int axis,
size_t c,
size_t h,
size_t w)
217 template <
typename Architecture_t>
218 TBatchNormLayer<Architecture_t>::TBatchNormLayer(
size_t batchSize,
size_t inputDepth,
size_t inputHeight,
219 size_t inputWidth,
const std::vector<size_t> &shape,
int axis,
220 Scalar_t momentum, Scalar_t epsilon)
221 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth,
222 inputDepth, inputHeight, inputWidth,
224 CalculateNormDim(axis, inputDepth, inputHeight, inputWidth),
226 shape[2], shape[0], shape[1],
227 EInitialization::kZero),
228 fNormAxis(axis), fMomentum(momentum), fEpsilon(epsilon),
229 fMu(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
230 fVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
231 fIVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
232 fMu_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
233 fVar_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
240 template <
typename Architecture_t>
241 TBatchNormLayer<Architecture_t>::TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer)
242 : VGeneralLayer<Architecture_t>(layer)
245 printf(
"Error - copy ctor not implmented\n");
249 template <
typename Architecture_t>
250 TBatchNormLayer<Architecture_t>::TBatchNormLayer(
const TBatchNormLayer &layer) : VGeneralLayer<Architecture_t>(layer)
253 printf(
"Error - copy ctor not implmented\n");
257 template <
typename Architecture_t>
258 TBatchNormLayer<Architecture_t>::~TBatchNormLayer()
262 Architecture_t::ReleaseBNormDescriptors(fDescriptors);
267 template <
typename Architecture_t>
268 auto TBatchNormLayer<Architecture_t>::Initialize() ->
void
270 Matrix_t &gamma = this->GetWeightsAt(0);
271 Matrix_t &beta = this->GetWeightsAt(1);
272 size_t bndim = gamma.GetNcols();
274 initialize<Architecture_t>(beta, EInitialization::kZero);
275 for (
size_t i = 0; i < bndim; ++i) {
278 fMu_Training(0,i) = 0;
279 fVar_Training(0,i) = 1;
282 Matrix_t &dgamma = this->GetWeightGradientsAt(0);
283 Matrix_t &dbeta = this->GetWeightGradientsAt(1);
284 initialize<Architecture_t>(dgamma, EInitialization::kZero);
285 initialize<Architecture_t>(dbeta, EInitialization::kZero);
289 Architecture_t::InitializeBNormDescriptors(fDescriptors,
this);
293 template <
typename Architecture_t>
294 auto TBatchNormLayer<Architecture_t>::Forward(Tensor_t &x,
bool inTraining) ->
void
298 if (x.GetLayout() != fReshapedData.GetLayout()) {
299 x2 = Tensor_t(x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
300 y2 = Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
304 y2 = this->GetOutput();
307 auto descr =
static_cast<BNormDescriptors_t *
> (fDescriptors);
309 Architecture_t::BatchNormLayerForwardTraining(fNormAxis, x2, y2,
310 this->GetWeightsAt(0), this->GetWeightsAt(1),
311 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
313 this->GetVarVector(), this->GetNTrainedBatches(),
314 this->GetMomentum(), this->GetEpsilon(),
315 descr->HelperDescriptor);
326 Architecture_t::BatchNormLayerForwardInference(fNormAxis, x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
327 y2, this->GetMuVector(), this->GetVarVector(),
328 this->GetEpsilon(), descr->HelperDescriptor);
335 template <
typename Architecture_t>
336 auto TBatchNormLayer<Architecture_t>::Backward(Tensor_t &gradients_backward,
337 const Tensor_t & activations_backward ) ->
void
340 auto descr =
static_cast<BNormDescriptors_t *
> (fDescriptors);
343 if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
344 Tensor_t x = Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
345 Tensor_t dx = Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
346 Tensor_t dy = Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
348 Architecture_t::BatchNormLayerBackward(fNormAxis, x, dy, dx,
349 this->GetWeightsAt(0),
350 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
351 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
352 this->GetEpsilon(), descr->HelperDescriptor);
356 Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward,
357 this->GetActivationGradients(),
359 this->GetWeightsAt(0),
360 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
361 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
362 this->GetEpsilon(), descr->HelperDescriptor);
367 template <
typename Architecture_t>
368 void TBatchNormLayer<Architecture_t>::Print()
const
370 std::cout <<
" BATCH NORM Layer: \t";
371 std::cout <<
" Input/Output = ( " ;
372 auto &shape = this->GetOutput().GetShape();
373 for (
size_t i = 0; i < shape.size(); ++i) {
374 if (i > 0) std::cout <<
" , ";
375 std::cout << shape[i];
378 std::cout <<
"\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
379 std::cout <<
"\t axis = " << fNormAxis << std::endl;
380 std::cout << std::endl;
385 template <
typename Architecture_t>
386 void TBatchNormLayer<Architecture_t>::AddWeightsXMLTo(
void *parent)
391 auto layerxml = gTools().xmlengine().NewChild(parent, 0,
"BatchNormLayer");
394 gTools().AddAttr(layerxml,
"Momentum", fMomentum);
395 gTools().AddAttr(layerxml,
"Epsilon", fEpsilon);
400 this->WriteMatrixToXML(layerxml,
"Training-mu", this->GetMuVector());
401 this->WriteMatrixToXML(layerxml,
"Training-variance", this->GetVarVector());
404 this->WriteMatrixToXML(layerxml,
"Gamma", this->GetWeightsAt(0));
405 this->WriteMatrixToXML(layerxml,
"Beta", this->GetWeightsAt(1));
410 template <
typename Architecture_t>
411 void TBatchNormLayer<Architecture_t>::ReadWeightsFromXML(
void *parent)
414 gTools().ReadAttr(parent,
"Momentum", fMomentum);
415 gTools().ReadAttr(parent,
"Epsilon", fEpsilon);
418 this->ReadMatrixXML(parent,
"Training-mu", this->GetMuVector());
419 this->ReadMatrixXML(parent,
"Training-variance", this->GetVarVector());
421 this->ReadMatrixXML(parent,
"Gamma", this->GetWeightsAt(0));
422 this->ReadMatrixXML(parent,
"Beta", this->GetWeightsAt(1));