Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
MethodRSNNS.h
Go to the documentation of this file.
1 // @(#)root/tmva/rmva $Id$
2 // Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : RMethodRSNNS *
8  * *
9  * Description: *
10  * R´s Package RSNNS method based on ROOTR *
11  * *
12  **********************************************************************************/
13 
14 #ifndef ROOT_TMVA_RMethodRSNNS
15 #define ROOT_TMVA_RMethodRSNNS
16 
17 //////////////////////////////////////////////////////////////////////////
18 // //
19 // RMethodRSNNS //
20 // //
21 // //
22 //////////////////////////////////////////////////////////////////////////
23 
24 #include "TMVA/RMethodBase.h"
25 
26 namespace TMVA {
27 
28  class Factory; // DSMTEST
29  class Reader; // DSMTEST
30  class DataSetManager; // DSMTEST
31  class Types;
32  class MethodRSNNS : public RMethodBase {
33 
34  public :
35 
36  // constructors
37  MethodRSNNS(const TString &jobName,
38  const TString &methodTitle,
39  DataSetInfo &theData,
40  const TString &theOption = "");
41 
42  MethodRSNNS(DataSetInfo &dsi,
43  const TString &theWeightFile);
44 
45 
46  ~MethodRSNNS(void);
47  void Train();
48  // options treatment
49  void Init();
50  void DeclareOptions();
51  void ProcessOptions();
52  // create ranking
53  const Ranking *CreateRanking()
54  {
55  return NULL; // = 0;
56  }
57 
58 
59  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
60 
61  // performs classifier testing
62  virtual void TestClassification();
63 
64 
65  Double_t GetMvaValue(Double_t *errLower = 0, Double_t *errUpper = 0);
66 
67  using MethodBase::ReadWeightsFromStream;
68  // the actual "weights"
69  virtual void AddWeightsXMLTo(void * /*parent*/) const {} // = 0;
70  virtual void ReadWeightsFromXML(void * /*wghtnode*/) {} // = 0;
71  virtual void ReadWeightsFromStream(std::istream &) {} //= 0; // backward compatibility
72 
73  void ReadModelFromFile();
74 
75  // signal/background classification response for all current set of data
76  virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
77 
78  private :
79  DataSetManager *fDataSetManager; // DSMTEST
80  friend class Factory; // DSMTEST
81  friend class Reader; // DSMTEST
82  protected:
83  UInt_t fMvaCounter;
84  std::vector<Float_t> fProbResultForTrainSig;
85  std::vector<Float_t> fProbResultForTestSig;
86 
87  TString fNetType;//default RMPL
88  //RSNNS Options for all NN methods
89  TString fSize;//number of units in the hidden layer(s)
90  UInt_t fMaxit;//maximum of iterations to learn
91 
92  TString fInitFunc;//the initialization function to use
93  TString fInitFuncParams;//the parameters for the initialization function (type 6 see getSnnsRFunctionTable() in RSNNS package)
94 
95  TString fLearnFunc;//the learning function to use
96  TString fLearnFuncParams;//the parameters for the learning function
97 
98  TString fUpdateFunc;//the update function to use
99  TString fUpdateFuncParams;//the parameters for the update function
100 
101  TString fHiddenActFunc;//the activation function of all hidden units
102  Bool_t fShufflePatterns;//should the patterns be shuffled?
103  Bool_t fLinOut;//sets the activation function of the output units to linear or logistic
104 
105  TString fPruneFunc;//the pruning function to use
106  TString fPruneFuncParams;//the parameters for the pruning function. Unlike the
107  //other functions, these have to be given in a named list. See
108  //the pruning demos for further explanation.
109  std::vector<UInt_t> fFactorNumeric; //factors creations
110  //RSNNS mlp require a numeric factor then background=0 signal=1 from fFactorTrain
111  static Bool_t IsModuleLoaded;
112  ROOT::R::TRFunctionImport predict;
113  ROOT::R::TRFunctionImport mlp;
114  ROOT::R::TRFunctionImport asfactor;
115  ROOT::R::TRObject *fModel;
116  // get help message text
117  void GetHelpMessage() const;
118 
119  ClassDef(MethodRSNNS, 0)
120  };
121 } // namespace TMVA
122 #endif