30 #ifndef ROOT_TMVA_MethodDNN
31 #define ROOT_TMVA_MethodDNN
72 class MethodDNN :
public MethodBase
74 friend struct TestMethodDNNValidationSize;
76 using Architecture_t = DNN::TReference<Float_t>;
77 using Net_t = DNN::TNet<Architecture_t>;
78 using Matrix_t =
typename Architecture_t::Matrix_t;
79 using Scalar_t =
typename Architecture_t::Scalar_t;
82 using LayoutVector_t = std::vector<std::pair<int, DNN::EActivationFunction>>;
83 using KeyValueVector_t = std::vector<std::map<TString, TString>>;
85 struct TTrainingSettings
89 size_t convergenceSteps;
90 DNN::ERegularization regularization;
91 Double_t learningRate;
94 std::vector<Double_t> dropoutProbabilities;
99 void DeclareOptions();
100 void ProcessOptions();
102 UInt_t GetNumValidationSamples();
108 DNN::EInitialization fWeightInitialization;
109 DNN::EOutputFunction fOutputFunction;
111 TString fLayoutString;
112 TString fErrorStrategy;
113 TString fTrainingStrategyString;
114 TString fWeightInitializationString;
115 TString fArchitectureString;
116 TString fValidationSize;
117 LayoutVector_t fLayout;
118 std::vector<TTrainingSettings> fTrainingSettings;
121 KeyValueVector_t fSettings;
123 ClassDef(MethodDNN,0);
125 static inline void WriteMatrixXML(
void *parent,
const char *name,
126 const TMatrixT<Double_t> &X);
127 static inline void ReadMatrixXML(
void *xml,
const char *name,
128 TMatrixT<Double_t> &X);
131 void MakeClassSpecific( std::ostream&,
const TString& )
const;
132 void GetHelpMessage()
const;
137 MethodDNN(
const TString& jobName,
138 const TString& methodTitle,
139 DataSetInfo& theData,
140 const TString& theOption);
141 MethodDNN(DataSetInfo& theData,
142 const TString& theWeightFile);
143 virtual ~MethodDNN();
145 virtual Bool_t HasAnalysisType(Types::EAnalysisType type,
146 UInt_t numberClasses,
147 UInt_t numberTargets );
148 LayoutVector_t ParseLayoutString(TString layerSpec);
149 KeyValueVector_t ParseKeyValueString(TString parseString,
156 virtual Double_t GetMvaValue( Double_t* err=0, Double_t* errUpper=0 );
157 virtual const std::vector<Float_t>& GetRegressionValues();
158 virtual const std::vector<Float_t>& GetMulticlassValues();
160 using MethodBase::ReadWeightsFromStream;
163 void AddWeightsXMLTo (
void* parent )
const;
166 void ReadWeightsFromStream( std::istream & i );
167 void ReadWeightsFromXML (
void* wghtnode );
170 const Ranking* CreateRanking();
174 inline void MethodDNN::WriteMatrixXML(
void *parent,
176 const TMatrixT<Double_t> &X)
178 std::stringstream matrixStringStream(
"");
179 matrixStringStream.precision( 16 );
181 for (
size_t i = 0; i < (size_t) X.GetNrows(); i++)
183 for (
size_t j = 0; j < (size_t) X.GetNcols(); j++)
185 matrixStringStream << std::scientific << X(i,j) <<
" ";
188 std::string s = matrixStringStream.str();
189 void* matxml = gTools().xmlengine().NewChild(parent, 0, name);
190 gTools().xmlengine().NewAttr(matxml, 0,
"rows",
191 gTools().StringFromInt((
int)X.GetNrows()));
192 gTools().xmlengine().NewAttr(matxml, 0,
"cols",
193 gTools().StringFromInt((
int)X.GetNcols()));
194 gTools().xmlengine().AddRawLine (matxml, s.c_str());
197 inline void MethodDNN::ReadMatrixXML(
void *xml,
199 TMatrixT<Double_t> &X)
201 void *matrixXML = gTools().GetChild(xml, name);
203 gTools().ReadAttr(matrixXML,
"rows", rows);
204 gTools().ReadAttr(matrixXML,
"cols", cols);
206 const char * matrixString = gTools().xmlengine().GetNodeContent(matrixXML);
207 std::stringstream matrixStringStream(matrixString);
209 for (
size_t i = 0; i < rows; i++)
211 for (
size_t j = 0; j < cols; j++)
213 matrixStringStream >> X(i,j);