19 #ifndef TMVA_DNN_ARCHITECTURES_CUDA_CUDAMATRIX
20 #define TMVA_DNN_ARCHITECTURES_CUDA_CUDAMATRIX
23 #include "cuda_runtime.h"
24 #include "cublas_v2.h"
25 #include "curand_kernel.h"
30 #define CUDACHECK(ans) {cudaError((ans), __FILE__, __LINE__); }
38 inline void cudaError(cudaError_t code,
const char *file,
int line,
bool abort=
true);
53 template<
typename AFloat>
54 class TCudaDeviceReference
58 AFloat * fDevicePointer;
62 TCudaDeviceReference(AFloat * devicePointer);
66 void operator=(
const TCudaDeviceReference &other);
67 void operator=(AFloat value);
68 void operator+=(AFloat value);
69 void operator-=(AFloat value);
97 template<
typename AFloat>
104 static size_t fInstances;
105 static cublasHandle_t fCublasHandle;
106 static AFloat * fDeviceReturn;
107 static AFloat * fOnes;
108 static size_t fNOnes;
109 static curandState_t * fCurandStates;
110 static size_t fNCurandStates;
115 TCudaDeviceBuffer<AFloat> fElementBuffer;
119 static Bool_t gInitializeCurand;
121 static AFloat * GetOnes() {
return fOnes;}
124 TCudaMatrix(
size_t i,
size_t j);
125 TCudaMatrix(
const TMatrixT<AFloat> &);
126 TCudaMatrix(TCudaDeviceBuffer<AFloat> buffer,
size_t m,
size_t n);
128 TCudaMatrix(
const TCudaMatrix &) =
default;
129 TCudaMatrix( TCudaMatrix &&) =
default;
130 TCudaMatrix & operator=(
const TCudaMatrix &) =
default;
131 TCudaMatrix & operator=( TCudaMatrix &&) =
default;
132 ~TCudaMatrix() =
default;
135 operator TMatrixT<AFloat>()
const;
137 inline cudaStream_t GetComputeStream()
const;
138 inline void SetComputeStream(cudaStream_t stream);
142 inline static void ResetDeviceReturn(AFloat value = 0.0);
145 inline static AFloat GetDeviceReturn();
147 inline static AFloat * GetDeviceReturnPointer() {
return fDeviceReturn;}
148 inline static curandState_t * GetCurandStatesPointer() {
return fCurandStates;}
152 inline void Synchronize(
const TCudaMatrix &)
const;
154 static size_t GetNDim() {
return 2;}
155 size_t GetNrows()
const {
return fNRows;}
156 size_t GetNcols()
const {
return fNCols;}
157 size_t GetNoElements()
const {
return fNRows * fNCols;}
159 const AFloat * GetDataPointer()
const {
return fElementBuffer;}
160 AFloat * GetDataPointer() {
return fElementBuffer;}
161 const cublasHandle_t & GetCublasHandle()
const {
return fCublasHandle;}
163 inline TCudaDeviceBuffer<AFloat> GetDeviceBuffer()
const {
return fElementBuffer;}
168 TCudaDeviceReference<AFloat> operator()(
size_t i,
size_t j)
const;
171 TMatrixT<AFloat> mat(*
this);
176 cudaMemset(GetDataPointer(), 0,
sizeof(AFloat) * GetNoElements());
186 void InitializeCuda();
187 void InitializeCurandStates();
194 inline void cudaError(cudaError_t code,
const char *file,
int line,
bool abort)
196 if (code != cudaSuccess)
198 fprintf(stderr,
"CUDA Error: %s %s %d\n", cudaGetErrorString(code), file, line);
199 if (abort) exit(code);
204 template<
typename AFloat>
205 TCudaDeviceReference<AFloat>::TCudaDeviceReference(AFloat * devicePointer)
206 : fDevicePointer(devicePointer)
212 template<
typename AFloat>
213 TCudaDeviceReference<AFloat>::operator AFloat()
216 cudaMemcpy(& buffer, fDevicePointer,
sizeof(AFloat),
217 cudaMemcpyDeviceToHost);
222 template<
typename AFloat>
223 void TCudaDeviceReference<AFloat>::operator=(
const TCudaDeviceReference &other)
225 cudaMemcpy(fDevicePointer, other.fDevicePointer,
sizeof(AFloat),
226 cudaMemcpyDeviceToDevice);
230 template<
typename AFloat>
231 void TCudaDeviceReference<AFloat>::operator=(AFloat value)
233 AFloat buffer = value;
234 cudaMemcpy(fDevicePointer, & buffer,
sizeof(AFloat),
235 cudaMemcpyHostToDevice);
239 template<
typename AFloat>
240 void TCudaDeviceReference<AFloat>::operator+=(AFloat value)
243 cudaMemcpy(& buffer, fDevicePointer,
sizeof(AFloat),
244 cudaMemcpyDeviceToHost);
246 cudaMemcpy(fDevicePointer, & buffer,
sizeof(AFloat),
247 cudaMemcpyHostToDevice);
251 template<
typename AFloat>
252 void TCudaDeviceReference<AFloat>::operator-=(AFloat value)
255 cudaMemcpy(& buffer, fDevicePointer,
sizeof(AFloat),
256 cudaMemcpyDeviceToHost);
258 cudaMemcpy(fDevicePointer, & buffer,
sizeof(AFloat),
259 cudaMemcpyHostToDevice);
263 template<
typename AFloat>
264 inline cudaStream_t TCudaMatrix<AFloat>::GetComputeStream()
const
266 return fElementBuffer.GetComputeStream();
270 template<
typename AFloat>
271 inline void TCudaMatrix<AFloat>::SetComputeStream(cudaStream_t stream)
273 return fElementBuffer.SetComputeStream(stream);
277 template<
typename AFloat>
278 inline void TCudaMatrix<AFloat>::Synchronize(
const TCudaMatrix &A)
const
281 cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
282 cudaEventRecord(event, A.GetComputeStream());
283 cudaStreamWaitEvent(fElementBuffer.GetComputeStream(), event, 0);
284 cudaEventDestroy(event);
288 template<
typename AFloat>
289 inline void TCudaMatrix<AFloat>::ResetDeviceReturn(AFloat value)
291 AFloat buffer = value;
292 cudaMemcpy(fDeviceReturn, & buffer,
sizeof(AFloat), cudaMemcpyHostToDevice);
296 template<
typename AFloat>
297 inline AFloat TCudaMatrix<AFloat>::GetDeviceReturn()
300 cudaMemcpy(& buffer, fDeviceReturn,
sizeof(AFloat), cudaMemcpyDeviceToHost);
305 template<
typename AFloat>
306 TCudaDeviceReference<AFloat> TCudaMatrix<AFloat>::operator()(
size_t i,
size_t j)
const
308 AFloat * elementPointer = fElementBuffer;
309 elementPointer += j * fNRows + i;
310 return TCudaDeviceReference<AFloat>(elementPointer);