Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
MethodDNN.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Peter Speckmayer
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodDNN *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * NeuralNetwork *
12  * *
13  * Authors (alphabetical): *
14  * Peter Speckmayer <peter.speckmayer@gmx.at> - CERN, Switzerland *
15  * Simon Pfreundschuh <s.pfreundschuh@gmail.com> - CERN, Switzerland *
16  * *
17  * Copyright (c) 2005-2015: *
18  * CERN, Switzerland *
19  * U. of Victoria, Canada *
20  * MPI-K Heidelberg, Germany *
21  * U. of Bonn, Germany *
22  * *
23  * Redistribution and use in source and binary forms, with or without *
24  * modification, are permitted according to the terms listed in LICENSE *
25  * (http://tmva.sourceforge.net/LICENSE) *
26  **********************************************************************************/
27 
28 //#pragma once
29 
30 #ifndef ROOT_TMVA_MethodDNN
31 #define ROOT_TMVA_MethodDNN
32 
33 //////////////////////////////////////////////////////////////////////////
34 // //
35 // MethodDNN //
36 // //
37 // Neural Network implementation //
38 // //
39 //////////////////////////////////////////////////////////////////////////
40 
41 #include <vector>
42 #include "TString.h"
43 #include "TTree.h"
44 #include "TRandom3.h"
45 #include "TH1F.h"
46 #include "TMVA/MethodBase.h"
47 #include "TMVA/NeuralNet.h"
48 
49 #include "TMVA/Tools.h"
50 
51 #include "TMVA/DNN/Net.h"
52 #include "TMVA/DNN/Minimizers.h"
54 
55 #ifdef R__HAS_TMVACPU
56 #define DNNCPU
57 #endif
58 #ifdef R__HAS_TMVAGPU
59 #define DNNCUDA
60 #endif
61 
62 #ifdef DNNCPU
64 #endif
65 
66 #ifdef DNNCUDA
68 #endif
69 
70 namespace TMVA {
71 
72 class MethodDNN : public MethodBase
73 {
74  friend struct TestMethodDNNValidationSize;
75 
76  using Architecture_t = DNN::TReference<Float_t>;
77  using Net_t = DNN::TNet<Architecture_t>;
78  using Matrix_t = typename Architecture_t::Matrix_t;
79  using Scalar_t = typename Architecture_t::Scalar_t;
80 
81 private:
82  using LayoutVector_t = std::vector<std::pair<int, DNN::EActivationFunction>>;
83  using KeyValueVector_t = std::vector<std::map<TString, TString>>;
84 
85  struct TTrainingSettings
86  {
87  size_t batchSize;
88  size_t testInterval;
89  size_t convergenceSteps;
90  DNN::ERegularization regularization;
91  Double_t learningRate;
92  Double_t momentum;
93  Double_t weightDecay;
94  std::vector<Double_t> dropoutProbabilities;
95  bool multithreading;
96  };
97 
98  // the option handling methods
99  void DeclareOptions();
100  void ProcessOptions();
101 
102  UInt_t GetNumValidationSamples();
103 
104  // general helper functions
105  void Init();
106 
107  Net_t fNet;
108  DNN::EInitialization fWeightInitialization;
109  DNN::EOutputFunction fOutputFunction;
110 
111  TString fLayoutString;
112  TString fErrorStrategy;
113  TString fTrainingStrategyString;
114  TString fWeightInitializationString;
115  TString fArchitectureString;
116  TString fValidationSize;
117  LayoutVector_t fLayout;
118  std::vector<TTrainingSettings> fTrainingSettings;
119  bool fResume;
120 
121  KeyValueVector_t fSettings;
122 
123  ClassDef(MethodDNN,0); // neural network
124 
125  static inline void WriteMatrixXML(void *parent, const char *name,
126  const TMatrixT<Double_t> &X);
127  static inline void ReadMatrixXML(void *xml, const char *name,
128  TMatrixT<Double_t> &X);
129 protected:
130 
131  void MakeClassSpecific( std::ostream&, const TString& ) const;
132  void GetHelpMessage() const;
133 
134 public:
135 
136  // Standard Constructors
137  MethodDNN(const TString& jobName,
138  const TString& methodTitle,
139  DataSetInfo& theData,
140  const TString& theOption);
141  MethodDNN(DataSetInfo& theData,
142  const TString& theWeightFile);
143  virtual ~MethodDNN();
144 
145  virtual Bool_t HasAnalysisType(Types::EAnalysisType type,
146  UInt_t numberClasses,
147  UInt_t numberTargets );
148  LayoutVector_t ParseLayoutString(TString layerSpec);
149  KeyValueVector_t ParseKeyValueString(TString parseString,
150  TString blockDelim,
151  TString tokenDelim);
152  void Train();
153  void TrainGpu();
154  void TrainCpu();
155 
156  virtual Double_t GetMvaValue( Double_t* err=0, Double_t* errUpper=0 );
157  virtual const std::vector<Float_t>& GetRegressionValues();
158  virtual const std::vector<Float_t>& GetMulticlassValues();
159 
160  using MethodBase::ReadWeightsFromStream;
161 
162  // write weights to stream
163  void AddWeightsXMLTo ( void* parent ) const;
164 
165  // read weights from stream
166  void ReadWeightsFromStream( std::istream & i );
167  void ReadWeightsFromXML ( void* wghtnode );
168 
169  // ranking of input variables
170  const Ranking* CreateRanking();
171 
172 };
173 
174 inline void MethodDNN::WriteMatrixXML(void *parent,
175  const char *name,
176  const TMatrixT<Double_t> &X)
177 {
178  std::stringstream matrixStringStream("");
179  matrixStringStream.precision( 16 );
180 
181  for (size_t i = 0; i < (size_t) X.GetNrows(); i++)
182  {
183  for (size_t j = 0; j < (size_t) X.GetNcols(); j++)
184  {
185  matrixStringStream << std::scientific << X(i,j) << " ";
186  }
187  }
188  std::string s = matrixStringStream.str();
189  void* matxml = gTools().xmlengine().NewChild(parent, 0, name);
190  gTools().xmlengine().NewAttr(matxml, 0, "rows",
191  gTools().StringFromInt((int)X.GetNrows()));
192  gTools().xmlengine().NewAttr(matxml, 0, "cols",
193  gTools().StringFromInt((int)X.GetNcols()));
194  gTools().xmlengine().AddRawLine (matxml, s.c_str());
195 }
196 
197 inline void MethodDNN::ReadMatrixXML(void *xml,
198  const char *name,
199  TMatrixT<Double_t> &X)
200 {
201  void *matrixXML = gTools().GetChild(xml, name);
202  size_t rows, cols;
203  gTools().ReadAttr(matrixXML, "rows", rows);
204  gTools().ReadAttr(matrixXML, "cols", cols);
205 
206  const char * matrixString = gTools().xmlengine().GetNodeContent(matrixXML);
207  std::stringstream matrixStringStream(matrixString);
208 
209  for (size_t i = 0; i < rows; i++)
210  {
211  for (size_t j = 0; j < cols; j++)
212  {
213  matrixStringStream >> X(i,j);
214  }
215  }
216 }
217 } // namespace TMVA
218 
219 #endif