Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
DataSetInfo.h
Go to the documentation of this file.
1 // // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DataSetInfo *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Contains all the data information *
12  * *
13  * Authors (alphabetical): *
14  * Peter Speckmayer <speckmay@mail.cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - DESY, Germany *
16  * *
17  * Copyright (c) 2008-2011: *
18  * CERN, Switzerland *
19  * MPI-K Heidelberg, Germany *
20  * DESY Hamburg, Germany *
21  * *
22  * Redistribution and use in source and binary forms, with or without *
23  * modification, are permitted according to the terms listed in LICENSE *
24  * (http://tmva.sourceforge.net/LICENSE) *
25  **********************************************************************************/
26 
27 #ifndef ROOT_TMVA_DataSetInfo
28 #define ROOT_TMVA_DataSetInfo
29 
30 //////////////////////////////////////////////////////////////////////////
31 // //
32 // DataSetInfo //
33 // //
34 // Class that contains all the data information //
35 // //
36 //////////////////////////////////////////////////////////////////////////
37 
38 #include <iosfwd>
39 
40 #include "TObject.h"
41 #include "TString.h"
42 #include "TTree.h"
43 #include "TCut.h"
44 #include "TMatrixDfwd.h"
45 
46 #include "TMVA/Types.h"
47 #include "TMVA/VariableInfo.h"
48 #include "TMVA/ClassInfo.h"
49 #include "TMVA/Event.h"
50 
51 class TH2;
52 
53 namespace TMVA {
54 
55  class DataSet;
56  class VariableTransformBase;
57  class MsgLogger;
58  class DataSetManager;
59 
60  class DataSetInfo : public TObject {
61 
62  public:
63 
64  enum { kIsArrayVariable = BIT(15) };
65 
66  DataSetInfo(const TString& name = "Default");
67  virtual ~DataSetInfo();
68 
69  virtual const char* GetName() const { return fName.Data(); }
70 
71  // the data set
72  void ClearDataSet() const;
73  DataSet* GetDataSet() const;
74 
75  // ---
76  // the variable data
77  // ---
78  VariableInfo& AddVariable( const TString& expression, const TString& title = "", const TString& unit = "",
79  Double_t min = 0, Double_t max = 0, char varType='F',
80  Bool_t normalized = kTRUE, void* external = 0 );
81  VariableInfo& AddVariable( const VariableInfo& varInfo );
82 
83  // NEW: add an array of variables (e.g. for image data)
84  void AddVariablesArray(const TString &expression, Int_t size, const TString &title = "", const TString &unit = "",
85  Double_t min = 0, Double_t max = 0, char type = 'F', Bool_t normalized = kTRUE,
86  void *external = 0 );
87 
88  VariableInfo& AddTarget ( const TString& expression, const TString& title, const TString& unit,
89  Double_t min, Double_t max, Bool_t normalized = kTRUE, void* external = 0 );
90  VariableInfo& AddTarget ( const VariableInfo& varInfo );
91 
92  VariableInfo& AddSpectator ( const TString& expression, const TString& title, const TString& unit,
93  Double_t min, Double_t max, char type = 'F', Bool_t normalized = kTRUE, void* external = 0 );
94  VariableInfo& AddSpectator ( const VariableInfo& varInfo );
95 
96  ClassInfo* AddClass ( const TString& className );
97 
98  // accessors
99 
100  // general
101  std::vector<VariableInfo>& GetVariableInfos() { return fVariables; }
102  const std::vector<VariableInfo>& GetVariableInfos() const { return fVariables; }
103  VariableInfo& GetVariableInfo( Int_t i ) { return fVariables.at(i); }
104  const VariableInfo& GetVariableInfo( Int_t i ) const { return fVariables.at(i); }
105 
106  Int_t GetVarArraySize(const TString &expression) const {
107  auto element = fVarArrays.find(expression);
108  return (element != fVarArrays.end()) ? element->second : -1;
109  }
110  Bool_t IsVariableFromArray(Int_t i) const { return GetVariableInfo(i).TestBit(DataSetInfo::kIsArrayVariable); }
111 
112  std::vector<VariableInfo> &GetTargetInfos()
113  {
114  return fTargets;
115  }
116  const std::vector<VariableInfo> &GetTargetInfos() const { return fTargets; }
117  VariableInfo &GetTargetInfo(Int_t i) { return fTargets.at(i); }
118  const VariableInfo &GetTargetInfo(Int_t i) const { return fTargets.at(i); }
119 
120  std::vector<VariableInfo> &GetSpectatorInfos() { return fSpectators; }
121  const std::vector<VariableInfo> &GetSpectatorInfos() const { return fSpectators; }
122  VariableInfo &GetSpectatorInfo(Int_t i) { return fSpectators.at(i); }
123  const VariableInfo &GetSpectatorInfo(Int_t i) const { return fSpectators.at(i); }
124 
125  UInt_t GetNVariables() const { return fVariables.size(); }
126  UInt_t GetNTargets() const { return fTargets.size(); }
127  UInt_t GetNSpectators(bool all = kTRUE) const;
128 
129  const TString &GetNormalization() const { return fNormalization; }
130  void SetNormalization(const TString &norm) { fNormalization = norm; }
131 
132  void SetTrainingSumSignalWeights(Double_t trainingSumSignalWeights)
133  {
134  fTrainingSumSignalWeights = trainingSumSignalWeights;}
135  void SetTrainingSumBackgrWeights(Double_t trainingSumBackgrWeights){fTrainingSumBackgrWeights = trainingSumBackgrWeights;}
136  void SetTestingSumSignalWeights (Double_t testingSumSignalWeights ){fTestingSumSignalWeights = testingSumSignalWeights ;}
137  void SetTestingSumBackgrWeights (Double_t testingSumBackgrWeights ){fTestingSumBackgrWeights = testingSumBackgrWeights ;}
138 
139  Double_t GetTrainingSumSignalWeights();
140  Double_t GetTrainingSumBackgrWeights();
141  Double_t GetTestingSumSignalWeights ();
142  Double_t GetTestingSumBackgrWeights ();
143 
144 
145 
146  // classification information
147  Int_t GetClassNameMaxLength() const;
148  Int_t GetVariableNameMaxLength() const;
149  Int_t GetTargetNameMaxLength() const;
150  ClassInfo* GetClassInfo( Int_t clNum ) const;
151  ClassInfo* GetClassInfo( const TString& name ) const;
152  void PrintClasses() const;
153  UInt_t GetNClasses() const { return fClasses.size(); }
154  Bool_t IsSignal( const Event* ev ) const;
155  std::vector<Float_t>* GetTargetsForMulticlass( const Event* ev );
156  UInt_t GetSignalClassIndex(){return fSignalClass;}
157 
158  // by variable
159  Int_t FindVarIndex( const TString& ) const;
160 
161  // weights
162  const TString GetWeightExpression(Int_t i) const { return GetClassInfo(i)->GetWeight(); }
163  void SetWeightExpression( const TString& exp, const TString& className = "" );
164 
165  // cuts
166  const TCut& GetCut (Int_t i) const { return GetClassInfo(i)->GetCut(); }
167  const TCut& GetCut ( const TString& className ) const { return GetClassInfo(className)->GetCut(); }
168  void SetCut ( const TCut& cut, const TString& className );
169  void AddCut ( const TCut& cut, const TString& className );
170  Bool_t HasCuts() const;
171 
172  std::vector<TString> GetListOfVariables() const;
173 
174  // correlation matrix
175  const TMatrixD* CorrelationMatrix ( const TString& className ) const;
176  void SetCorrelationMatrix ( const TString& className, TMatrixD* matrix );
177  void PrintCorrelationMatrix( const TString& className );
178  TH2* CreateCorrelationMatrixHist( const TMatrixD* m,
179  const TString& hName,
180  const TString& hTitle ) const;
181 
182  // options
183  void SetSplitOptions(const TString& so) { fSplitOptions = so; fNeedsRebuilding = kTRUE; }
184  const TString& GetSplitOptions() const { return fSplitOptions; }
185 
186  // root dir
187  void SetRootDir(TDirectory* d) { fOwnRootDir = d; }
188  TDirectory* GetRootDir() const { return fOwnRootDir; }
189 
190  void SetMsgType( EMsgType t ) const;
191 
192  DataSetManager* GetDataSetManager(){return fDataSetManager;}
193  private:
194 
195  TMVA::DataSetManager* fDataSetManager; // DSMTEST
196  void SetDataSetManager( DataSetManager* dsm ) { fDataSetManager = dsm; } // DSMTEST
197  friend class DataSetManager; // DSMTEST (datasetmanager test)
198 
199  DataSetInfo( const DataSetInfo& ) : TObject() {}
200 
201  void PrintCorrelationMatrix( TTree* theTree );
202 
203  TString fName; // name of the dataset info object
204 
205  mutable DataSet* fDataSet; // dataset, owned by this datasetinfo object
206  mutable Bool_t fNeedsRebuilding; // flag if rebuilding of dataset is needed (after change of cuts, vars, etc.)
207 
208  // expressions/formulas
209  std::vector<VariableInfo> fVariables; // list of variable expressions/internal names
210  std::vector<VariableInfo> fTargets; // list of targets expressions/internal names
211  std::vector<VariableInfo> fSpectators; // list of spectators expressions/internal names
212 
213  // variable arrays
214  std::map<TString, int> fVarArrays;
215 
216  // the classes
217  mutable std::vector<ClassInfo*> fClasses; // name and other infos of the classes
218 
219  TString fNormalization; //
220  TString fSplitOptions; //
221 
222  Double_t fTrainingSumSignalWeights;
223  Double_t fTrainingSumBackgrWeights;
224  Double_t fTestingSumSignalWeights ;
225  Double_t fTestingSumBackgrWeights ;
226 
227 
228 
229  TDirectory* fOwnRootDir; // ROOT output dir
230  Bool_t fVerbose; // Verbosity
231 
232  UInt_t fSignalClass; // index of the class with the name signal
233 
234  std::vector<Float_t>* fTargetsForMulticlass;//-> all targets 0 except the one with index==classNumber
235 
236  mutable MsgLogger* fLogger; //! message logger
237  MsgLogger& Log() const { return *fLogger; }
238 
239  public:
240 
241  ClassDef(DataSetInfo,1);
242  };
243 }
244 
245 #endif