19 #ifndef TMVA_DNN_ARCHITECTURES_CUDA_CUDATENSOR
20 #define TMVA_DNN_ARCHITECTURES_CUDA_CUDATENSOR
28 #include "RConfigure.h"
36 #define CUDNNCHECK(ans) {cudnnError((ans), __FILE__, __LINE__); }
45 namespace Experimental {
47 enum class MemoryLayout : uint8_t {
56 using MemoryLayout = TMVA::Experimental::MemoryLayout;
63 inline void cudnnError(cudnnStatus_t status,
const char *file,
int line,
bool abort=
true)
65 if (status != CUDNN_STATUS_SUCCESS) {
66 fprintf(stderr,
"CUDNN Error: %s %s %d\n", cudnnGetErrorString(status), file, line);
82 template<
typename AFloat>
87 using Shape_t = std::vector<size_t>;
88 using MemoryLayout = TMVA::Experimental:: MemoryLayout;
89 using Scalar_t = AFloat;
95 struct TensorDescriptor {
96 cudnnTensorDescriptor_t fCudnnDesc;
99 static std::vector<cudnnHandle_t> fCudnnHandle;
101 static cudnnDataType_t fDataType;
103 struct TensorDescriptor {
111 static std::vector<int> fInstances;
123 std::shared_ptr<TensorDescriptor> fTensorDescriptor;
124 TCudaDeviceBuffer<AFloat> fElementBuffer;
126 MemoryLayout fMemoryLayout;
137 TCudaTensor(
const AFloat * data,
138 const std::vector<size_t> & shape,
139 MemoryLayout memlayout = MemoryLayout::ColumnMajor,
140 int deviceIndx = 0,
int streamIndx = 0);
141 TCudaTensor(TCudaDeviceBuffer<AFloat> buffer,
142 const std::vector<size_t> & shape,
143 MemoryLayout memlayout = MemoryLayout::ColumnMajor,
144 int deviceIndx = 0,
int streamIndx = 0);
145 TCudaTensor(
const std::vector<size_t> & shape,
146 MemoryLayout memlayout = MemoryLayout::ColumnMajor,
147 int deviceIndx = 0,
int streamIndx = 0);
149 TCudaTensor(
size_t bsize,
size_t csize,
size_t hwsize, MemoryLayout memlayout = MemoryLayout::ColumnMajor,
int deviceIndx = 0,
int streamIndx = 0) :
150 TCudaTensor( (memlayout == MemoryLayout::ColumnMajor) ? Shape_t({ csize, hwsize, bsize}) : Shape_t({ bsize, csize, hwsize }) , memlayout,
151 deviceIndx, streamIndx)
154 TCudaTensor(
size_t bsize,
size_t csize,
size_t hsize,
size_t wsize, MemoryLayout memlayout = MemoryLayout::ColumnMajor,
int deviceIndx = 0,
int streamIndx = 0) :
156 TCudaTensor( {bsize, csize, hsize, wsize}, memlayout, deviceIndx, streamIndx)
158 if (memlayout == MemoryLayout::ColumnMajor)
159 *
this = TCudaTensor(fElementBuffer, { csize, hsize, wsize, bsize}, memlayout, deviceIndx, streamIndx);
162 TCudaTensor(
size_t n,
size_t m, MemoryLayout memlayout = MemoryLayout::ColumnMajor,
int deviceIndx = 0,
int streamIndx = 0) :
164 TCudaTensor( {n, m}, memlayout, deviceIndx, streamIndx)
167 TCudaTensor(
const TCudaMatrix<AFloat> & m,
size_t dim = 2);
169 TCudaTensor(TCudaDeviceBuffer<AFloat> buffer,
size_t n,
size_t m) :
170 TCudaTensor( buffer, {n,m}, MemoryLayout::ColumnMajor ,0,0) {}
172 TCudaTensor(
const TCudaTensor &) =
default;
173 TCudaTensor(TCudaTensor &&) =
default;
174 TCudaTensor & operator=(
const TCudaTensor &) =
default;
175 TCudaTensor & operator=( TCudaTensor &&) =
default;
179 operator TMatrixT<AFloat>()
const;
182 MemoryLayout GetLayout()
const {
return fMemoryLayout; }
184 const Shape_t & GetShape()
const {
return fShape;}
185 const Shape_t & GetStrides()
const {
return fStrides;}
186 size_t GetDimAt(
size_t i)
const {
return fShape[i];}
187 size_t GetNDim()
const {
return fNDim;}
188 size_t GetSize()
const {
return fSize;}
190 const AFloat * GetDataPointer()
const {
return fElementBuffer;}
191 AFloat * GetDataPointer() {
return fElementBuffer;}
192 const AFloat * GetData()
const {
return fElementBuffer;}
193 AFloat * GetData() {
return fElementBuffer;}
195 const AFloat * GetDataPointerAt(
size_t i )
const {
196 return (
const_cast<TCudaDeviceBuffer<AFloat>&
>(fElementBuffer)).GetSubBuffer(i * GetFirstStride(), GetFirstStride() ); }
197 AFloat * GetDataPointerAt(
size_t i ) {
return fElementBuffer.GetSubBuffer(i * GetFirstStride(), GetFirstStride() ); }
200 const TCudaDeviceBuffer<AFloat> & GetDeviceBuffer()
const {
return fElementBuffer;}
201 TCudaDeviceBuffer<AFloat> & GetDeviceBuffer() {
return fElementBuffer;}
204 const cudnnHandle_t & GetCudnnHandle()
const {
return fCudnnHandle[fStreamIndx];}
205 const cudnnTensorDescriptor_t & GetTensorDescriptor()
const {
return fTensorDescriptor->fCudnnDesc;}
206 static cudnnDataType_t GetDataType() {
return fDataType; }
209 cudaStream_t GetComputeStream()
const {
210 return fElementBuffer.GetComputeStream();
212 void SetComputeStream(cudaStream_t stream) {
213 fElementBuffer.SetComputeStream(stream);
216 bool isEqual (TCudaTensor<AFloat> & other) {
218 if (fSize != other.GetSize())
return false;
221 std::unique_ptr<AFloat[]> hostBufferThis(
new AFloat[fSize]);
222 std::unique_ptr<AFloat[]> hostBufferOther(
new AFloat[fSize]);
223 cudaMemcpy(hostBufferThis.get(), fElementBuffer, fSize *
sizeof(AFloat),
224 cudaMemcpyDeviceToHost);
225 cudaMemcpy(hostBufferOther.get(), other.GetDeviceBuffer(), fSize *
sizeof(AFloat),
226 cudaMemcpyDeviceToHost);
228 for (
size_t i = 0; i < fSize; i++) {
229 if (hostBufferThis[i] != hostBufferOther[i])
return false;
234 bool isEqual (
const AFloat * hostBufferOther,
size_t otherSize) {
235 if (fSize != otherSize)
return false;
238 std::unique_ptr<AFloat[]> hostBufferThis(
new AFloat[fSize]);
239 cudaMemcpy(hostBufferThis.get(), fElementBuffer, fSize *
sizeof(AFloat),
240 cudaMemcpyDeviceToHost);
242 for (
size_t i = 0; i < fSize; i++) {
243 if (hostBufferThis[i] != hostBufferOther[i])
return false;
249 void Print(
const char * name =
"Tensor",
bool truncate =
false)
const;
251 void PrintShape(
const char * name=
"Tensor")
const;
254 cudaMemset(GetDataPointer(), 0,
sizeof(AFloat) * GetSize());
257 void SetConstVal(
const AFloat constVal) {
258 TCudaHostBuffer<AFloat> hostBuffer(fSize);
259 hostBuffer.SetConstVal(constVal);
260 fElementBuffer.CopyFrom(hostBuffer);
270 size_t GetFirstSize()
const {
271 return (GetLayout() == MemoryLayout::ColumnMajor ) ? fShape.back() : fShape.front(); }
272 size_t GetFirstStride()
const {
273 return (GetLayout() == MemoryLayout::ColumnMajor ) ? fStrides.back() : fStrides.front(); }
275 size_t GetCSize()
const {
276 if (fNDim == 2)
return 1;
277 return (GetLayout() == MemoryLayout::ColumnMajor ) ? fShape.front() : fShape[1] ;
279 size_t GetHSize()
const {
280 if (fNDim == 2)
return fShape[0];
281 if (fNDim == 3)
return (GetLayout() == MemoryLayout::ColumnMajor ) ? fShape[0] : fShape[1] ;
282 if (fNDim >= 4)
return (GetLayout() == MemoryLayout::ColumnMajor ) ? fShape[2] : fShape[2] ;
285 size_t GetWSize()
const {
286 if (fNDim == 2)
return fShape[1];
287 if (fNDim == 3)
return (GetLayout() == MemoryLayout::ColumnMajor ) ? fShape[1] : fShape[2] ;
288 if (fNDim == 4)
return (GetLayout() == MemoryLayout::ColumnMajor ) ? fShape[3] : fShape[3] ;
295 size_t GetNrows()
const {
return (GetLayout() == MemoryLayout::ColumnMajor ) ? fStrides.back() : fShape.front();}
296 size_t GetNcols()
const {
return (GetLayout() == MemoryLayout::ColumnMajor ) ? fShape.back() : fStrides.front(); }
300 TCudaMatrix<AFloat> GetMatrix()
const {
301 if (fNDim == 2 || (fNDim == 3 && GetFirstSize() == 1))
302 return TCudaMatrix<AFloat>(fElementBuffer, GetHSize(), GetWSize());
307 bool caseNM11 =
true;
308 for (
size_t i = 2; i < fNDim; ++i) caseNM11 &= fShape[i] == 1;
310 return (GetLayout() == MemoryLayout::ColumnMajor ) ?
311 TCudaMatrix<AFloat>(fElementBuffer, fShape[0], fShape[1]) :
312 TCudaMatrix<AFloat>(fElementBuffer, fShape[1], fShape[0]);
314 bool case11NM =
true;
315 for (
size_t i = 0; i < fNDim-2; ++i) case11NM &= fShape[i] == 1;
317 return (GetLayout() == MemoryLayout::ColumnMajor ) ?
318 TCudaMatrix<AFloat>(fElementBuffer, fShape[fNDim-2], fShape[fNDim-1]) :
319 TCudaMatrix<AFloat>(fElementBuffer, fShape[fNDim-1], fShape[fNDim-2]);
323 return TCudaMatrix<AFloat>();
328 static inline std::vector<std::size_t> ComputeStridesFromShape(
const std::vector<std::size_t> &shape,
329 bool rowmajorLayout);
331 void ReshapeInPlace(
const Shape_t & newShape) {
333 fStrides = ComputeStridesFromShape(fShape, fMemoryLayout == MemoryLayout::RowMajor);
334 fNDim = fShape.size();
336 size_t newSize = (fMemoryLayout == MemoryLayout::RowMajor) ? fStrides.front() * fShape.front() : fStrides.back() * fShape.back();
337 R__ASSERT(newSize <= fSize);
340 SetTensorDescriptor();
343 TCudaTensor<AFloat> Reshape(
const Shape_t & newShape)
const {
344 TCudaTensor<AFloat> tmp(*
this);
347 tmp.fTensorDescriptor.reset(
new TensorDescriptor() );
349 CUDNNCHECK(cudnnCreateTensorDescriptor(&(tmp.fTensorDescriptor->fCudnnDesc)));
351 tmp.ReshapeInPlace(newShape);
355 void SetTensorDescriptor();
360 TCudaTensor<AFloat> At(
size_t i)
const {
361 Shape_t sliced_shape = (GetLayout() == MemoryLayout::RowMajor)
362 ? Shape_t(fShape.begin() + 1, fShape.end()) :
363 Shape_t(fShape.begin(), fShape.end() - 1);
366 size_t buffsize = (GetLayout() == MemoryLayout::RowMajor) ?
367 fStrides.front() : fStrides.back();
369 size_t offset = i * buffsize;
371 return TCudaTensor<AFloat>((
const_cast<TCudaDeviceBuffer<AFloat>&
>(fElementBuffer)).GetSubBuffer(offset, buffsize), sliced_shape, GetLayout());
376 TCudaDeviceReference<AFloat> operator()(
size_t i,
size_t j)
const
380 size_t nrows = GetNrows();
381 size_t ncols = GetNcols();
383 size_t offset = (GetLayout() == MemoryLayout::RowMajor) ?
384 i * ncols + j : j * nrows + i;
386 AFloat * elementPointer = fElementBuffer + offset;
387 return TCudaDeviceReference<AFloat>(elementPointer);
390 TCudaDeviceReference<AFloat> operator()(
size_t i,
size_t j,
size_t k)
const
396 size_t offset = (GetLayout() == MemoryLayout::RowMajor) ?
397 i * fStrides[0] + j * fStrides[1] + k :
398 i * fStrides[2] + k * fStrides[1] + j;
400 AFloat * elementPointer = fElementBuffer + offset;
402 return TCudaDeviceReference<AFloat>(elementPointer);
405 TCudaDeviceReference<AFloat> operator()(
size_t i,
size_t j,
size_t k,
size_t l)
const
411 size_t offset = (GetLayout() == MemoryLayout::RowMajor) ?
412 i * fStrides[0] + j * fStrides[1] + k * fStrides[2] + l:
413 l * fStrides[3] + k * fStrides[2] + j * fStrides[1] + i;
415 AFloat * elementPointer = fElementBuffer + offset;
417 return TCudaDeviceReference<AFloat>(elementPointer);
429 void InitializeCuda();
430 void InitializeCurandStates();