29 #ifndef TMVA_DNN_RNN_LAYER
30 #define TMVA_DNN_RNN_LAYER
54 template<
typename Architecture_t>
55 class TBasicRNNLayer :
public VGeneralLayer<Architecture_t>
60 using Tensor_t =
typename Architecture_t::Tensor_t;
61 using Matrix_t =
typename Architecture_t::Matrix_t;
62 using Scalar_t =
typename Architecture_t::Scalar_t;
70 DNN::EActivationFunction fF;
73 Matrix_t &fWeightsInput;
74 Matrix_t &fWeightsState;
77 Tensor_t fDerivatives;
78 Matrix_t &fWeightInputGradients;
79 Matrix_t &fWeightStateGradients;
80 Matrix_t &fBiasGradients;
82 typename Architecture_t::ActivationDescriptor_t fActivationDesc;
87 TBasicRNNLayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
88 size_t timeSteps,
bool rememberState =
false,
89 DNN::EActivationFunction f = DNN::EActivationFunction::kTanh,
90 bool training =
true, DNN::EInitialization fA = DNN::EInitialization::kZero);
93 TBasicRNNLayer(
const TBasicRNNLayer &);
101 void InitState(DNN::EInitialization m = DNN::EInitialization::kZero);
105 void Forward(Tensor_t &input,
bool isTraining =
true);
108 void CellForward(
const Matrix_t &input, Matrix_t & dF);
112 void Backward(Tensor_t &gradients_backward,
113 const Tensor_t &activations_backward);
116 void Update(
const Scalar_t learningRate);
120 inline Matrix_t & CellBackward(Matrix_t & state_gradients_backward,
121 const Matrix_t & precStateActivations,
122 const Matrix_t & input, Matrix_t & input_gradient, Matrix_t &dF);
128 virtual void AddWeightsXMLTo(
void *parent);
131 virtual void ReadWeightsFromXML(
void *parent);
135 size_t GetTimeSteps()
const {
return fTimeSteps; }
136 size_t GetStateSize()
const {
return fStateSize; }
137 size_t GetInputSize()
const {
return this->GetInputWidth(); }
138 inline bool IsRememberState()
const {
return fRememberState;}
139 inline DNN::EActivationFunction GetActivationFunction()
const {
return fF;}
140 Matrix_t & GetState() {
return fState;}
141 const Matrix_t & GetState()
const {
return fState;}
142 Matrix_t & GetWeightsInput() {
return fWeightsInput;}
143 const Matrix_t & GetWeightsInput()
const {
return fWeightsInput;}
144 Matrix_t & GetWeightsState() {
return fWeightsState;}
145 const Matrix_t & GetWeightsState()
const {
return fWeightsState;}
146 Tensor_t & GetDerivatives() {
return fDerivatives;}
147 const Tensor_t & GetDerivatives()
const {
return fDerivatives;}
151 Matrix_t & GetBiasesState() {
return fBiases;}
152 const Matrix_t & GetBiasesState()
const {
return fBiases;}
153 Matrix_t & GetBiasStateGradients() {
return fBiasGradients;}
154 const Matrix_t & GetBiasStateGradients()
const {
return fBiasGradients;}
155 Matrix_t & GetWeightInputGradients() {
return fWeightInputGradients;}
156 const Matrix_t & GetWeightInputGradients()
const {
return fWeightInputGradients;}
157 Matrix_t & GetWeightStateGradients() {
return fWeightStateGradients;}
158 const Matrix_t & GetWeightStateGradients()
const {
return fWeightStateGradients;}
165 template <
typename Architecture_t>
166 TBasicRNNLayer<Architecture_t>::TBasicRNNLayer(
size_t batchSize,
size_t stateSize,
size_t inputSize,
size_t timeSteps,
167 bool rememberState, DNN::EActivationFunction f,
bool ,
168 DNN::EInitialization fA)
170 : VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, timeSteps, stateSize, 2,
171 {stateSize, stateSize}, {inputSize, stateSize}, 1, {stateSize}, {1}, batchSize,
172 timeSteps, stateSize, fA),
173 fTimeSteps(timeSteps),
174 fStateSize(stateSize),
175 fRememberState(rememberState),
177 fState(batchSize, stateSize),
178 fWeightsInput(this->GetWeightsAt(0)),
179 fWeightsState(this->GetWeightsAt(1)),
180 fBiases(this->GetBiasesAt(0)),
181 fDerivatives( timeSteps, batchSize, stateSize),
182 fWeightInputGradients(this->GetWeightGradientsAt(0)),
183 fWeightStateGradients(this->GetWeightGradientsAt(1)),
184 fBiasGradients(this->GetBiasGradientsAt(0))
190 template <
typename Architecture_t>
191 TBasicRNNLayer<Architecture_t>::TBasicRNNLayer(
const TBasicRNNLayer &layer)
192 : VGeneralLayer<Architecture_t>(layer), fTimeSteps(layer.fTimeSteps), fStateSize(layer.fStateSize),
193 fRememberState(layer.fRememberState), fF(layer.GetActivationFunction()),
194 fState(layer.GetBatchSize(), layer.GetStateSize()), fWeightsInput(this->GetWeightsAt(0)),
195 fWeightsState(this->GetWeightsAt(1)), fBiases(this->GetBiasesAt(0)),
196 fDerivatives( layer.GetDerivatives().GetShape() ), fWeightInputGradients(this->GetWeightGradientsAt(0)),
197 fWeightStateGradients(this->GetWeightGradientsAt(1)), fBiasGradients(this->GetBiasGradientsAt(0))
200 Architecture_t::Copy(fDerivatives, layer.GetDerivatives() );
203 Architecture_t::Copy(fState, layer.GetState());
217 template <
typename Architecture_t>
218 auto TBasicRNNLayer<Architecture_t>::InitState(DNN::EInitialization ) ->
void
220 DNN::initialize<Architecture_t>(this->GetState(), DNN::EInitialization::kZero);
222 Architecture_t::InitializeActivationDescriptor(fActivationDesc,this->GetActivationFunction());
226 template<
typename Architecture_t>
227 auto TBasicRNNLayer<Architecture_t>::Print() const
230 std::cout <<
" RECURRENT Layer: \t ";
231 std::cout <<
" (NInput = " << this->GetInputSize();
232 std::cout <<
", NState = " << this->GetStateSize();
233 std::cout <<
", NTime = " << this->GetTimeSteps() <<
" )";
234 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" )\n";
237 template <
typename Architecture_t>
238 auto debugMatrix(
const typename Architecture_t::Matrix_t &A,
const std::string name =
"matrix")
241 std::cout << name <<
"\n";
242 for (
size_t i = 0; i < A.GetNrows(); ++i) {
243 for (
size_t j = 0; j < A.GetNcols(); ++j) {
244 std::cout << A(i, j) <<
" ";
248 std::cout <<
"********\n";
253 template <
typename Architecture_t>
254 auto inline TBasicRNNLayer<Architecture_t>::Forward(Tensor_t &input,
bool )
262 Tensor_t arrInput (fTimeSteps, this->GetBatchSize(), this->GetInputWidth() );
264 Architecture_t::Rearrange(arrInput, input);
265 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
268 if (!this->fRememberState) InitState(DNN::EInitialization::kZero);
269 for (
size_t t = 0; t < fTimeSteps; ++t) {
270 Matrix_t arrInput_m = arrInput.At(t).GetMatrix();
271 Matrix_t df_m = fDerivatives.At(t).GetMatrix();
272 CellForward(arrInput_m, df_m );
273 Matrix_t arrOutput_m = arrOutput.At(t).GetMatrix();
274 Architecture_t::Copy(arrOutput_m, fState);
276 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
280 template <
typename Architecture_t>
281 auto inline TBasicRNNLayer<Architecture_t>::CellForward(
const Matrix_t &input, Matrix_t &dF)
285 const DNN::EActivationFunction fAF = this->GetActivationFunction();
286 Matrix_t tmpState(fState.GetNrows(), fState.GetNcols());
287 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsState);
288 Architecture_t::MultiplyTranspose(fState, input, fWeightsInput);
289 Architecture_t::ScaleAdd(fState, tmpState);
290 Architecture_t::AddRowWise(fState, fBiases);
291 Tensor_t inputActivFunc(dF);
292 Tensor_t tState(fState);
297 Architecture_t::Copy(inputActivFunc, tState);
298 Architecture_t::ActivationFunctionForward(tState, fAF, fActivationDesc);
303 template <
typename Architecture_t>
304 auto inline TBasicRNNLayer<Architecture_t>::Backward(Tensor_t &gradients_backward,
305 const Tensor_t &activations_backward) ->
void
316 if (gradients_backward.GetSize() == 0) {
319 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
326 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
328 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
330 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
331 DNN::initialize<Architecture_t>(state_gradients_backward, DNN::EInitialization::kZero);
333 Matrix_t initState(this->GetBatchSize(), fStateSize);
334 DNN::initialize<Architecture_t>(initState, DNN::EInitialization::kZero);
336 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
338 Architecture_t::Rearrange(arr_output, this->GetOutput());
340 Tensor_t arr_actgradients ( fTimeSteps, this->GetBatchSize(), fStateSize);
342 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
345 fWeightInputGradients.Zero();
346 fWeightStateGradients.Zero();
347 fBiasGradients.Zero();
349 for (
size_t t = fTimeSteps; t > 0; t--) {
351 Matrix_t actgrad_m = arr_actgradients.At(t - 1).GetMatrix();
352 Architecture_t::ScaleAdd(state_gradients_backward, actgrad_m);
354 Matrix_t actbw_m = arr_activations_backward.At(t - 1).GetMatrix();
355 Matrix_t gradbw_m = arr_gradients_backward.At(t - 1).GetMatrix();
361 Tensor_t df = fDerivatives.At(t-1);
362 Tensor_t dy = Tensor_t(state_gradients_backward);
364 Tensor_t y = arr_output.At(t-1);
365 Architecture_t::ActivationFunctionBackward(df, y,
367 this->GetActivationFunction(), fActivationDesc);
369 Matrix_t df_m = df.GetMatrix();
373 Matrix_t precStateActivations = arr_output.At(t - 2).GetMatrix();
374 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
382 const Matrix_t & precStateActivations = initState;
383 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
393 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
399 template <
typename Architecture_t>
400 auto inline TBasicRNNLayer<Architecture_t>::CellBackward(Matrix_t & state_gradients_backward,
401 const Matrix_t & precStateActivations,
402 const Matrix_t & input, Matrix_t & input_gradient, Matrix_t &dF)
405 return Architecture_t::RecurrentLayerBackward(state_gradients_backward, fWeightInputGradients, fWeightStateGradients,
406 fBiasGradients, dF, precStateActivations, fWeightsInput,
407 fWeightsState, input, input_gradient);
411 template <
typename Architecture_t>
412 void TBasicRNNLayer<Architecture_t>::AddWeightsXMLTo(
void *parent)
414 auto layerxml = gTools().xmlengine().NewChild(parent, 0,
"RNNLayer");
417 gTools().xmlengine().NewAttr(layerxml, 0,
"StateSize", gTools().StringFromInt(this->GetStateSize()));
418 gTools().xmlengine().NewAttr(layerxml, 0,
"InputSize", gTools().StringFromInt(this->GetInputSize()));
419 gTools().xmlengine().NewAttr(layerxml, 0,
"TimeSteps", gTools().StringFromInt(this->GetTimeSteps()));
420 gTools().xmlengine().NewAttr(layerxml, 0,
"RememberState", gTools().StringFromInt(this->IsRememberState()));
423 this->WriteMatrixToXML(layerxml,
"InputWeights",
this -> GetWeightsAt(0));
424 this->WriteMatrixToXML(layerxml,
"StateWeights",
this -> GetWeightsAt(1));
425 this->WriteMatrixToXML(layerxml,
"Biases",
this -> GetBiasesAt(0));
431 template <
typename Architecture_t>
432 void TBasicRNNLayer<Architecture_t>::ReadWeightsFromXML(
void *parent)
435 this->ReadMatrixXML(parent,
"InputWeights",
this -> GetWeightsAt(0));
436 this->ReadMatrixXML(parent,
"StateWeights",
this -> GetWeightsAt(1));
437 this->ReadMatrixXML(parent,
"Biases",
this -> GetBiasesAt(0));