27 #ifndef TMVA_DNN_RESHAPELAYER
28 #define TMVA_DNN_RESHAPELAYER
40 template <
typename Architecture_t>
41 class TReshapeLayer :
public VGeneralLayer<Architecture_t> {
43 using Tensor_t =
typename Architecture_t::Tensor_t;
44 using Matrix_t =
typename Architecture_t::Matrix_t;
45 using Scalar_t =
typename Architecture_t::Scalar_t;
52 TReshapeLayer(
size_t BatchSize,
size_t InputDepth,
size_t InputHeight,
size_t InputWidth,
size_t Depth,
53 size_t Height,
size_t Width,
size_t OutputNSlices,
size_t OutputNRows,
size_t OutputNCols,
57 TReshapeLayer(TReshapeLayer<Architecture_t> *layer);
60 TReshapeLayer(
const TReshapeLayer &);
68 void Forward(Tensor_t &input,
bool applyDropout =
false);
70 void Backward(Tensor_t &gradients_backward,
const Tensor_t &activations_backward);
77 virtual void AddWeightsXMLTo(
void *parent);
80 virtual void ReadWeightsFromXML(
void *parent);
86 bool isFlattening()
const {
return fFlattening; }
93 template <
typename Architecture_t>
94 TReshapeLayer<Architecture_t>::TReshapeLayer(
size_t batchSize,
size_t inputDepth,
size_t inputHeight,
size_t inputWidth,
95 size_t depth,
size_t height,
size_t width,
size_t outputNSlices,
96 size_t outputNRows,
size_t outputNCols,
bool flattening)
97 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height, width, 0, 0, 0, 0, 0,
98 0, outputNSlices, outputNRows, outputNCols, EInitialization::kZero),
99 fFlattening(flattening)
101 if (this->GetInputDepth() * this->GetInputHeight() * this->GetInputWidth() !=
102 this->GetDepth() * this->GetHeight() * this->GetWidth()) {
103 std::cout <<
"Reshape Dimensions not compatible \n"
104 << this->GetInputDepth() <<
" x " << this->GetInputHeight() <<
" x " << this->GetInputWidth() <<
" --> "
105 << this->GetDepth() <<
" x " << this->GetHeight() <<
" x " << this->GetWidth() << std::endl;
111 template <
typename Architecture_t>
112 TReshapeLayer<Architecture_t>::TReshapeLayer(TReshapeLayer<Architecture_t> *layer)
113 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
118 template <
typename Architecture_t>
119 TReshapeLayer<Architecture_t>::TReshapeLayer(
const TReshapeLayer &layer)
120 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
126 template <
typename Architecture_t>
127 TReshapeLayer<Architecture_t>::~TReshapeLayer()
133 template <
typename Architecture_t>
134 auto TReshapeLayer<Architecture_t>::Forward(Tensor_t &input,
bool ) ->
void
138 Architecture_t::Flatten(this->GetOutput(), input);
143 Architecture_t::Deflatten(this->GetOutput(), input);
148 template <
typename Architecture_t>
149 auto TReshapeLayer<Architecture_t>::Backward(Tensor_t &gradients_backward,
const Tensor_t &
154 size_t size = gradients_backward.GetSize();
156 if (size == 0)
return;
159 Architecture_t::Deflatten(gradients_backward, this->GetActivationGradients());
162 Architecture_t::Flatten(gradients_backward, this->GetActivationGradients() );
168 template <
typename Architecture_t>
169 auto TReshapeLayer<Architecture_t>::Print() const ->
void
171 std::cout <<
" RESHAPE Layer \t ";
172 std::cout <<
"Input = ( " << this->GetInputDepth() <<
" , " << this->GetInputHeight() <<
" , " << this->GetInputWidth() <<
" ) ";
173 if (this->GetOutput().GetSize() > 0) {
174 std::cout <<
"\tOutput = ( " << this->GetOutput().GetFirstSize() <<
" , " << this->GetOutput().GetHSize() <<
" , " << this->GetOutput().GetWSize() <<
" ) ";
176 std::cout << std::endl;
179 template <
typename Architecture_t>
180 auto TReshapeLayer<Architecture_t>::AddWeightsXMLTo(
void *parent) ->
void
182 auto layerxml = gTools().xmlengine().NewChild(parent, 0,
"ReshapeLayer");
185 gTools().xmlengine().NewAttr(layerxml, 0,
"Depth", gTools().StringFromInt(this->GetDepth()));
186 gTools().xmlengine().NewAttr(layerxml, 0,
"Height", gTools().StringFromInt(this->GetHeight()));
187 gTools().xmlengine().NewAttr(layerxml, 0,
"Width", gTools().StringFromInt(this->GetWidth()));
188 gTools().xmlengine().NewAttr(layerxml, 0,
"Flattening", gTools().StringFromInt(this->isFlattening()));
194 template <
typename Architecture_t>
195 void TReshapeLayer<Architecture_t>::ReadWeightsFromXML(
void * )