17 #ifndef TMVA_DNN_ARCHITECTURES_CUDA_DEVICE
18 #define TMVA_DNN_ARCHITECTURES_CUDA_DEVICE
21 #include "vector_types.h"
40 static constexpr
int BlockDimX = 1;
42 static constexpr
int BlockDimY = 32;
44 static constexpr
int BlockSize = BlockDimX * BlockDimY;
48 static dim3 BlockDims1D()
50 return dim3(1, BlockSize);
55 static dim3 BlockDims2D()
57 return dim3(BlockDimX, BlockDimY);
62 template<
typename AMatrix>
63 static dim3 GridDims1D(
const AMatrix &A)
65 int gridDim = A.GetNrows() / TDevice::BlockSize;
66 if ((A.GetNrows() % TDevice::BlockSize) != 0) {
69 return dim3(1, gridDim);
74 static dim3 GridDims2D(
int nrows,
int ncols)
76 int gridDimX = ncols / TDevice::BlockDimX;
77 if ((ncols % TDevice::BlockDimX) != 0)
79 int gridDimY = nrows / TDevice::BlockDimY;
80 if ((nrows % TDevice::BlockDimY) != 0)
82 return dim3(gridDimX, gridDimY);
87 template<
typename AMatrix>
88 static dim3 GridDims2D(
const AMatrix &A)
90 int gridDimX = A.GetNcols() / TDevice::BlockDimX;
91 if ((A.GetNcols() % TDevice::BlockDimX) != 0)
93 int gridDimY = A.GetNrows() / TDevice::BlockDimY;
94 if ((A.GetNrows() % TDevice::BlockDimY) != 0)
96 return dim3(gridDimX, gridDimY);
100 template<
typename AMatrix>
101 static int NThreads(
const AMatrix &A)
103 int gridDimX = A.GetNcols() / TDevice::BlockDimX;
104 if ((A.GetNcols() % TDevice::BlockDimX) != 0) {
107 int gridDimY = A.GetNrows() / TDevice::BlockDimY;
108 if ((A.GetNrows() % TDevice::BlockDimY) != 0) {
111 return gridDimX * gridDimY * TDevice::BlockDimX * TDevice::BlockDimY;