25 template <
typename Real_t>
26 TRandom * TReference<Real_t>::fgRandomGen =
nullptr;
28 template<
typename Real_t>
29 void TReference<Real_t>::SetRandomSeed(
size_t seed)
31 if (!fgRandomGen) fgRandomGen =
new TRandom3();
32 fgRandomGen->SetSeed(seed);
34 template<
typename Real_t>
35 TRandom & TReference<Real_t>::GetRandomGenerator()
37 if (!fgRandomGen) fgRandomGen =
new TRandom3(0);
42 template<
typename Real_t>
43 void TReference<Real_t>::InitializeGauss(TMatrixT<Real_t> & A)
49 TRandom & rand = GetRandomGenerator();
51 Real_t sigma = sqrt(2.0 / ((Real_t) n));
53 for (
size_t i = 0; i < m; i++) {
54 for (
size_t j = 0; j < n; j++) {
55 A(i,j) = rand.Gaus(0.0, sigma);
61 template<
typename Real_t>
62 void TReference<Real_t>::InitializeUniform(TMatrixT<Real_t> & A)
68 TRandom & rand = GetRandomGenerator();
70 Real_t range = sqrt(2.0 / ((Real_t) n));
72 for (
size_t i = 0; i < m; i++) {
73 for (
size_t j = 0; j < n; j++) {
74 A(i,j) = rand.Uniform(-range, range);
84 template<
typename Real_t>
85 void TReference<Real_t>::InitializeGlorotNormal(TMatrixT<Real_t> & A)
91 TRandom & rand = GetRandomGenerator();
93 Real_t sigma = sqrt(2.0 /( ((Real_t) n) + ((Real_t) m)) );
95 for (
size_t i = 0; i < m; i++) {
96 for (
size_t j = 0; j < n; j++) {
97 Real_t value = rand.Gaus(0.0, sigma);
98 if ( std::abs(value) > 2*sigma)
continue;
99 A(i,j) = rand.Gaus(0.0, sigma);
109 template<
typename Real_t>
110 void TReference<Real_t>::InitializeGlorotUniform(TMatrixT<Real_t> & A)
116 TRandom & rand = GetRandomGenerator();
118 Real_t range = sqrt(6.0 /( ((Real_t) n) + ((Real_t) m)) );
120 for (
size_t i = 0; i < m; i++) {
121 for (
size_t j = 0; j < n; j++) {
122 A(i,j) = rand.Uniform(-range, range);
128 template<
typename Real_t>
129 void TReference<Real_t>::InitializeIdentity(TMatrixT<Real_t> & A)
135 for (
size_t i = 0; i < m; i++) {
136 for (
size_t j = 0; j < n ; j++) {
147 template<
typename Real_t>
148 void TReference<Real_t>::InitializeZero(TMatrixT<Real_t> & A)
154 for (
size_t i = 0; i < m; i++) {
155 for (
size_t j = 0; j < n ; j++) {