Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Initialization.hxx
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 21/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12  //////////////////////////////////////////////////////////////
13  // Implementation of the DNN initialization methods for the //
14  // multi-threaded CPU backend. //
15  //////////////////////////////////////////////////////////////
16 
17 #include "TRandom3.h"
19 
20 namespace TMVA
21 {
22 namespace DNN
23 {
24 
25 template <typename AFloat_t>
26 TRandom * TCpu<AFloat_t>::fgRandomGen = nullptr;
27 //______________________________________________________________________________
28 template<typename AFloat>
29 void TCpu<AFloat>::SetRandomSeed(size_t seed)
30 {
31  if (!fgRandomGen) fgRandomGen = new TRandom3();
32  fgRandomGen->SetSeed(seed);
33 }
34 template<typename AFloat>
35 TRandom & TCpu<AFloat>::GetRandomGenerator()
36 {
37  if (!fgRandomGen) fgRandomGen = new TRandom3(0);
38  return *fgRandomGen;
39 }
40 
41 //______________________________________________________________________________
42 template<typename AFloat>
43 void TCpu<AFloat>::InitializeGauss(TCpuMatrix<AFloat> & A)
44 {
45  size_t n = A.GetNcols();
46 
47  TRandom & rand = GetRandomGenerator();
48 
49  AFloat sigma = sqrt(2.0 / ((AFloat) n));
50 
51  for (size_t i = 0; i < A.GetSize(); ++i) {
52  A.GetRawDataPointer()[i] = rand.Gaus(0.0, sigma);
53  }
54 }
55 
56 //______________________________________________________________________________
57 template<typename AFloat>
58 void TCpu<AFloat>::InitializeUniform(TCpuMatrix<AFloat> & A)
59 {
60  //size_t m = A.GetNrows();
61  size_t n = A.GetNcols();
62 
63  TRandom & rand = GetRandomGenerator();
64 
65  AFloat range = sqrt(2.0 / ((AFloat) n));
66 
67  // for debugging
68  //range = 1;
69  //rand.SetSeed(111);
70 
71  for (size_t i = 0; i < A.GetSize(); ++i) {
72  A.GetRawDataPointer()[i] = rand.Uniform(-range, range);
73  }
74 }
75 
76  //______________________________________________________________________________
77 /// Truncated normal initialization (Glorot, called also Xavier normal)
78 /// The values are sample with a normal distribution with stddev = sqrt(2/N_input + N_output) and
79 /// values larger than 2 * stddev are discarded
80 /// See Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
81 template<typename AFloat>
82 void TCpu<AFloat>::InitializeGlorotNormal(TCpuMatrix<AFloat> & A)
83 {
84  size_t m,n;
85  // for conv layer weights output m is only output depth. It shouild ne multiplied also by filter sizes
86  // e.g. 9 for a 3x3 filter. But this information is lost if we use Tensors of dims 2
87  m = A.GetNrows();
88  n = A.GetNcols();
89 
90  TRandom & rand = GetRandomGenerator();
91 
92  AFloat sigma = sqrt(2.0 /( ((AFloat) n) + ((AFloat) m)) );
93  // AFloat sigma = sqrt(2.0 /( ((AFloat) m)) );
94 
95  size_t nsize = A.GetSize();
96  for (size_t i = 0; i < nsize; i++) {
97  AFloat value = 0;
98  do {
99  value = rand.Gaus(0.0, sigma);
100  } while (std::abs(value) > 2 * sigma);
101  R__ASSERT(std::abs(value) < 2 * sigma);
102  A.GetRawDataPointer()[i] = value;
103  }
104 }
105 
106 //______________________________________________________________________________
107 /// Sample from a uniform distribution in range [ -lim,+lim] where
108 /// lim = sqrt(6/N_in+N_out).
109 /// This initialization is also called Xavier uniform
110 /// see Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
111 template<typename AFloat>
112 void TCpu<AFloat>::InitializeGlorotUniform(TCpuMatrix<AFloat> & A)
113 {
114  size_t m,n;
115  m = A.GetNrows();
116  n = A.GetNcols();
117 
118  TRandom & rand = GetRandomGenerator();
119 
120  AFloat range = sqrt(6.0 /( ((AFloat) n) + ((AFloat) m)) );
121 
122  size_t nsize = A.GetSize();
123  for (size_t i = 0; i < nsize; i++) {
124  A.GetRawDataPointer()[i] = rand.Uniform(-range, range);
125  }
126 }
127 
128 //______________________________________________________________________________
129 template<typename AFloat>
130 void TCpu<AFloat>::InitializeIdentity(TCpuMatrix<AFloat> & A)
131 {
132  size_t m,n;
133  m = A.GetNrows();
134  n = A.GetNcols();
135 
136  for (size_t i = 0; i < m; i++) {
137  for (size_t j = 0; j < n ; j++) {
138  //A(i,j) = 0.0;
139  A(i,j) = 1.0;
140  }
141 
142  if (i < n) {
143  A(i,i) = 1.0;
144  }
145  }
146 }
147 
148 //______________________________________________________________________________
149 template<typename AFloat>
150 void TCpu<AFloat>::InitializeZero(TCpuMatrix<AFloat> & A)
151 {
152  size_t m,n;
153  m = A.GetNrows();
154  n = A.GetNcols();
155 
156  for (size_t i = 0; i < m; i++) {
157  for (size_t j = 0; j < n ; j++) {
158  A(i,j) = 0.0;
159  }
160  }
161 }
162 
163 } // namespace DNN
164 } // namespace TMVA