Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Regularization.hxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 10/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12  //////////////////////////////////////////////////////////////////////
13  // Implementation of the regularization functions for the reference //
14  // implementation. //
15  //////////////////////////////////////////////////////////////////////
16 
18 
19 namespace TMVA
20 {
21 namespace DNN
22 {
23 
24 //______________________________________________________________________________
25 template<typename Real_t>
26 Real_t TReference<Real_t>::L1Regularization(const TMatrixT<Real_t> & W)
27 {
28  size_t m,n;
29  m = W.GetNrows();
30  n = W.GetNcols();
31 
32  Real_t result = 0.0;
33 
34  for (size_t i = 0; i < m; i++) {
35  for (size_t j = 0; j < n; j++) {
36  result += std::abs(W(i,j));
37  }
38  }
39  return result;
40 }
41 
42 //______________________________________________________________________________
43 template<typename Real_t>
44 void TReference<Real_t>::AddL1RegularizationGradients(TMatrixT<Real_t> & A,
45  const TMatrixT<Real_t> & W,
46  Real_t weightDecay)
47 {
48  size_t m,n;
49  m = W.GetNrows();
50  n = W.GetNcols();
51 
52  Real_t sign = 0.0;
53 
54  for (size_t i = 0; i < m; i++) {
55  for (size_t j = 0; j < n; j++) {
56  sign = (W(i,j) > 0.0) ? 1.0 : -1.0;
57  A(i,j) += sign * weightDecay;
58  }
59  }
60 }
61 
62 //______________________________________________________________________________
63 template<typename Real_t>
64 Real_t TReference<Real_t>::L2Regularization(const TMatrixT<Real_t> & W)
65 {
66  size_t m,n;
67  m = W.GetNrows();
68  n = W.GetNcols();
69 
70  Real_t result = 0.0;
71 
72  for (size_t i = 0; i < m; i++) {
73  for (size_t j = 0; j < n; j++) {
74  result += W(i,j) * W(i,j);
75  }
76  }
77  return result;
78 }
79 
80 //______________________________________________________________________________
81 template<typename Real_t>
82 void TReference<Real_t>::AddL2RegularizationGradients(TMatrixT<Real_t> & A,
83  const TMatrixT<Real_t> & W,
84  Real_t weightDecay)
85 {
86  size_t m,n;
87  m = W.GetNrows();
88  n = W.GetNcols();
89 
90  for (size_t i = 0; i < m; i++) {
91  for (size_t j = 0; j < n; j++) {
92  A(i,j) += weightDecay * 2.0 * W(i,j);
93  }
94  }
95 }
96 
97 } // namespace DNN
98 } // namespace TMVA