Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
ActivationFunctions.hxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 10/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12  //////////////////////////////////////////////////////////////////
13  // Implementation of the activation functions for the reference //
14  // implementation. //
15  //////////////////////////////////////////////////////////////////
16 
18 #include <math.h>
19 
20 namespace TMVA
21 {
22 namespace DNN
23 {
24 
25 //______________________________________________________________________________
26 template<typename Real_t>
27 void TReference<Real_t>::IdentityDerivative(TMatrixT<Real_t> & B,
28  const TMatrixT<Real_t> &/*A*/)
29 {
30  size_t m,n;
31  m = B.GetNrows();
32  n = B.GetNcols();
33 
34  for (size_t i = 0; i < m; i++) {
35  for (size_t j = 0; j < n; j++) {
36  B(i,j) = 1.0;
37  }
38  }
39 }
40 
41 //______________________________________________________________________________
42 template<typename Real_t>
43 void TReference<Real_t>::Relu(TMatrixT<Real_t> &A)
44 {
45  size_t m,n;
46  m = A.GetNrows();
47  n = A.GetNcols();
48 
49  for (size_t i = 0; i < m; i++) {
50  for (size_t j = 0; j < n; j++) {
51  A(i,j) = std::max((Real_t) 0.0, A(i,j));
52  }
53  }
54 }
55 
56 //______________________________________________________________________________
57 template<typename Real_t>
58 inline void TReference<Real_t>::ReluDerivative(TMatrixT<Real_t> & B,
59  const TMatrixT<Real_t> & A)
60 {
61  size_t m,n;
62  m = A.GetNrows();
63  n = A.GetNcols();
64 
65  for (size_t i = 0; i < m; i++)
66  {
67  for (size_t j = 0; j < n; j++)
68  {
69  B(i,j) = (A(i,j) < 0) ? 0.0 : 1.0;
70  }
71  }
72 }
73 
74 //______________________________________________________________________________
75 template<typename Real_t>
76 void TReference<Real_t>::Sigmoid(TMatrixT<Real_t> & A)
77 {
78  size_t m,n;
79  m = A.GetNrows();
80  n = A.GetNcols();
81 
82  for (size_t i = 0; i < m; i++) {
83  for (size_t j = 0; j < n; j++) {
84  Real_t sig = 1.0 / (1.0 + std::exp(-A(i,j)));
85  A(i,j) = sig;
86  }
87  }
88 }
89 
90 //______________________________________________________________________________
91 template<typename Real_t>
92 inline void TReference<Real_t>::SigmoidDerivative(TMatrixT<Real_t> & B,
93  const TMatrixT<Real_t> & A)
94 {
95  size_t m,n;
96  m = A.GetNrows();
97  n = A.GetNcols();
98 
99  for (size_t i = 0; i < m; i++) {
100  for (size_t j = 0; j < n; j++) {
101  Real_t sig = 1.0 / (1.0 + std::exp(-A(i,j)));
102  B(i,j) = sig * (1.0 - sig);
103  }
104  }
105 }
106 
107 //______________________________________________________________________________
108 template<typename Real_t>
109 inline void TReference<Real_t>::Tanh(TMatrixT<Real_t> & B)
110 {
111  size_t m,n;
112  m = B.GetNrows();
113  n = B.GetNcols();
114 
115  for (size_t i = 0; i < m; i++) {
116  for (size_t j = 0; j < n; j++) {
117  Real_t t = tanh(B(i,j));
118  B(i,j) = t;
119  }
120  }
121 }
122 
123 //______________________________________________________________________________
124 template<typename Real_t>
125 inline void TReference<Real_t>::TanhDerivative(TMatrixT<Real_t> & B,
126  const TMatrixT<Real_t> & A)
127 {
128  size_t m,n;
129  m = A.GetNrows();
130  n = A.GetNcols();
131 
132  for (size_t i = 0; i < m; i++) {
133  for (size_t j = 0; j < n; j++) {
134  Real_t t = tanh(A(i,j));
135  B(i,j) = 1 - t * t;
136  }
137  }
138 }
139 
140 //______________________________________________________________________________
141 template<typename Real_t>
142 inline void TReference<Real_t>::SymmetricRelu(TMatrixT<Real_t> & B)
143 {
144  size_t m,n;
145  m = B.GetNrows();
146  n = B.GetNcols();
147 
148  for (size_t i = 0; i < m; i++) {
149  for (size_t j = 0; j < n; j++) {
150  B(i,j) = fabs(B(i,j));
151  }
152  }
153 }
154 
155 //______________________________________________________________________________
156 template<typename Real_t>
157 inline void TReference<Real_t>::SymmetricReluDerivative(TMatrixT<Real_t> & B,
158  const TMatrixT<Real_t> & A)
159 {
160  size_t m,n;
161  m = A.GetNrows();
162  n = A.GetNcols();
163 
164  for (size_t i = 0; i < m; i++) {
165  for (size_t j = 0; j < n; j++) {
166  B(i,j) = (A(i,j) < 0.0) ? -1.0 : 1.0;
167  }
168  }
169 }
170 
171 //______________________________________________________________________________
172 template<typename Real_t>
173 inline void TReference<Real_t>::SoftSign(TMatrixT<Real_t> & A)
174 {
175  size_t m,n;
176  m = A.GetNrows();
177  n = A.GetNcols();
178 
179  for (size_t i = 0; i < m; i++) {
180  for (size_t j = 0; j < n; j++) {
181  Real_t x = A(i,j);
182  A(i,j) = x / (1 + fabs(x));
183  }
184  }
185 }
186 
187 //______________________________________________________________________________
188 template<typename Real_t>
189 inline void TReference<Real_t>::SoftSignDerivative(TMatrixT<Real_t> & B,
190  const TMatrixT<Real_t> & A)
191 {
192  size_t m,n;
193  m = A.GetNrows();
194  n = A.GetNcols();
195 
196  for (size_t i = 0; i < m; i++) {
197  for (size_t j = 0; j < n; j++) {
198  Real_t x = 1.0 + fabs(A(i,j));
199  B(i,j) = 1.0 / (x * x);
200  }
201  }
202 }
203 
204 //______________________________________________________________________________
205 template<typename Real_t>
206 inline void TReference<Real_t>::Gauss(TMatrixT<Real_t> & A)
207 {
208  size_t m,n;
209  m = A.GetNrows();
210  n = A.GetNcols();
211 
212  for (size_t i = 0; i < m; i++) {
213  for (size_t j = 0; j < n; j++) {
214  Real_t x = A(i,j);
215  A(i,j) = exp(- x * x);
216  }
217  }
218 }
219 
220 //______________________________________________________________________________
221 template<typename Real_t>
222 inline void TReference<Real_t>::GaussDerivative(TMatrixT<Real_t> & B,
223  const TMatrixT<Real_t> & A)
224 {
225  size_t m,n;
226  m = A.GetNrows();
227  n = A.GetNcols();
228 
229  for (size_t i = 0; i < m; i++) {
230  for (size_t j = 0; j < n; j++) {
231  Real_t x = A(i,j);
232  B(i,j) = - 2.0 * x * exp(- x * x);
233  }
234  }
235 }
236 } // namespace DNN
237 } // namespace TMVA