Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Blas.h
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 20/07/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 ///////////////////////////////////////////////////////////////////
13 // Declarations of the BLAS functions used for the forward and //
14 // backward propagation of activation through neural networks on //
15 // CPUs. //
16 ///////////////////////////////////////////////////////////////////
17 
18 #ifndef TMVA_DNN_ARCHITECTURES_CPU_BLAS
19 #define TMVA_DNN_ARCHITECTURES_CPU_BLAS
20 
21 #include <iostream>
22 
23 #ifndef DNN_USE_CBLAS
24 // External Library Routines
25 //____________________________________________________________________________
26 extern "C" void saxpy_(const int * n, const float * alpha, const float * x,
27  const int * incx, float * y, const int * incy);
28 extern "C" void daxpy_(const int * n, const double * alpha, const double * x,
29  const int * incx, double * y, const int * incy);
30 extern "C" void sger_(const int * m, const int * n, const float * alpha,
31  const float * x, const int * incx,
32  const float * y, const int * incy,
33  float * A, const int * lda);
34 extern "C" void dger_(const int * m, const int * n, const double * alpha,
35  const double * x, const int * incx,
36  const double * y, const int * incy,
37  double * A, const int * lda);
38 extern "C" void sgemv_(const char * trans, const int * m, const int * n,
39  const float * alpha, const float * A, const int * lda,
40  const float * x, const int * incx,
41  const float * beta, float * y, const int * incy);
42 extern "C" void dgemv_(const char * trans, const int * m, const int * n,
43  const double * alpha, const double * A, const int * lda,
44  const double * x, const int * incx,
45  const double * beta, double * y, const int * incy);
46 extern "C" void dgemm_(const char * transa, const char * transb,
47  const int * m, const int * n, const int * k,
48  const double * alpha, const double * A, const int * lda,
49  const double * B, const int * ldb, const double * beta,
50  double * C, const int * ldc);
51 extern "C" void sgemm_(const char * transa, const char * transb,
52  const int * m, const int * n, const int * k,
53  const float * alpha, const float * A, const int * lda,
54  const float * B, const int * ldb, const float * beta,
55  float * C, const int * ldc);
56 
57 #else
58 #include "gsl/gsl_cblas.h"
59 #endif
60 
61 namespace TMVA
62 {
63 namespace DNN
64 {
65 namespace Blas
66 {
67 
68 // Type-Generic Wrappers
69 //____________________________________________________________________________
70 /** Add the vector \p x scaled by \p alpha to \p y scaled by \beta */
71 template <typename AReal>
72 inline void Axpy(const int * n, const AReal * alpha,
73  const AReal * x, const int * incx,
74  AReal * y, const int * incy);
75 
76 /** Multiply the vector \p x with the matrix \p A and store the result in \p y. */
77 template <typename AReal>
78 inline void Gemv(const char *trans, const int * m, const int * n,
79  const AReal * alpha, const AReal * A, const int * lda,
80  const AReal * x, const int * incx,
81  const AReal * beta, AReal * y, const int * incy);
82 
83 /** Multiply the matrix \p A with the matrix \p B and store the result in \p C. */
84 template <typename AReal>
85 inline void Gemm(const char *transa, const char *transb,
86  const int * m, const int * n, const int* k,
87  const AReal * alpha, const AReal * A, const int * lda,
88  const AReal * B, const int * ldb, const AReal * beta,
89  AReal * C, const int * ldc);
90 
91 /** Add the outer product of \p x and \p y to the matrix \p A. */
92 template <typename AReal>
93 inline void Ger(const int * m, const int * n, const AReal * alpha,
94  const AReal * x, const int * incx,
95  const AReal * y, const int * incy,
96  AReal * A, const int * lda);
97 
98 // Specializations
99 //____________________________________________________________________________
100 #ifndef DNN_USE_CBLAS
101 
102 template<>
103 inline void Axpy<double>(const int * n, const double * alpha,
104  const double * x, const int * incx,
105  double * y, const int * incy)
106 {
107  daxpy_(n, alpha, x, incx, y, incy);
108 }
109 
110 template<>
111 inline void Axpy<float>(const int * n, const float * alpha,
112  const float * x, const int * incx,
113  float * y, const int * incy)
114 {
115  saxpy_(n, alpha, x, incx, y, incy);
116 }
117 
118 template<>
119 inline void Gemv<double>(const char *trans, const int * m, const int * n,
120  const double * alpha, const double * A, const int * lda,
121  const double * x, const int * incx,
122  const double * beta, double * y, const int * incy)
123 {
124  dgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
125 }
126 
127 template<>
128 inline void Gemv<float>(const char *trans, const int * m, const int * n,
129  const float * alpha, const float * A, const int * lda,
130  const float * x, const int * incx,
131  const float * beta, float * y, const int * incy)
132 {
133  sgemv_(trans, m, n, alpha, A, lda, x, incx, beta, y, incy);
134 }
135 
136 template<>
137 inline void Gemm<double>(const char *transa, const char *transb,
138  const int * m, const int * n, const int* k,
139  const double * alpha, const double * A, const int * lda,
140  const double * B, const int * ldb, const double * beta,
141  double * C, const int * ldc)
142 {
143  dgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
144 }
145 
146 template<>
147 inline void Gemm<float>(const char *transa, const char *transb,
148  const int * m, const int * n, const int* k,
149  const float * alpha, const float * A, const int * lda,
150  const float * B, const int * ldb, const float * beta,
151  float * C, const int * ldc)
152 {
153  sgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
154 }
155 
156 template <>
157 inline void Ger<double>(const int * m, const int * n, const double * alpha,
158  const double * x, const int * incx,
159  const double * y, const int * incy,
160  double * A, const int * lda)
161 {
162  dger_(m, n, alpha, x, incx, y, incy, A, lda);
163 }
164 
165 template <>
166 inline void Ger<float>(const int * m, const int * n, const float * alpha,
167  const float * x, const int * incx,
168  const float * y, const int * incy,
169  float * A, const int * lda)
170 {
171  sger_(m, n, alpha, x, incx, y, incy, A, lda);
172 }
173 
174 #else // use cblas
175 //--------------------------------------------------------
176 // cblas implementation
177 //-----------------------------------------------------------
178 template<>
179 inline void Axpy<double>(const int * n, const double * alpha,
180  const double * x, const int * incx,
181  double * y, const int * incy)
182 {
183  cblas_daxpy(*n, *alpha, x, *incx, y, *incy);
184 }
185 
186 template<>
187 inline void Axpy<float>(const int * n, const float * alpha,
188  const float * x, const int * incx,
189  float * y, const int * incy)
190 {
191  cblas_saxpy(*n, *alpha, x, *incx, y, *incy);
192 }
193 
194 template<>
195 inline void Gemv<double>(const char *trans, const int * m, const int * n,
196  const double * alpha, const double * A, const int * lda,
197  const double * x, const int * incx,
198  const double * beta, double * y, const int * incy)
199 {
200  CBLAS_TRANSPOSE kTrans = (*trans == 'T') ? CblasTrans : CblasNoTrans;
201  cblas_dgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
202 }
203 
204 template<>
205 inline void Gemv<float>(const char *trans, const int * m, const int * n,
206  const float * alpha, const float * A, const int * lda,
207  const float * x, const int * incx,
208  const float * beta, float * y, const int * incy)
209 {
210  CBLAS_TRANSPOSE kTrans = (*trans == 'T') ? CblasTrans : CblasNoTrans;
211  cblas_sgemv(CblasColMajor, kTrans, *m, *n, *alpha, A, *lda, x, *incx, *beta, y, *incy);
212 }
213 
214 template<>
215 inline void Gemm<double>(const char *transa, const char *transb,
216  const int * m, const int * n, const int* k,
217  const double * alpha, const double * A, const int * lda,
218  const double * B, const int * ldb, const double * beta,
219  double * C, const int * ldc)
220 {
221  CBLAS_TRANSPOSE kTransA = (*transa == 'T') ? CblasTrans : CblasNoTrans;
222  CBLAS_TRANSPOSE kTransB = (*transb == 'T') ? CblasTrans : CblasNoTrans;
223  cblas_dgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
224 }
225 
226 template<>
227 inline void Gemm<float>(const char *transa, const char *transb,
228  const int * m, const int * n, const int* k,
229  const float * alpha, const float * A, const int * lda,
230  const float * B, const int * ldb, const float * beta,
231  float * C, const int * ldc)
232 {
233  CBLAS_TRANSPOSE kTransA = (*transa == 'T') ? CblasTrans : CblasNoTrans;
234  CBLAS_TRANSPOSE kTransB = (*transb == 'T') ? CblasTrans : CblasNoTrans;
235  cblas_sgemm(CblasColMajor, kTransA, kTransB, *m, *n, *k, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
236 }
237 
238 template <>
239 inline void Ger<double>(const int * m, const int * n, const double * alpha,
240  const double * x, const int * incx,
241  const double * y, const int * incy,
242  double * A, const int * lda)
243 {
244  cblas_dger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);
245 }
246 
247 template <>
248 inline void Ger<float>(const int * m, const int * n, const float * alpha,
249  const float * x, const int * incx,
250  const float * y, const int * incy,
251  float * A, const int * lda)
252 {
253  cblas_sger(CblasColMajor, *m, *n, *alpha, x, *incx, y, *incy, A, *lda);
254 }
255 
256 #endif
257 
258 } // namespace Blas
259 } // namespace DNN
260 } // namespace TMVA
261 
262 #endif