18 #ifndef TMVA_DNN_ARCHITECTURES_CPU_BLAS
19 #define TMVA_DNN_ARCHITECTURES_CPU_BLAS
26 extern "C" void saxpy_(
const int * n,
const float * alpha,
const float * x,
27 const int * incx,
float * y,
const int * incy);
28 extern "C" void daxpy_(
const int * n,
const double * alpha,
const double * x,
29 const int * incx,
double * y,
const int * incy);
30 extern "C" void sger_(
const int * m,
const int * n,
const float * alpha,
31 const float * x,
const int * incx,
32 const float * y,
const int * incy,
33 float * A,
const int * lda);
34 extern "C" void dger_(
const int * m,
const int * n,
const double * alpha,
35 const double * x,
const int * incx,
36 const double * y,
const int * incy,
37 double * A,
const int * lda);
38 extern "C" void sgemv_(
const char * trans,
const int * m,
const int * n,
39 const float * alpha,
const float * A,
const int * lda,
40 const float * x,
const int * incx,
41 const float * beta,
float * y,
const int * incy);
42 extern "C" void dgemv_(
const char * trans,
const int * m,
const int * n,
43 const double * alpha,
const double * A,
const int * lda,
44 const double * x,
const int * incx,
45 const double * beta,
double * y,
const int * incy);
46 extern "C" void dgemm_(
const char * transa,
const char * transb,
47 const int * m,
const int * n,
const int * k,
48 const double * alpha,
const double * A,
const int * lda,
49 const double * B,
const int * ldb,
const double * beta,
50 double * C,
const int * ldc);
51 extern "C" void sgemm_(
const char * transa,
const char * transb,
52 const int * m,
const int * n,
const int * k,
53 const float * alpha,
const float * A,
const int * lda,
54 const float * B,
const int * ldb,
const float * beta,
55 float * C,
const int * ldc);
58 #include "gsl/gsl_cblas.h"
71 template <
typename AReal>
72 inline void Axpy(
const int * n,
const AReal * alpha,
73 const AReal * x,
const int * incx,
74 AReal * y,
const int * incy);
77 template <
typename AReal>
78 inline void Gemv(
const char *trans,
const int * m,
const int * n,
79 const AReal * alpha,
const AReal * A,
const int * lda,
80 const AReal * x,
const int * incx,
81 const AReal * beta, AReal * y,
const int * incy);
84 template <
typename AReal>
85 inline void Gemm(
const char *transa,
const char *transb,
86 const int * m,
const int * n,
const int* k,
87 const AReal * alpha,
const AReal * A,
const int * lda,
88 const AReal * B,
const int * ldb,
const AReal * beta,
89 AReal * C,
const int * ldc);
92 template <
typename AReal>
93 inline void Ger(
const int * m,
const int * n,
const AReal * alpha,
94 const AReal * x,
const int * incx,
95 const AReal * y,
const int * incy,
96 AReal * A,
const int * lda);
100 #ifndef DNN_USE_CBLAS
103 inline void Axpy<double>(
const int * n,
const double * alpha,
104 const double * x,
const int * incx,
105 double * y,
const int * incy)
107 daxpy_(n, alpha, x, incx, y, incy);
111 inline void Axpy<float>(
const int * n,
const float * alpha,
112 const float * x,
const int * incx,
113 float * y,
const int * incy)
115 saxpy_(n, alpha, x, incx, y, incy);
119 inline void Gemv<double>(
const char *trans,
const int * m,
const int * n,
120 const double * alpha,
const double * A,
const int * lda,
121 const double * x,
const int * incx,
122 const double * beta,
double * y,
const int * incy)
124 dgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
128 inline void Gemv<float>(
const char *trans,
const int * m,
const int * n,
129 const float * alpha,
const float * A,
const int * lda,
130 const float * x,
const int * incx,
131 const float * beta,
float * y,
const int * incy)
133 sgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
137 inline void Gemm<double>(
const char *transa,
const char *transb,
138 const int * m,
const int * n,
const int* k,
139 const double * alpha,
const double * A,
const int * lda,
140 const double * B,
const int * ldb,
const double * beta,
141 double * C,
const int * ldc)
143 dgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
147 inline void Gemm<float>(
const char *transa,
const char *transb,
148 const int * m,
const int * n,
const int* k,
149 const float * alpha,
const float * A,
const int * lda,
150 const float * B,
const int * ldb,
const float * beta,
151 float * C,
const int * ldc)
153 sgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
157 inline void Ger<double>(
const int * m,
const int * n,
const double * alpha,
158 const double * x,
const int * incx,
159 const double * y,
const int * incy,
160 double * A,
const int * lda)
162 dger_(m, n, alpha, x, incx, y, incy, A, lda);
166 inline void Ger<float>(
const int * m,
const int * n,
const float * alpha,
167 const float * x,
const int * incx,
168 const float * y,
const int * incy,
169 float * A,
const int * lda)
171 sger_(m, n, alpha, x, incx, y, incy, A, lda);
179 inline void Axpy<double>(
const int * n,
const double * alpha,
180 const double * x,
const int * incx,
181 double * y,
const int * incy)
183 cblas_daxpy(*n, *alpha, x, *incx, y, *incy);
187 inline void Axpy<float>(
const int * n,
const float * alpha,
188 const float * x,
const int * incx,
189 float * y,
const int * incy)
191 cblas_saxpy(*n, *alpha, x, *incx, y, *incy);
195 inline void Gemv<double>(
const char *trans,
const int * m,
const int * n,
196 const double * alpha,
const double * A,
const int * lda,
197 const double * x,
const int * incx,
198 const double * beta,
double * y,
const int * incy)
200 CBLAS_TRANSPOSE kTrans = (*trans ==
'T') ? CblasTrans : CblasNoTrans;
201 cblas_dgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
205 inline void Gemv<float>(
const char *trans,
const int * m,
const int * n,
206 const float * alpha,
const float * A,
const int * lda,
207 const float * x,
const int * incx,
208 const float * beta,
float * y,
const int * incy)
210 CBLAS_TRANSPOSE kTrans = (*trans ==
'T') ? CblasTrans : CblasNoTrans;
211 cblas_sgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
215 inline void Gemm<double>(
const char *transa,
const char *transb,
216 const int * m,
const int * n,
const int* k,
217 const double * alpha,
const double * A,
const int * lda,
218 const double * B,
const int * ldb,
const double * beta,
219 double * C,
const int * ldc)
221 CBLAS_TRANSPOSE kTransA = (*transa ==
'T') ? CblasTrans : CblasNoTrans;
222 CBLAS_TRANSPOSE kTransB = (*transb ==
'T') ? CblasTrans : CblasNoTrans;
223 cblas_dgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
227 inline void Gemm<float>(
const char *transa,
const char *transb,
228 const int * m,
const int * n,
const int* k,
229 const float * alpha,
const float * A,
const int * lda,
230 const float * B,
const int * ldb,
const float * beta,
231 float * C,
const int * ldc)
233 CBLAS_TRANSPOSE kTransA = (*transa ==
'T') ? CblasTrans : CblasNoTrans;
234 CBLAS_TRANSPOSE kTransB = (*transb ==
'T') ? CblasTrans : CblasNoTrans;
235 cblas_sgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
239 inline void Ger<double>(
const int * m,
const int * n,
const double * alpha,
240 const double * x,
const int * incx,
241 const double * y,
const int * incy,
242 double * A,
const int * lda)
244 cblas_dger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);
248 inline void Ger<float>(
const int * m,
const int * n,
const float * alpha,
249 const float * x,
const int * incx,
250 const float * y,
const int * incy,
251 float * A,
const int * lda)
253 cblas_sger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);