Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
RecurrentPropagation.hxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Saurav Shekhar 23/06/17
3 
4 /*************************************************************************
5  * Copyright (C) 2017, Saurav Shekhar *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /////////////////////////////////////////////////////////////////////
13 // Implementation of the functions required for the forward and //
14 // backward propagation of activations through a recurrent neural //
15 // network in the TCpu architecture //
16 /////////////////////////////////////////////////////////////////////
17 
18 
20 
21 namespace TMVA
22 {
23 namespace DNN
24 {
25 
26 template<typename AFloat>
27 auto TCpu<AFloat>::RecurrentLayerBackward(TCpuMatrix<AFloat> & state_gradients_backward, // BxH
28  TCpuMatrix<AFloat> & input_weight_gradients,
29  TCpuMatrix<AFloat> & state_weight_gradients,
30  TCpuMatrix<AFloat> & bias_gradients,
31  TCpuMatrix<AFloat> & df, //BxH
32  const TCpuMatrix<AFloat> & state, // BxH
33  const TCpuMatrix<AFloat> & weights_input, // HxD
34  const TCpuMatrix<AFloat> & weights_state, // HxH
35  const TCpuMatrix<AFloat> & input, // BxD
36  TCpuMatrix<AFloat> & input_gradient)
37 -> TCpuMatrix<AFloat> &
38 {
39 
40  // std::cout << "Recurrent Propo" << std::endl;
41  // TMVA_DNN_PrintTCpuMatrix(df,"DF");
42  // TMVA_DNN_PrintTCpuMatrix(state_gradients_backward,"State grad");
43  // TMVA_DNN_PrintTCpuMatrix(input_weight_gradients,"input w grad");
44  // TMVA_DNN_PrintTCpuMatrix(state,"state");
45  // TMVA_DNN_PrintTCpuMatrix(input,"input");
46 
47  // Compute element-wise product.
48  //Hadamard(df, state_gradients_backward); // B x H
49 
50  // Input gradients.
51  if (input_gradient.GetNoElements() > 0) Multiply(input_gradient, df, weights_input);
52 
53  // State gradients.
54  if (state_gradients_backward.GetNoElements() > 0) Multiply(state_gradients_backward, df, weights_state);
55 
56  // compute the gradients
57  // Perform the operation in place by readding the result on the same gradient matrix
58  // e.g. W += D * X
59 
60  // Weights gradients
61  if (input_weight_gradients.GetNoElements() > 0) {
62  TransposeMultiply(input_weight_gradients, df, input, 1. , 1.); // H x B . B x D
63  }
64  if (state_weight_gradients.GetNoElements() > 0) {
65  TransposeMultiply(state_weight_gradients, df, state, 1. , 1. ); // H x B . B x H
66  }
67 
68  // Bias gradients.
69  if (bias_gradients.GetNoElements() > 0) {
70  SumColumns(bias_gradients, df, 1., 1.); // could be probably do all here
71  }
72 
73  //std::cout << "RecurrentPropo: end " << std::endl;
74 
75  // TMVA_DNN_PrintTCpuMatrix(state_gradients_backward,"State grad");
76  // TMVA_DNN_PrintTCpuMatrix(input_weight_gradients,"input w grad");
77  // TMVA_DNN_PrintTCpuMatrix(bias_gradients,"bias grad");
78  // TMVA_DNN_PrintTCpuMatrix(input_gradient,"input grad");
79 
80  return input_gradient;
81 }
82 
83 } // namespace DNN
84 } // namespace TMVA