26 #ifndef ROOT_TMVA_MethodPyKeras
27 #define ROOT_TMVA_MethodPyKeras
33 class MethodPyKeras :
public PyMethodBase {
38 MethodPyKeras(
const TString &jobName,
39 const TString &methodTitle,
41 const TString &theOption =
"");
42 MethodPyKeras(DataSetInfo &dsi,
43 const TString &theWeightFile);
48 void DeclareOptions();
49 void ProcessOptions();
53 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t);
55 Double_t GetMvaValue(Double_t *errLower, Double_t *errUpper);
56 std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress);
58 std::vector<Float_t>& GetRegressionValues();
60 std::vector<Float_t>& GetMulticlassValues();
62 const Ranking *CreateRanking() {
return 0; }
63 virtual void TestClassification();
64 virtual void AddWeightsXMLTo(
void*)
const{}
65 virtual void ReadWeightsFromXML(
void*){}
66 virtual void ReadWeightsFromStream(std::istream&) {}
67 virtual void ReadWeightsFromStream(TFile&){}
68 void ReadModelFromFile();
70 void GetHelpMessage()
const;
73 enum EBackendType { kUndefined = -1, kTensorFlow = 0, kTheano = 1, kCNTK = 2 };
76 EBackendType GetKerasBackend();
77 TString GetKerasBackendName();
81 TString fFilenameModel;
82 UInt_t fBatchSize {0};
83 UInt_t fNumEpochs {0};
84 Int_t fNumThreads {0};
86 Bool_t fContinueTraining;
88 Int_t fTriesEarlyStopping;
89 TString fLearningRateSchedule;
91 TString fNumValidationString;
94 bool fModelIsSetup =
false;
95 float* fVals =
nullptr;
96 std::vector<float> fOutput;
99 TString fFilenameTrainedModel;
101 void SetupKerasModel(Bool_t loadTrainedModel);
102 UInt_t GetNumValidationSamples();
104 ClassDef(MethodPyKeras, 0);
109 #endif // ROOT_TMVA_MethodPyKeras