Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
MethodCrossValidation.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Kim Albertsson
3 
4 /*************************************************************************
5  * Copyright (C) 2018, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 #ifndef ROOT_TMVA_MethodCrossValidation
13 #define ROOT_TMVA_MethodCrossValidation
14 
15 //////////////////////////////////////////////////////////////////////////
16 // //
17 // MethodCrossValidation //
18 // //
19 //////////////////////////////////////////////////////////////////////////
20 
21 #include "TMVA/CvSplit.h"
22 #include "TMVA/DataSetInfo.h"
23 #include "TMVA/MethodBase.h"
24 
25 #include "TString.h"
26 
27 #include <iostream>
28 #include <memory>
29 
30 namespace TMVA {
31 
32 class CrossValidation;
33 class Ranking;
34 
35 // Looks for serialised methods of the form methodTitle + "_fold" + iFold;
36 class MethodCrossValidation : public MethodBase {
37 
38  friend CrossValidation;
39 
40 public:
41  // constructor for training and reading
42  MethodCrossValidation(const TString &jobName, const TString &methodTitle, DataSetInfo &theData,
43  const TString &theOption = "");
44 
45  // constructor for calculating BDT-MVA using previously generatad decision trees
46  MethodCrossValidation(DataSetInfo &theData, const TString &theWeightFile);
47 
48  virtual ~MethodCrossValidation(void);
49 
50  // optimize tuning parameters
51  // virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString
52  // fitType="FitGA"); virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
53 
54  // training method
55  void Train(void);
56 
57  // revoke training
58  void Reset(void);
59 
60  using MethodBase::ReadWeightsFromStream;
61 
62  // write weights to file
63  void AddWeightsXMLTo(void *parent) const;
64 
65  // read weights from file
66  void ReadWeightsFromStream(std::istream &istr);
67  void ReadWeightsFromXML(void *parent);
68 
69  // write method specific histos to target file
70  void WriteMonitoringHistosToFile(void) const;
71 
72  // calculate the MVA value
73  Double_t GetMvaValue(Double_t *err = 0, Double_t *errUpper = 0);
74  const std::vector<Float_t> &GetMulticlassValues();
75  const std::vector<Float_t> &GetRegressionValues();
76 
77  // the option handling methods
78  void DeclareOptions();
79  void ProcessOptions();
80 
81  // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
82  void MakeClassSpecific(std::ostream &, const TString &) const;
83  void MakeClassSpecificHeader(std::ostream &, const TString &) const;
84 
85  void GetHelpMessage() const;
86 
87  const Ranking *CreateRanking();
88  Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
89 
90 protected:
91  void Init(void);
92  void DeclareCompatibilityOptions();
93 
94 private:
95  TString GetWeightFileNameForFold(UInt_t iFold) const;
96  MethodBase *InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const;
97 
98 private:
99  TString fEncapsulatedMethodName;
100  TString fEncapsulatedMethodTypeName;
101  UInt_t fNumFolds;
102  TString fOutputEnsembling;
103 
104  TString fSplitExprString;
105  std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
106 
107  std::vector<Float_t> fMulticlassValues;
108  std::vector<Float_t> fRegressionValues;
109 
110  std::vector<MethodBase *> fEncapsulatedMethods;
111 
112  // Used for CrossValidation with random splits (not using the
113  // CVSplitCrossValisationExpr functionality) to communicate Event to fold
114  // mapping.
115  std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
116 
117  // for backward compatibility
118  ClassDef(MethodCrossValidation, 0);
119 };
120 
121 } // namespace TMVA
122 
123 #endif