Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
ContextHandles.h
Go to the documentation of this file.
1 // @(#)root/tmva/tmva/dnn:$Id$
2 // Author: Simon Pfreundschuh 20/06/16
3 
4 /*************************************************************************
5  * Copyright (C) 2016, Simon Pfreundschuh *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 /////////////////////////////////////////////////////////////////////
13 // Contains function enums for activation and output functions, as //
14 // well as generic evaluation functions, that delegate the call to //
15 // the corresponding evaluation kernel. //
16 /////////////////////////////////////////////////////////////////////
17 
18 #ifndef TMVA_DNN_CNN_DESCRIPTORS
19 #define TMVA_DNN_CNN_DESCRIPTORS
20 
21 #include <stddef.h>
22 
23 namespace TMVA
24 {
25 namespace DNN
26 {
27 
28 struct TDescriptors {
29  virtual ~TDescriptors() {}
30 };
31 struct TWorkspace {};
32 
33 template <typename Layer_t>
34 struct TDNNGenDescriptors : public TMVA::DNN::TDescriptors {
35  using HelperDescriptor_t = typename Layer_t::HelperDescriptor_t;
36 
37  HelperDescriptor_t HelperDescriptor;
38 };
39 
40 namespace CNN {
41 
42 //______________________________________________________________________________
43 //
44 // Keeps the descriptors for the CNN
45 //______________________________________________________________________________
46 
47 template <typename Layer_t>
48 struct TCNNDescriptors : public TMVA::DNN::TDescriptors {
49  using LayerDescriptor_t = typename Layer_t::LayerDescriptor_t; // Main layer operation
50  using HelperDescriptor_t = typename Layer_t::HelperDescriptor_t; // Used to define possible helpers for the layers (e.g. activations)
51  using WeightsDescriptor_t = typename Layer_t::WeightsDescriptor_t; // The weights that are modified (e.g filters)
52 
53  LayerDescriptor_t LayerDescriptor;
54  HelperDescriptor_t HelperDescriptor;
55  WeightsDescriptor_t WeightsDescriptor;
56 };
57 
58 template <typename Layer_t>
59 struct TCNNWorkspace : public TMVA::DNN::TWorkspace {
60  using AlgorithmForward_t = typename Layer_t::AlgorithmForward_t; // Forward layer operation
61  using AlgorithmBackward_t = typename Layer_t::AlgorithmBackward_t; // Backward layer operation
62  using AlgorithmHelper_t = typename Layer_t::AlgorithmHelper_t; // Used for weight grad backward pass
63 
64  using ReduceTensorDescriptor_t = typename Layer_t::ReduceTensorDescriptor_t;
65 
66  using AlgorithmDataType_t = typename Layer_t::AlgorithmDataType_t;
67 
68  AlgorithmForward_t AlgorithmForward;
69  AlgorithmBackward_t AlgorithmBackward;
70  AlgorithmHelper_t HelperAlgorithm;
71 
72  AlgorithmDataType_t DataType;
73 
74  size_t *ForwardWorkspace;
75  size_t *BackwardWorkspace;
76  size_t *HelperWorkspace;
77 
78  void *fReductionWorkspace = nullptr;
79 
80  size_t ForwardWorkspaceSize;
81  size_t BackwardWorkspaceSize;
82  size_t HelperWorkspaceSize;
83  size_t fReductionWorkspaceSize = 0;
84 
85  ReduceTensorDescriptor_t fReduceTensorDesc;
86 };
87 
88 } // namespace CNN
89 } // namespace DNN
90 } // namespace TMVA
91 
92 #endif