25 template <
typename AFloat_t>
26 TRandom * TCpu<AFloat_t>::fgRandomGen =
nullptr;
28 template<
typename AFloat>
29 void TCpu<AFloat>::SetRandomSeed(
size_t seed)
31 if (!fgRandomGen) fgRandomGen =
new TRandom3();
32 fgRandomGen->SetSeed(seed);
34 template<
typename AFloat>
35 TRandom & TCpu<AFloat>::GetRandomGenerator()
37 if (!fgRandomGen) fgRandomGen =
new TRandom3(0);
42 template<
typename AFloat>
43 void TCpu<AFloat>::InitializeGauss(TCpuMatrix<AFloat> & A)
45 size_t n = A.GetNcols();
47 TRandom & rand = GetRandomGenerator();
49 AFloat sigma = sqrt(2.0 / ((AFloat) n));
51 for (
size_t i = 0; i < A.GetSize(); ++i) {
52 A.GetRawDataPointer()[i] = rand.Gaus(0.0, sigma);
57 template<
typename AFloat>
58 void TCpu<AFloat>::InitializeUniform(TCpuMatrix<AFloat> & A)
61 size_t n = A.GetNcols();
63 TRandom & rand = GetRandomGenerator();
65 AFloat range = sqrt(2.0 / ((AFloat) n));
71 for (
size_t i = 0; i < A.GetSize(); ++i) {
72 A.GetRawDataPointer()[i] = rand.Uniform(-range, range);
81 template<
typename AFloat>
82 void TCpu<AFloat>::InitializeGlorotNormal(TCpuMatrix<AFloat> & A)
90 TRandom & rand = GetRandomGenerator();
92 AFloat sigma = sqrt(2.0 /( ((AFloat) n) + ((AFloat) m)) );
95 size_t nsize = A.GetSize();
96 for (
size_t i = 0; i < nsize; i++) {
99 value = rand.Gaus(0.0, sigma);
100 }
while (std::abs(value) > 2 * sigma);
101 R__ASSERT(std::abs(value) < 2 * sigma);
102 A.GetRawDataPointer()[i] = value;
111 template<
typename AFloat>
112 void TCpu<AFloat>::InitializeGlorotUniform(TCpuMatrix<AFloat> & A)
118 TRandom & rand = GetRandomGenerator();
120 AFloat range = sqrt(6.0 /( ((AFloat) n) + ((AFloat) m)) );
122 size_t nsize = A.GetSize();
123 for (
size_t i = 0; i < nsize; i++) {
124 A.GetRawDataPointer()[i] = rand.Uniform(-range, range);
129 template<
typename AFloat>
130 void TCpu<AFloat>::InitializeIdentity(TCpuMatrix<AFloat> & A)
136 for (
size_t i = 0; i < m; i++) {
137 for (
size_t j = 0; j < n ; j++) {
149 template<
typename AFloat>
150 void TCpu<AFloat>::InitializeZero(TCpuMatrix<AFloat> & A)
156 for (
size_t i = 0; i < m; i++) {
157 for (
size_t j = 0; j < n ; j++) {