Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
ReshapeLayer.h
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Vladimir Ilievski
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : TReshapeLayer *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Reshape Deep Neural Network Layer *
12  * *
13  * Authors (alphabetical): *
14  * Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
15  * *
16  * Copyright (c) 2005-2015: *
17  * CERN, Switzerland *
18  * U. of Victoria, Canada *
19  * MPI-K Heidelberg, Germany *
20  * U. of Bonn, Germany *
21  * *
22  * Redistribution and use in source and binary forms, with or without *
23  * modification, are permitted according to the terms listed in LICENSE *
24  * (http://tmva.sourceforge.net/LICENSE) *
25  **********************************************************************************/
26 
27 #ifndef TMVA_DNN_RESHAPELAYER
28 #define TMVA_DNN_RESHAPELAYER
29 
30 #include "TMatrix.h"
31 
32 #include "TMVA/DNN/GeneralLayer.h"
33 #include "TMVA/DNN/Functions.h"
34 
35 #include <iostream>
36 
37 namespace TMVA {
38 namespace DNN {
39 
40 template <typename Architecture_t>
41 class TReshapeLayer : public VGeneralLayer<Architecture_t> {
42 public:
43  using Tensor_t = typename Architecture_t::Tensor_t;
44  using Matrix_t = typename Architecture_t::Matrix_t;
45  using Scalar_t = typename Architecture_t::Scalar_t;
46 
47 private:
48  bool fFlattening; ///< Whather the layer is doing flattening
49 
50 public:
51  /*! Constructor */
52  TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth,
53  size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols,
54  bool Flattening);
55 
56  /*! Copy the reshape layer provided as a pointer */
57  TReshapeLayer(TReshapeLayer<Architecture_t> *layer);
58 
59  /*! Copy Constructor */
60  TReshapeLayer(const TReshapeLayer &);
61 
62  /*! Destructor. */
63  ~TReshapeLayer();
64 
65  /*! The input must be in 3D tensor form with the different matrices
66  * corresponding to different events in the batch. It transforms the
67  * input matrices. */
68  void Forward(Tensor_t &input, bool applyDropout = false);
69 
70  void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
71  // Tensor_t &inp1, Tensor_t &inp2);
72 
73  /*! Prints the info about the layer. */
74  void Print() const;
75 
76  /*! Writes the information and the weights about the layer in an XML node. */
77  virtual void AddWeightsXMLTo(void *parent);
78 
79  /*! Read the information and the weights about the layer from XML node. */
80  virtual void ReadWeightsFromXML(void *parent);
81 
82 
83  /*! TODO Add documentation
84  * Does this layer flatten? (necessary for DenseLayer)
85  * B x D1 x D2 --> 1 x B x (D1 * D2) */
86  bool isFlattening() const { return fFlattening; }
87 };
88 
89 //
90 //
91 // The Reshape Layer Class - Implementation
92 //_________________________________________________________________________________________________
93 template <typename Architecture_t>
94 TReshapeLayer<Architecture_t>::TReshapeLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
95  size_t depth, size_t height, size_t width, size_t outputNSlices,
96  size_t outputNRows, size_t outputNCols, bool flattening)
97  : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height, width, 0, 0, 0, 0, 0,
98  0, outputNSlices, outputNRows, outputNCols, EInitialization::kZero),
99  fFlattening(flattening)
100 {
101  if (this->GetInputDepth() * this->GetInputHeight() * this->GetInputWidth() !=
102  this->GetDepth() * this->GetHeight() * this->GetWidth()) {
103  std::cout << "Reshape Dimensions not compatible \n"
104  << this->GetInputDepth() << " x " << this->GetInputHeight() << " x " << this->GetInputWidth() << " --> "
105  << this->GetDepth() << " x " << this->GetHeight() << " x " << this->GetWidth() << std::endl;
106  return;
107  }
108 }
109 
110 //_________________________________________________________________________________________________
111 template <typename Architecture_t>
112 TReshapeLayer<Architecture_t>::TReshapeLayer(TReshapeLayer<Architecture_t> *layer)
113  : VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
114 {
115 }
116 
117 //_________________________________________________________________________________________________
118 template <typename Architecture_t>
119 TReshapeLayer<Architecture_t>::TReshapeLayer(const TReshapeLayer &layer)
120  : VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
121 {
122  // Nothing to do here.
123 }
124 
125 //_________________________________________________________________________________________________
126 template <typename Architecture_t>
127 TReshapeLayer<Architecture_t>::~TReshapeLayer()
128 {
129  // Nothing to do here.
130 }
131 
132 //_________________________________________________________________________________________________
133 template <typename Architecture_t>
134 auto TReshapeLayer<Architecture_t>::Forward(Tensor_t &input, bool /*applyDropout*/) -> void
135 {
136  if (fFlattening) {
137 
138  Architecture_t::Flatten(this->GetOutput(), input);
139 
140  return;
141  } else {
142 
143  Architecture_t::Deflatten(this->GetOutput(), input); //, out_size, nRows, nCols);
144  return;
145  }
146 }
147 //_________________________________________________________________________________________________
148 template <typename Architecture_t>
149 auto TReshapeLayer<Architecture_t>::Backward(Tensor_t &gradients_backward, const Tensor_t &
150  /*activations_backward*/) -> void
151 // Tensor_t & /*inp1*/, Tensor_t &
152 // /*inp2*/) -> void
153 {
154  size_t size = gradients_backward.GetSize();
155  // in case of first layer size is zero - do nothing
156  if (size == 0) return;
157  if (fFlattening) {
158  // deflatten in backprop
159  Architecture_t::Deflatten(gradients_backward, this->GetActivationGradients());
160  return;
161  } else {
162  Architecture_t::Flatten(gradients_backward, this->GetActivationGradients() );
163  return;
164  }
165 }
166 
167 //_________________________________________________________________________________________________
168 template <typename Architecture_t>
169 auto TReshapeLayer<Architecture_t>::Print() const -> void
170 {
171  std::cout << " RESHAPE Layer \t ";
172  std::cout << "Input = ( " << this->GetInputDepth() << " , " << this->GetInputHeight() << " , " << this->GetInputWidth() << " ) ";
173  if (this->GetOutput().GetSize() > 0) {
174  std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput().GetHSize() << " , " << this->GetOutput().GetWSize() << " ) ";
175  }
176  std::cout << std::endl;
177 }
178 
179 template <typename Architecture_t>
180 auto TReshapeLayer<Architecture_t>::AddWeightsXMLTo(void *parent) -> void
181 {
182  auto layerxml = gTools().xmlengine().NewChild(parent, 0, "ReshapeLayer");
183 
184  // write info for reshapelayer
185  gTools().xmlengine().NewAttr(layerxml, 0, "Depth", gTools().StringFromInt(this->GetDepth()));
186  gTools().xmlengine().NewAttr(layerxml, 0, "Height", gTools().StringFromInt(this->GetHeight()));
187  gTools().xmlengine().NewAttr(layerxml, 0, "Width", gTools().StringFromInt(this->GetWidth()));
188  gTools().xmlengine().NewAttr(layerxml, 0, "Flattening", gTools().StringFromInt(this->isFlattening()));
189 
190 
191 }
192 
193 //______________________________________________________________________________
194 template <typename Architecture_t>
195 void TReshapeLayer<Architecture_t>::ReadWeightsFromXML(void * /*parent*/)
196 {
197  // no info to read
198 }
199 
200 
201 
202 } // namespace DNN
203 } // namespace TMVA
204 
205 #endif