18 #ifndef TMVA_DNN_FUNCTIONS
19 #define TMVA_DNN_FUNCTIONS
31 enum class EActivationFunction
43 enum class EOutputFunction
54 enum class ELossFunction
57 kMeanSquaredError =
'R',
58 kSoftmaxCrossEntropy =
'S'
62 enum class ERegularization
70 enum class EInitialization {
80 enum class EOptimizer {
95 template<
typename Architecture_t>
96 inline void evaluate(
typename Architecture_t::Tensor_t &A,
97 EActivationFunction f)
101 case EActivationFunction::kIdentity :
break;
102 case EActivationFunction::kRelu : Architecture_t::Relu(A);
104 case EActivationFunction::kSigmoid : Architecture_t::Sigmoid(A);
106 case EActivationFunction::kTanh : Architecture_t::Tanh(A);
108 case EActivationFunction::kSymmRelu : Architecture_t::SymmetricRelu(A);
110 case EActivationFunction::kSoftSign : Architecture_t::SoftSign(A);
112 case EActivationFunction::kGauss : Architecture_t::Gauss(A);
120 template<
typename Architecture_t>
121 inline void evaluateDerivative(
typename Architecture_t::Tensor_t & B,
122 EActivationFunction f,
123 const typename Architecture_t::Tensor_t & A)
127 case EActivationFunction::kIdentity : Architecture_t::IdentityDerivative(B, A);
129 case EActivationFunction::kRelu : Architecture_t::ReluDerivative(B, A);
131 case EActivationFunction::kSigmoid : Architecture_t::SigmoidDerivative(B, A);
133 case EActivationFunction::kTanh : Architecture_t::TanhDerivative(B, A);
135 case EActivationFunction::kSymmRelu : Architecture_t::SymmetricReluDerivative(B, A);
137 case EActivationFunction::kSoftSign : Architecture_t::SoftSignDerivative(B, A);
139 case EActivationFunction::kGauss : Architecture_t::GaussDerivative(B, A);
150 template<
typename Architecture_t>
151 inline void evaluate(
typename Architecture_t::Matrix_t &A,
153 const typename Architecture_t::Matrix_t &X)
157 case EOutputFunction::kIdentity : Architecture_t::Copy(A, X);
159 case EOutputFunction::kSigmoid : Architecture_t::Sigmoid(A, X);
161 case EOutputFunction::kSoftmax : Architecture_t::Softmax(A, X);
173 template <
typename Architecture_t>
174 inline auto evaluate(ELossFunction f,
const typename Architecture_t::Matrix_t &Y,
175 const typename Architecture_t::Matrix_t &output,
const typename Architecture_t::Matrix_t &weights)
176 -> decltype(Architecture_t::CrossEntropy(Y, output, weights))
180 case ELossFunction::kCrossEntropy:
return Architecture_t::CrossEntropy(Y, output, weights);
181 case ELossFunction::kMeanSquaredError:
return Architecture_t::MeanSquaredError(Y, output, weights);
182 case ELossFunction::kSoftmaxCrossEntropy:
return Architecture_t::SoftmaxCrossEntropy(Y, output, weights);
190 template <
typename Architecture_t>
191 inline void evaluateGradients(
typename Architecture_t::Matrix_t &dY, ELossFunction f,
192 const typename Architecture_t::Matrix_t &Y,
193 const typename Architecture_t::Matrix_t &output,
194 const typename Architecture_t::Matrix_t &weights)
198 case ELossFunction::kCrossEntropy: Architecture_t::CrossEntropyGradients(dY, Y, output, weights);
break;
199 case ELossFunction::kMeanSquaredError: Architecture_t::MeanSquaredErrorGradients(dY, Y, output, weights);
break;
200 case ELossFunction::kSoftmaxCrossEntropy :
201 Architecture_t::SoftmaxCrossEntropyGradients(dY, Y, output, weights);
213 template<
typename Architecture_t>
214 inline auto regularization(
const typename Architecture_t::Matrix_t &A,
216 -> decltype(Architecture_t::L1Regularization(A))
220 case ERegularization::kNone :
222 case ERegularization::kL1 :
223 return Architecture_t::L1Regularization(A);
224 case ERegularization::kL2 :
225 return Architecture_t::L2Regularization(A);
233 template<
typename Architecture_t>
234 inline void addRegularizationGradients(
typename Architecture_t::Matrix_t &A,
235 const typename Architecture_t::Matrix_t &W,
236 typename Architecture_t::Scalar_t weightDecay,
241 case ERegularization::kNone :
243 case ERegularization::kL1 :
244 Architecture_t::AddL1RegularizationGradients(A, W, weightDecay);
246 case ERegularization::kL2 :
247 Architecture_t::AddL2RegularizationGradients(A, W, weightDecay);
257 template<
typename Architecture_t>
258 inline void initialize(
typename Architecture_t::Matrix_t & A,
262 case EInitialization::kGauss : Architecture_t::InitializeGauss(A);
264 case EInitialization::kUniform : Architecture_t::InitializeUniform(A);
266 case EInitialization::kIdentity : Architecture_t::InitializeIdentity(A);
268 case EInitialization::kZero : Architecture_t::InitializeZero(A);
270 case EInitialization::kGlorotNormal : Architecture_t::InitializeGlorotNormal(A);
272 case EInitialization::kGlorotUniform : Architecture_t::InitializeGlorotUniform(A);