Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
DataSet.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DataSet *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Contains all the data information *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
16  * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
17  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18  * *
19  * Copyright (c) 2006: *
20  * CERN, Switzerland *
21  * U. of Victoria, Canada *
22  * MPI-K Heidelberg, Germany *
23  * *
24  * Redistribution and use in source and binary forms, with or without *
25  * modification, are permitted according to the terms listed in LICENSE *
26  * (http://tmva.sourceforge.net/LICENSE) *
27  **********************************************************************************/
28 
29 #ifndef ROOT_TMVA_DataSet
30 #define ROOT_TMVA_DataSet
31 
32 //////////////////////////////////////////////////////////////////////////
33 // //
34 // DataSet //
35 // //
36 // Class that contains all the data information //
37 // //
38 //////////////////////////////////////////////////////////////////////////
39 
40 #include <vector>
41 #include <map>
42 #include <string>
43 
44 #include "TObject.h"
45 #include "TNamed.h"
46 #include "TString.h"
47 #include "TTree.h"
48 //#ifndef ROOT_TCut
49 //#include "TCut.h"
50 //#endif
51 //#ifndef ROOT_TMatrixDfwd
52 //#include "TMatrixDfwd.h"
53 //#endif
54 //#ifndef ROOT_TPrincipal
55 //#include "TPrincipal.h"
56 //#endif
57 #include "TRandom3.h"
58 
59 #include "TMVA/Types.h"
60 #include "TMVA/VariableInfo.h"
61 
62 namespace TMVA {
63 
64  class Event;
65  class DataSetInfo;
66  class MsgLogger;
67  class Results;
68 
69  class DataSet :public TNamed {
70 
71  public:
72  DataSet();
73  DataSet(const DataSetInfo&);
74  virtual ~DataSet();
75 
76  void AddEvent( Event *, Types::ETreeType );
77 
78  Long64_t GetNEvents( Types::ETreeType type = Types::kMaxTreeType ) const;
79  Long64_t GetNTrainingEvents() const { return GetNEvents(Types::kTraining); }
80  Long64_t GetNTestEvents() const { return GetNEvents(Types::kTesting); }
81 
82  // const getters
83  const Event* GetEvent() const; // returns event without transformations
84  const Event* GetEvent ( Long64_t ievt ) const { fCurrentEventIdx = ievt; return GetEvent(); } // returns event without transformations
85  const Event* GetTrainingEvent( Long64_t ievt ) const { return GetEvent(ievt, Types::kTraining); }
86  const Event* GetTestEvent ( Long64_t ievt ) const { return GetEvent(ievt, Types::kTesting); }
87  const Event* GetEvent ( Long64_t ievt, Types::ETreeType type ) const
88  {
89  fCurrentTreeIdx = TreeIndex(type); fCurrentEventIdx = ievt; return GetEvent();
90  }
91 
92 
93 
94 
95  UInt_t GetNVariables() const;
96  UInt_t GetNTargets() const;
97  UInt_t GetNSpectators() const;
98 
99  void SetCurrentEvent( Long64_t ievt ) const { fCurrentEventIdx = ievt; }
100  void SetCurrentType ( Types::ETreeType type ) const { fCurrentTreeIdx = TreeIndex(type); }
101  Types::ETreeType GetCurrentType() const;
102 
103  void SetEventCollection( std::vector<Event*>*, Types::ETreeType, Bool_t deleteEvents = true );
104  const std::vector<Event*>& GetEventCollection( Types::ETreeType type = Types::kMaxTreeType ) const;
105  const TTree* GetEventCollectionAsTree();
106 
107  Long64_t GetNEvtSigTest();
108  Long64_t GetNEvtBkgdTest();
109  Long64_t GetNEvtSigTrain();
110  Long64_t GetNEvtBkgdTrain();
111 
112  Bool_t HasNegativeEventWeights() const { return fHasNegativeEventWeights; }
113 
114  Results* GetResults ( const TString &,
115  Types::ETreeType type,
116  Types::EAnalysisType analysistype );
117  void DeleteResults ( const TString &,
118  Types::ETreeType type,
119  Types::EAnalysisType analysistype );
120  void DeleteAllResults(Types::ETreeType type,
121  Types::EAnalysisType analysistype);
122 
123  void SetVerbose( Bool_t ) {}
124 
125  // sets the number of blocks to which the training set is divided,
126  // some of which are given to the Validation sample. As default they belong all to Training set.
127  void DivideTrainingSet( UInt_t blockNum );
128 
129  // sets a certrain block from the origin training set to belong to either Training or Validation set
130  void MoveTrainingBlock( Int_t blockInd,Types::ETreeType dest, Bool_t applyChanges = kTRUE );
131 
132  void IncrementNClassEvents( Int_t type, UInt_t classNumber );
133  Long64_t GetNClassEvents ( Int_t type, UInt_t classNumber );
134  void ClearNClassEvents ( Int_t type );
135 
136  TTree* GetTree( Types::ETreeType type );
137 
138  // accessors for random and importance sampling
139  void InitSampling( Float_t fraction, Float_t weight, UInt_t seed = 0 );
140  void EventResult( Bool_t successful, Long64_t evtNumber = -1 );
141  void CreateSampling() const;
142 
143  UInt_t TreeIndex(Types::ETreeType type) const;
144 
145  private:
146 
147  // data members
148  void DestroyCollection( Types::ETreeType type, Bool_t deleteEvents );
149 
150  const DataSetInfo *fdsi; //-> datasetinfo that created this dataset
151 
152  std::vector< std::vector<Event*> > fEventCollection; // list of events for training/testing/...
153 
154  std::vector< std::map< TString, Results* > > fResults; //! [train/test/...][method-identifier]
155 
156  mutable UInt_t fCurrentTreeIdx;
157  mutable Long64_t fCurrentEventIdx;
158 
159  // event sampling
160  std::vector<Char_t> fSampling; // random or importance sampling (not all events are taken) !! Bool_t are stored ( no std::vector<bool> taken for speed (performance) issues )
161  std::vector<Int_t> fSamplingNEvents; // number of events which should be sampled
162  std::vector<Float_t> fSamplingWeight; // weight change factor [weight is indicating if sampling is random (1.0) or importance (<1.0)]
163  mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingEventList; // weights and indices for sampling
164  mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingSelected; // selected events
165  TRandom3 *fSamplingRandom; //-> random generator for sampling
166 
167 
168  // further things
169  std::vector< std::vector<Long64_t> > fClassEvents; // number of events of class 0,1,2,... in training[0]
170  // and testing[1] (+validation, trainingoriginal)
171 
172  Bool_t fHasNegativeEventWeights; // true if at least one signal or bkg event has negative weight
173 
174  mutable MsgLogger* fLogger; //! message logger
175  MsgLogger& Log() const { return *fLogger; }
176  std::vector<Char_t> fBlockBelongToTraining; // when dividing the dataset to blocks, sets whether
177  // the certain block is in the Training set or else
178  // in the validation set
179  // boolean are stored, taken std::vector<Char_t> for performance reasons (instead of std::vector<Bool_t>)
180  Long64_t fTrainingBlockSize; // block size into which the training dataset is divided
181 
182  void ApplyTrainingBlockDivision();
183  void ApplyTrainingSetDivision();
184  public:
185 
186  ClassDef(DataSet,1);
187  };
188 }
189 
190 
191 //_______________________________________________________________________
192 inline UInt_t TMVA::DataSet::TreeIndex(Types::ETreeType type) const
193 {
194  switch (type) {
195  case Types::kMaxTreeType : return fCurrentTreeIdx;
196  case Types::kTraining : return 0;
197  case Types::kTesting : return 1;
198  case Types::kValidation : return 2;
199  case Types::kTrainingOriginal : return 3;
200  default : return fCurrentTreeIdx;
201  }
202 }
203 
204 //_______________________________________________________________________
205 inline TMVA::Types::ETreeType TMVA::DataSet::GetCurrentType() const
206 {
207  switch (fCurrentTreeIdx) {
208  case 0: return Types::kTraining;
209  case 1: return Types::kTesting;
210  case 2: return Types::kValidation;
211  case 3: return Types::kTrainingOriginal;
212  }
213  return Types::kMaxTreeType;
214 }
215 
216 //_______________________________________________________________________
217 inline Long64_t TMVA::DataSet::GetNEvents(Types::ETreeType type) const
218 {
219  Int_t treeIdx = TreeIndex(type);
220  if (fSampling.size() > UInt_t(treeIdx) && fSampling.at(treeIdx)) {
221  return fSamplingSelected.at(treeIdx).size();
222  }
223  return GetEventCollection(type).size();
224 }
225 
226 //_______________________________________________________________________
227 inline const std::vector<TMVA::Event*>& TMVA::DataSet::GetEventCollection( TMVA::Types::ETreeType type ) const
228 {
229  return fEventCollection.at(TreeIndex(type));
230 }
231 
232 
233 #endif