5 #ifndef ROOT_TMVA_Classification
6 #define ROOT_TMVA_Classification
124 class ResultsClassification;
125 namespace Experimental {
126 class ClassificationResult :
public TObject {
127 friend class Classification;
131 TString fDataLoaderName;
132 std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>> fMvaTrain;
133 std::map<UInt_t, std::vector<std::tuple<Float_t, Float_t, Bool_t>>>
135 std::vector<TString> fClassNames;
137 Bool_t IsMethod(TString methodname, TString methodtitle);
139 Double_t fROCIntegral;
142 ClassificationResult();
143 ClassificationResult(
const ClassificationResult &cr);
144 ~ClassificationResult() {}
146 const TString GetMethodName()
const {
return fMethod.GetValue<TString>(
"MethodName"); }
147 const TString GetMethodTitle()
const {
return fMethod.GetValue<TString>(
"MethodTitle"); }
148 ROCCurve *GetROC(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
149 Double_t GetROCIntegral(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
150 TString GetDataLoaderName() {
return fDataLoaderName; }
151 Bool_t IsCutsMethod() {
return fIsCuts; }
155 TGraph *GetROCGraph(UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
156 ClassificationResult &operator=(
const ClassificationResult &r);
158 ClassDef(ClassificationResult, 3);
161 class Classification :
public Envelope {
162 std::vector<ClassificationResult> fResults;
163 std::vector<IMethod *> fIMethods;
164 Types::EAnalysisType fAnalysisType;
165 Bool_t fCorrelations;
168 explicit Classification(DataLoader *loader, TFile *file, TString options);
169 explicit Classification(DataLoader *loader, TString options);
172 virtual void Train();
173 virtual void TrainMethod(TString methodname, TString methodtitle);
174 virtual void TrainMethod(Types::EMVA method, TString methodtitle);
177 virtual void TestMethod(TString methodname, TString methodtitle);
178 virtual void TestMethod(Types::EMVA method, TString methodtitle);
180 virtual void Evaluate();
182 std::vector<ClassificationResult> &GetResults();
184 MethodBase *GetMethod(TString methodname, TString methodtitle);
187 TString GetMethodOptions(TString methodname, TString methodtitle);
188 Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index);
189 Bool_t IsCutsMethod(TMVA::MethodBase *method);
191 GetROC(TMVA::MethodBase *method, UInt_t iClass = 0, TMVA::Types::ETreeType type = TMVA::Types::kTesting);
192 TMVA::ROCCurve *GetROC(TString methodname, TString methodtitle, UInt_t iClass = 0,
193 TMVA::Types::ETreeType type = TMVA::Types::kTesting);
195 Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass = 0);
197 ClassificationResult &GetResults(TString methodname, TString methodtitle);
198 void CopyFrom(TDirectory *src, TFile *file);
201 ClassDef(Classification, 0);
206 #endif // ROOT_TMVA_Classification