28 #ifndef ROOT_TMVA_MethodDL
29 #define ROOT_TMVA_MethodDL
65 struct TTrainingSettings {
68 size_t convergenceSteps;
70 DNN::ERegularization regularization;
71 DNN::EOptimizer optimizer;
72 TString optimizerName;
73 Double_t learningRate;
76 std::vector<Double_t> dropoutProbabilities;
81 class MethodDL :
public MethodBase {
85 using KeyValueVector_t = std::vector<std::map<TString, TString>>;
94 using ArchitectureImpl_t = TMVA::DNN::TCpu<Float_t>;
99 using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
100 using MatrixImpl_t =
typename ArchitectureImpl_t::Matrix_t;
101 using TensorImpl_t =
typename ArchitectureImpl_t::Tensor_t;
102 using ScalarImpl_t =
typename ArchitectureImpl_t::Scalar_t;
103 using HostBufferImpl_t =
typename ArchitectureImpl_t::HostBuffer_t;
106 void DeclareOptions();
107 void ProcessOptions();
112 void ParseInputLayout();
113 void ParseBatchLayout();
119 template <
typename Architecture_t,
typename Layer_t>
120 void CreateDeepNet(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
121 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets);
123 template <
typename Architecture_t,
typename Layer_t>
124 void ParseDenseLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
125 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
127 template <
typename Architecture_t,
typename Layer_t>
128 void ParseConvLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
129 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
131 template <
typename Architecture_t,
typename Layer_t>
132 void ParseMaxPoolLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
133 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
136 template <
typename Architecture_t,
typename Layer_t>
137 void ParseReshapeLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
138 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
141 template <
typename Architecture_t,
typename Layer_t>
142 void ParseBatchNormLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
143 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
147 template <
typename Architecture_t,
typename Layer_t>
148 void ParseRnnLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
149 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
151 template <
typename Architecture_t,
typename Layer_t>
152 void ParseLstmLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
153 std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
156 template <
typename Architecture_t>
161 template <
typename Architecture_t>
162 std::vector<Double_t> PredictDeepNet(Long64_t firstEvt, Long64_t lastEvt,
size_t batchSize, Bool_t logProgress);
165 UInt_t GetNumValidationSamples();
170 std::vector<size_t> fInputShape;
180 DNN::EInitialization fWeightInitialization;
181 DNN::EOutputFunction fOutputFunction;
182 DNN::ELossFunction fLossFunction;
184 TString fInputLayoutString;
185 TString fBatchLayoutString;
186 TString fLayoutString;
187 TString fErrorStrategy;
188 TString fTrainingStrategyString;
189 TString fWeightInitializationString;
190 TString fArchitectureString;
191 TString fNumValidationString;
195 KeyValueVector_t fSettings;
196 std::vector<TTrainingSettings> fTrainingSettings;
198 TensorImpl_t fXInput;
199 HostBufferImpl_t fXInputBuffer;
200 std::unique_ptr<MatrixImpl_t> fYHat;
201 std::unique_ptr<DeepNetImpl_t> fNet;
204 ClassDef(MethodDL, 0);
208 void GetHelpMessage()
const;
210 virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress);
215 MethodDL(
const TString &jobName,
const TString &methodTitle, DataSetInfo &theData,
const TString &theOption);
218 MethodDL(DataSetInfo &theData,
const TString &theWeightFile);
225 KeyValueVector_t ParseKeyValueString(TString parseString, TString blockDelim, TString tokenDelim);
228 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
233 Double_t GetMvaValue(Double_t *err = 0, Double_t *errUpper = 0);
234 virtual const std::vector<Float_t>& GetRegressionValues();
235 virtual const std::vector<Float_t>& GetMulticlassValues();
238 using MethodBase::ReadWeightsFromStream;
239 void AddWeightsXMLTo(
void *parent)
const;
240 void ReadWeightsFromXML(
void *wghtnode);
241 void ReadWeightsFromStream(std::istream &);
244 const Ranking *CreateRanking();
247 size_t GetInputDepth()
const {
return fInputShape[1]; }
248 size_t GetInputHeight()
const {
return fInputShape[2]; }
249 size_t GetInputWidth()
const {
return fInputShape[3]; }
250 size_t GetInputDim()
const {
return fInputShape.size() - 2; }
251 std::vector<size_t> GetInputShape()
const {
return fInputShape; }
253 size_t GetBatchSize()
const {
return fInputShape[0]; }
254 size_t GetBatchDepth()
const {
return fBatchDepth; }
255 size_t GetBatchHeight()
const {
return fBatchHeight; }
256 size_t GetBatchWidth()
const {
return fBatchWidth; }
258 const DeepNetImpl_t & GetDeepNet()
const {
return *fNet; }
260 DNN::EInitialization GetWeightInitialization()
const {
return fWeightInitialization; }
261 DNN::EOutputFunction GetOutputFunction()
const {
return fOutputFunction; }
262 DNN::ELossFunction GetLossFunction()
const {
return fLossFunction; }
264 TString GetInputLayoutString()
const {
return fInputLayoutString; }
265 TString GetBatchLayoutString()
const {
return fBatchLayoutString; }
266 TString GetLayoutString()
const {
return fLayoutString; }
267 TString GetErrorStrategyString()
const {
return fErrorStrategy; }
268 TString GetTrainingStrategyString()
const {
return fTrainingStrategyString; }
269 TString GetWeightInitializationString()
const {
return fWeightInitializationString; }
270 TString GetArchitectureString()
const {
return fArchitectureString; }
272 const std::vector<TTrainingSettings> &GetTrainingSettings()
const {
return fTrainingSettings; }
273 std::vector<TTrainingSettings> &GetTrainingSettings() {
return fTrainingSettings; }
274 const KeyValueVector_t &GetKeyValueSettings()
const {
return fSettings; }
275 KeyValueVector_t &GetKeyValueSettings() {
return fSettings; }
278 void SetInputDepth (
int inputDepth) { fInputShape[1] = inputDepth; }
279 void SetInputHeight(
int inputHeight) { fInputShape[2] = inputHeight; }
280 void SetInputWidth (
int inputWidth) { fInputShape[3] = inputWidth; }
281 void SetInputShape (std::vector<size_t> inputShape) { fInputShape = std::move(inputShape); }
283 void SetBatchSize (
size_t batchSize) { fInputShape[0] = batchSize; }
284 void SetBatchDepth (
size_t batchDepth) { fBatchDepth = batchDepth; }
285 void SetBatchHeight(
size_t batchHeight) { fBatchHeight = batchHeight; }
286 void SetBatchWidth (
size_t batchWidth) { fBatchWidth = batchWidth; }
288 void SetWeightInitialization(DNN::EInitialization weightInitialization)
290 fWeightInitialization = weightInitialization;
292 void SetOutputFunction (DNN::EOutputFunction outputFunction) { fOutputFunction = outputFunction; }
293 void SetErrorStrategyString (TString errorStrategy) { fErrorStrategy = errorStrategy; }
294 void SetTrainingStrategyString (TString trainingStrategyString) { fTrainingStrategyString = trainingStrategyString; }
295 void SetWeightInitializationString(TString weightInitializationString)
297 fWeightInitializationString = weightInitializationString;
299 void SetArchitectureString (TString architectureString) { fArchitectureString = architectureString; }
300 void SetLayoutString (TString layoutString) { fLayoutString = layoutString; }