29 #ifndef ROOT_TMVA_DataSet
30 #define ROOT_TMVA_DataSet
69 class DataSet :
public TNamed {
73 DataSet(
const DataSetInfo&);
76 void AddEvent( Event *, Types::ETreeType );
78 Long64_t GetNEvents( Types::ETreeType type = Types::kMaxTreeType )
const;
79 Long64_t GetNTrainingEvents()
const {
return GetNEvents(Types::kTraining); }
80 Long64_t GetNTestEvents()
const {
return GetNEvents(Types::kTesting); }
83 const Event* GetEvent()
const;
84 const Event* GetEvent ( Long64_t ievt )
const { fCurrentEventIdx = ievt;
return GetEvent(); }
85 const Event* GetTrainingEvent( Long64_t ievt )
const {
return GetEvent(ievt, Types::kTraining); }
86 const Event* GetTestEvent ( Long64_t ievt )
const {
return GetEvent(ievt, Types::kTesting); }
87 const Event* GetEvent ( Long64_t ievt, Types::ETreeType type )
const
89 fCurrentTreeIdx = TreeIndex(type); fCurrentEventIdx = ievt;
return GetEvent();
95 UInt_t GetNVariables()
const;
96 UInt_t GetNTargets()
const;
97 UInt_t GetNSpectators()
const;
99 void SetCurrentEvent( Long64_t ievt )
const { fCurrentEventIdx = ievt; }
100 void SetCurrentType ( Types::ETreeType type )
const { fCurrentTreeIdx = TreeIndex(type); }
101 Types::ETreeType GetCurrentType()
const;
103 void SetEventCollection( std::vector<Event*>*, Types::ETreeType, Bool_t deleteEvents =
true );
104 const std::vector<Event*>& GetEventCollection( Types::ETreeType type = Types::kMaxTreeType )
const;
105 const TTree* GetEventCollectionAsTree();
107 Long64_t GetNEvtSigTest();
108 Long64_t GetNEvtBkgdTest();
109 Long64_t GetNEvtSigTrain();
110 Long64_t GetNEvtBkgdTrain();
112 Bool_t HasNegativeEventWeights()
const {
return fHasNegativeEventWeights; }
114 Results* GetResults (
const TString &,
115 Types::ETreeType type,
116 Types::EAnalysisType analysistype );
117 void DeleteResults (
const TString &,
118 Types::ETreeType type,
119 Types::EAnalysisType analysistype );
120 void DeleteAllResults(Types::ETreeType type,
121 Types::EAnalysisType analysistype);
123 void SetVerbose( Bool_t ) {}
127 void DivideTrainingSet( UInt_t blockNum );
130 void MoveTrainingBlock( Int_t blockInd,Types::ETreeType dest, Bool_t applyChanges = kTRUE );
132 void IncrementNClassEvents( Int_t type, UInt_t classNumber );
133 Long64_t GetNClassEvents ( Int_t type, UInt_t classNumber );
134 void ClearNClassEvents ( Int_t type );
136 TTree* GetTree( Types::ETreeType type );
139 void InitSampling( Float_t fraction, Float_t weight, UInt_t seed = 0 );
140 void EventResult( Bool_t successful, Long64_t evtNumber = -1 );
141 void CreateSampling()
const;
143 UInt_t TreeIndex(Types::ETreeType type)
const;
148 void DestroyCollection( Types::ETreeType type, Bool_t deleteEvents );
150 const DataSetInfo *fdsi;
152 std::vector< std::vector<Event*> > fEventCollection;
154 std::vector< std::map< TString, Results* > > fResults;
156 mutable UInt_t fCurrentTreeIdx;
157 mutable Long64_t fCurrentEventIdx;
160 std::vector<Char_t> fSampling;
161 std::vector<Int_t> fSamplingNEvents;
162 std::vector<Float_t> fSamplingWeight;
163 mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingEventList;
164 mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingSelected;
165 TRandom3 *fSamplingRandom;
169 std::vector< std::vector<Long64_t> > fClassEvents;
172 Bool_t fHasNegativeEventWeights;
174 mutable MsgLogger* fLogger;
175 MsgLogger& Log()
const {
return *fLogger; }
176 std::vector<Char_t> fBlockBelongToTraining;
180 Long64_t fTrainingBlockSize;
182 void ApplyTrainingBlockDivision();
183 void ApplyTrainingSetDivision();
192 inline UInt_t TMVA::DataSet::TreeIndex(Types::ETreeType type)
const
195 case Types::kMaxTreeType :
return fCurrentTreeIdx;
196 case Types::kTraining :
return 0;
197 case Types::kTesting :
return 1;
198 case Types::kValidation :
return 2;
199 case Types::kTrainingOriginal :
return 3;
200 default :
return fCurrentTreeIdx;
205 inline TMVA::Types::ETreeType TMVA::DataSet::GetCurrentType()
const
207 switch (fCurrentTreeIdx) {
208 case 0:
return Types::kTraining;
209 case 1:
return Types::kTesting;
210 case 2:
return Types::kValidation;
211 case 3:
return Types::kTrainingOriginal;
213 return Types::kMaxTreeType;
217 inline Long64_t TMVA::DataSet::GetNEvents(Types::ETreeType type)
const
219 Int_t treeIdx = TreeIndex(type);
220 if (fSampling.size() > UInt_t(treeIdx) && fSampling.at(treeIdx)) {
221 return fSamplingSelected.at(treeIdx).size();
223 return GetEventCollection(type).size();
227 inline const std::vector<TMVA::Event*>& TMVA::DataSet::GetEventCollection( TMVA::Types::ETreeType type )
const
229 return fEventCollection.at(TreeIndex(type));