Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Arithmetic.hxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Ravi Kiran S
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Ravi Kiran S *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 //////////////////////////////////////////////////////////////////
13 // Implementation of the Helper arithmetic functions for the //
14 // reference implementation. //
15 //////////////////////////////////////////////////////////////////
16 
18 #include <math.h>
19 
20 namespace TMVA {
21 namespace DNN {
22 
23 //______________________________________________________________________________
24 template <typename AReal>
25 void TReference<AReal>::SumColumns(TMatrixT<AReal> &B, const TMatrixT<AReal> &A)
26 {
27  B = 0.0;
28  for (Int_t i = 0; i < A.GetNrows(); i++) {
29  for (Int_t j = 0; j < A.GetNcols(); j++) {
30  B(0, j) += A(i, j);
31  }
32  }
33 }
34 
35 //______________________________________________________________________________
36 template <typename AReal>
37 void TReference<AReal>::Hadamard(TMatrixT<AReal> &A, const TMatrixT<AReal> &B)
38 {
39  for (Int_t i = 0; i < A.GetNrows(); i++) {
40  for (Int_t j = 0; j < A.GetNcols(); j++) {
41  A(i, j) *= B(i, j);
42  }
43  }
44 }
45 
46 //______________________________________________________________________________
47 template <typename AReal>
48 void TReference<AReal>::ConstAdd(TMatrixT<AReal> &A, AReal beta)
49 {
50  for (Int_t i = 0; i < A.GetNrows(); i++) {
51  for (Int_t j = 0; j < A.GetNcols(); j++) {
52  A(i, j) += beta;
53  }
54  }
55 }
56 
57 //______________________________________________________________________________
58 template <typename AReal>
59 void TReference<AReal>::ConstMult(TMatrixT<AReal> &A, AReal beta)
60 {
61  for (Int_t i = 0; i < A.GetNrows(); i++) {
62  for (Int_t j = 0; j < A.GetNcols(); j++) {
63  A(i, j) *= beta;
64  }
65  }
66 }
67 
68 //______________________________________________________________________________
69 template <typename AReal>
70 void TReference<AReal>::ReciprocalElementWise(TMatrixT<AReal> &A)
71 {
72  for (Int_t i = 0; i < A.GetNrows(); i++) {
73  for (Int_t j = 0; j < A.GetNcols(); j++) {
74  A(i, j) = 1.0 / A(i, j);
75  }
76  }
77 }
78 
79 //______________________________________________________________________________
80 template <typename AReal>
81 void TReference<AReal>::SquareElementWise(TMatrixT<AReal> &A)
82 {
83  for (Int_t i = 0; i < A.GetNrows(); i++) {
84  for (Int_t j = 0; j < A.GetNcols(); j++) {
85  A(i, j) *= A(i, j);
86  }
87  }
88 }
89 
90 //______________________________________________________________________________
91 template <typename AReal>
92 void TReference<AReal>::SqrtElementWise(TMatrixT<AReal> &A)
93 {
94  for (Int_t i = 0; i < A.GetNrows(); i++) {
95  for (Int_t j = 0; j < A.GetNcols(); j++) {
96  A(i, j) = sqrt(A(i, j));
97  }
98  }
99 }
100 /// Adam updates
101 //____________________________________________________________________________
102 template<typename AReal>
103 void TReference<AReal>::AdamUpdate(TMatrixT<AReal> &A, const TMatrixT<AReal> & M, const TMatrixT<AReal> & V, AReal alpha, AReal eps)
104 {
105  // ADAM update the weights.
106  // Weight = Weight - alpha * M / (sqrt(V) + epsilon)
107  AReal * a = A.GetMatrixArray();
108  const AReal * m = M.GetMatrixArray();
109  const AReal * v = V.GetMatrixArray();
110  for (int index = 0; index < A.GetNoElements() ; ++index) {
111  a[index] = a[index] - alpha * m[index]/( sqrt(v[index]) + eps);
112  }
113 }
114 
115 //____________________________________________________________________________
116 template<typename AReal>
117 void TReference<AReal>::AdamUpdateFirstMom(TMatrixT<AReal> &A, const TMatrixT<AReal> & B, AReal beta)
118 {
119  // First momentum weight gradient update for ADAM
120  // Mt = beta1 * Mt-1 + (1-beta1) * WeightGradients
121  AReal * a = A.GetMatrixArray();
122  const AReal * b = B.GetMatrixArray();
123  for (int index = 0; index < A.GetNoElements() ; ++index) {
124  a[index] = beta * a[index] + (1.-beta) * b[index];
125  }
126 }
127 //____________________________________________________________________________
128 template<typename AReal>
129 void TReference<AReal>::AdamUpdateSecondMom(TMatrixT<AReal> &A, const TMatrixT<AReal> & B, AReal beta)
130 {
131  // Second momentum weight gradient update for ADAM
132  // Vt = beta2 * Vt-1 + (1-beta2) * WeightGradients^2
133  AReal * a = A.GetMatrixArray();
134  const AReal * b = B.GetMatrixArray();
135  for (int index = 0; index < A.GetNoElements() ; ++index) {
136  a[index] = beta * a[index] + (1.-beta) * b[index] * b[index];
137  }
138 }
139 
140 } // namespace DNN
141 } // namespace TMVA