27 #ifndef ROOT_TMVA_DataSetInfo
28 #define ROOT_TMVA_DataSetInfo
56 class VariableTransformBase;
60 class DataSetInfo :
public TObject {
64 enum { kIsArrayVariable = BIT(15) };
66 DataSetInfo(
const TString& name =
"Default");
67 virtual ~DataSetInfo();
69 virtual const char* GetName()
const {
return fName.Data(); }
72 void ClearDataSet()
const;
73 DataSet* GetDataSet()
const;
78 VariableInfo& AddVariable(
const TString& expression,
const TString& title =
"",
const TString& unit =
"",
79 Double_t min = 0, Double_t max = 0,
char varType=
'F',
80 Bool_t normalized = kTRUE,
void* external = 0 );
81 VariableInfo& AddVariable(
const VariableInfo& varInfo );
84 void AddVariablesArray(
const TString &expression, Int_t size,
const TString &title =
"",
const TString &unit =
"",
85 Double_t min = 0, Double_t max = 0,
char type =
'F', Bool_t normalized = kTRUE,
88 VariableInfo& AddTarget (
const TString& expression,
const TString& title,
const TString& unit,
89 Double_t min, Double_t max, Bool_t normalized = kTRUE,
void* external = 0 );
90 VariableInfo& AddTarget (
const VariableInfo& varInfo );
92 VariableInfo& AddSpectator (
const TString& expression,
const TString& title,
const TString& unit,
93 Double_t min, Double_t max,
char type =
'F', Bool_t normalized = kTRUE,
void* external = 0 );
94 VariableInfo& AddSpectator (
const VariableInfo& varInfo );
96 ClassInfo* AddClass (
const TString& className );
101 std::vector<VariableInfo>& GetVariableInfos() {
return fVariables; }
102 const std::vector<VariableInfo>& GetVariableInfos()
const {
return fVariables; }
103 VariableInfo& GetVariableInfo( Int_t i ) {
return fVariables.at(i); }
104 const VariableInfo& GetVariableInfo( Int_t i )
const {
return fVariables.at(i); }
106 Int_t GetVarArraySize(
const TString &expression)
const {
107 auto element = fVarArrays.find(expression);
108 return (element != fVarArrays.end()) ? element->second : -1;
110 Bool_t IsVariableFromArray(Int_t i)
const {
return GetVariableInfo(i).TestBit(DataSetInfo::kIsArrayVariable); }
112 std::vector<VariableInfo> &GetTargetInfos()
116 const std::vector<VariableInfo> &GetTargetInfos()
const {
return fTargets; }
117 VariableInfo &GetTargetInfo(Int_t i) {
return fTargets.at(i); }
118 const VariableInfo &GetTargetInfo(Int_t i)
const {
return fTargets.at(i); }
120 std::vector<VariableInfo> &GetSpectatorInfos() {
return fSpectators; }
121 const std::vector<VariableInfo> &GetSpectatorInfos()
const {
return fSpectators; }
122 VariableInfo &GetSpectatorInfo(Int_t i) {
return fSpectators.at(i); }
123 const VariableInfo &GetSpectatorInfo(Int_t i)
const {
return fSpectators.at(i); }
125 UInt_t GetNVariables()
const {
return fVariables.size(); }
126 UInt_t GetNTargets()
const {
return fTargets.size(); }
127 UInt_t GetNSpectators(
bool all = kTRUE)
const;
129 const TString &GetNormalization()
const {
return fNormalization; }
130 void SetNormalization(
const TString &norm) { fNormalization = norm; }
132 void SetTrainingSumSignalWeights(Double_t trainingSumSignalWeights)
134 fTrainingSumSignalWeights = trainingSumSignalWeights;}
135 void SetTrainingSumBackgrWeights(Double_t trainingSumBackgrWeights){fTrainingSumBackgrWeights = trainingSumBackgrWeights;}
136 void SetTestingSumSignalWeights (Double_t testingSumSignalWeights ){fTestingSumSignalWeights = testingSumSignalWeights ;}
137 void SetTestingSumBackgrWeights (Double_t testingSumBackgrWeights ){fTestingSumBackgrWeights = testingSumBackgrWeights ;}
139 Double_t GetTrainingSumSignalWeights();
140 Double_t GetTrainingSumBackgrWeights();
141 Double_t GetTestingSumSignalWeights ();
142 Double_t GetTestingSumBackgrWeights ();
147 Int_t GetClassNameMaxLength()
const;
148 Int_t GetVariableNameMaxLength()
const;
149 Int_t GetTargetNameMaxLength()
const;
150 ClassInfo* GetClassInfo( Int_t clNum )
const;
151 ClassInfo* GetClassInfo(
const TString& name )
const;
152 void PrintClasses()
const;
153 UInt_t GetNClasses()
const {
return fClasses.size(); }
154 Bool_t IsSignal(
const Event* ev )
const;
155 std::vector<Float_t>* GetTargetsForMulticlass(
const Event* ev );
156 UInt_t GetSignalClassIndex(){
return fSignalClass;}
159 Int_t FindVarIndex(
const TString& )
const;
162 const TString GetWeightExpression(Int_t i)
const {
return GetClassInfo(i)->GetWeight(); }
163 void SetWeightExpression(
const TString& exp,
const TString& className =
"" );
166 const TCut& GetCut (Int_t i)
const {
return GetClassInfo(i)->GetCut(); }
167 const TCut& GetCut (
const TString& className )
const {
return GetClassInfo(className)->GetCut(); }
168 void SetCut (
const TCut& cut,
const TString& className );
169 void AddCut (
const TCut& cut,
const TString& className );
170 Bool_t HasCuts()
const;
172 std::vector<TString> GetListOfVariables()
const;
175 const TMatrixD* CorrelationMatrix (
const TString& className )
const;
176 void SetCorrelationMatrix (
const TString& className, TMatrixD* matrix );
177 void PrintCorrelationMatrix(
const TString& className );
178 TH2* CreateCorrelationMatrixHist(
const TMatrixD* m,
179 const TString& hName,
180 const TString& hTitle )
const;
183 void SetSplitOptions(
const TString& so) { fSplitOptions = so; fNeedsRebuilding = kTRUE; }
184 const TString& GetSplitOptions()
const {
return fSplitOptions; }
187 void SetRootDir(TDirectory* d) { fOwnRootDir = d; }
188 TDirectory* GetRootDir()
const {
return fOwnRootDir; }
190 void SetMsgType( EMsgType t )
const;
192 DataSetManager* GetDataSetManager(){
return fDataSetManager;}
195 TMVA::DataSetManager* fDataSetManager;
196 void SetDataSetManager( DataSetManager* dsm ) { fDataSetManager = dsm; }
197 friend class DataSetManager;
199 DataSetInfo(
const DataSetInfo& ) : TObject() {}
201 void PrintCorrelationMatrix( TTree* theTree );
205 mutable DataSet* fDataSet;
206 mutable Bool_t fNeedsRebuilding;
209 std::vector<VariableInfo> fVariables;
210 std::vector<VariableInfo> fTargets;
211 std::vector<VariableInfo> fSpectators;
214 std::map<TString, int> fVarArrays;
217 mutable std::vector<ClassInfo*> fClasses;
219 TString fNormalization;
220 TString fSplitOptions;
222 Double_t fTrainingSumSignalWeights;
223 Double_t fTrainingSumBackgrWeights;
224 Double_t fTestingSumSignalWeights ;
225 Double_t fTestingSumBackgrWeights ;
229 TDirectory* fOwnRootDir;
234 std::vector<Float_t>* fTargetsForMulticlass;
236 mutable MsgLogger* fLogger;
237 MsgLogger& Log()
const {
return *fLogger; }
241 ClassDef(DataSetInfo,1);