Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
DataSetInfo.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Joerg Stelzer, Peter Speckmeier
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DataSetInfo *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header for description) *
12  * *
13  * Authors (alphabetical): *
14  * Peter Speckmayer <speckmay@mail.cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - DESY, Germany *
16  * *
17  * Copyright (c) 2008: *
18  * CERN, Switzerland *
19  * MPI-K Heidelberg, Germany *
20  * DESY Hamburg, Germany *
21  * *
22  * Redistribution and use in source and binary forms, with or without *
23  * modification, are permitted according to the terms listed in LICENSE *
24  * (http://tmva.sourceforge.net/LICENSE) *
25  **********************************************************************************/
26 
27 /*! \class TMVA::DataSetInfo
28 \ingroup TMVA
29 
30 Class that contains all the data information.
31 
32 */
33 
34 #include <vector>
35 
36 #include "TEventList.h"
37 #include "TFile.h"
38 #include "TH1.h"
39 #include "TH2.h"
40 #include "TProfile.h"
41 #include "TRandom3.h"
42 #include "TMatrixF.h"
43 #include "TVectorF.h"
44 #include "TMath.h"
45 #include "TROOT.h"
46 #include "TObjString.h"
47 
48 #include "TMVA/MsgLogger.h"
49 #include "TMVA/Tools.h"
50 #include "TMVA/DataSet.h"
51 #include "TMVA/DataSetInfo.h"
52 #include "TMVA/DataSetManager.h"
53 #include "TMVA/Event.h"
54 
55 #include "TMVA/Types.h"
56 #include "TMVA/VariableInfo.h"
57 
58 ////////////////////////////////////////////////////////////////////////////////
59 /// constructor
60 
61 TMVA::DataSetInfo::DataSetInfo(const TString& name)
62  : TObject(),
63  fDataSetManager(NULL),
64  fName(name),
65  fDataSet( 0 ),
66  fNeedsRebuilding( kTRUE ),
67  fVariables(),
68  fTargets(),
69  fSpectators(),
70  fClasses( 0 ),
71  fNormalization( "NONE" ),
72  fSplitOptions(""),
73  fTrainingSumSignalWeights(-1),
74  fTrainingSumBackgrWeights(-1),
75  fTestingSumSignalWeights (-1),
76  fTestingSumBackgrWeights (-1),
77  fOwnRootDir(0),
78  fVerbose( kFALSE ),
79  fSignalClass(0),
80  fTargetsForMulticlass(0),
81  fLogger( new MsgLogger("DataSetInfo", kINFO) )
82 {
83  std::cout << "create data set info " << name << std::endl;
84 }
85 
86 ////////////////////////////////////////////////////////////////////////////////
87 /// destructor
88 
89 TMVA::DataSetInfo::~DataSetInfo()
90 {
91  ClearDataSet();
92 
93  for(UInt_t i=0, iEnd = fClasses.size(); i<iEnd; ++i) {
94  delete fClasses[i];
95  }
96 
97  delete fTargetsForMulticlass;
98 
99  delete fLogger;
100 }
101 
102 ////////////////////////////////////////////////////////////////////////////////
103 
104 void TMVA::DataSetInfo::ClearDataSet() const
105 {
106  if(fDataSet!=0) { delete fDataSet; fDataSet=0; }
107 }
108 
109 ////////////////////////////////////////////////////////////////////////////////
110 
111 void
112 TMVA::DataSetInfo::SetMsgType( EMsgType t ) const
113 {
114  fLogger->SetMinType(t);
115 }
116 
117 ////////////////////////////////////////////////////////////////////////////////
118 
119 TMVA::ClassInfo* TMVA::DataSetInfo::AddClass( const TString& className )
120 {
121  ClassInfo* theClass = GetClassInfo(className);
122  if (theClass) return theClass;
123 
124 
125  fClasses.push_back( new ClassInfo(className) );
126  fClasses.back()->SetNumber(fClasses.size()-1);
127 
128  //Log() << kHEADER << Endl;
129 
130  Log() << kHEADER << Form("[%s] : ",fName.Data()) << "Added class \"" << className << "\""<< Endl;
131 
132  Log() << kDEBUG <<"\t with internal class number " << fClasses.back()->GetNumber() << Endl;
133 
134 
135  if (className == "Signal") fSignalClass = fClasses.size()-1; // store the signal class index ( for comparison reasons )
136 
137  return fClasses.back();
138 }
139 
140 ////////////////////////////////////////////////////////////////////////////////
141 
142 TMVA::ClassInfo* TMVA::DataSetInfo::GetClassInfo( const TString& name ) const
143 {
144  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
145  if ((*it)->GetName() == name) return (*it);
146  }
147  return 0;
148 }
149 
150 ////////////////////////////////////////////////////////////////////////////////
151 
152 TMVA::ClassInfo* TMVA::DataSetInfo::GetClassInfo( Int_t cls ) const
153 {
154  try {
155  return fClasses.at(cls);
156  }
157  catch(...) {
158  return 0;
159  }
160 }
161 
162 ////////////////////////////////////////////////////////////////////////////////
163 
164 void TMVA::DataSetInfo::PrintClasses() const
165 {
166  for (UInt_t cls = 0; cls < GetNClasses() ; cls++) {
167  Log() << kINFO << Form("Dataset[%s] : ",fName.Data()) << "Class index : " << cls << " name : " << GetClassInfo(cls)->GetName() << Endl;
168  }
169 }
170 
171 ////////////////////////////////////////////////////////////////////////////////
172 
173 Bool_t TMVA::DataSetInfo::IsSignal( const TMVA::Event* ev ) const
174 {
175  return (ev->GetClass() == fSignalClass);
176 }
177 
178 ////////////////////////////////////////////////////////////////////////////////
179 
180 std::vector<Float_t>* TMVA::DataSetInfo::GetTargetsForMulticlass( const TMVA::Event* ev )
181 {
182  if( !fTargetsForMulticlass ) fTargetsForMulticlass = new std::vector<Float_t>( GetNClasses() );
183  // fTargetsForMulticlass->resize( GetNClasses() );
184  fTargetsForMulticlass->assign( GetNClasses(), 0.0 );
185  fTargetsForMulticlass->at( ev->GetClass() ) = 1.0;
186  return fTargetsForMulticlass;
187 }
188 
189 
190 ////////////////////////////////////////////////////////////////////////////////
191 
192 Bool_t TMVA::DataSetInfo::HasCuts() const
193 {
194  Bool_t hasCuts = kFALSE;
195  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
196  if( TString((*it)->GetCut()) != TString("") ) hasCuts = kTRUE;
197  }
198  return hasCuts;
199 }
200 
201 ////////////////////////////////////////////////////////////////////////////////
202 
203 const TMatrixD* TMVA::DataSetInfo::CorrelationMatrix( const TString& className ) const
204 {
205  ClassInfo* ptr = GetClassInfo(className);
206  return ptr?ptr->GetCorrelationMatrix():0;
207 }
208 
209 ////////////////////////////////////////////////////////////////////////////////
210 /// add a variable (can be a complex expression) to the set of
211 /// variables used in the MV analysis
212 
213 TMVA::VariableInfo& TMVA::DataSetInfo::AddVariable( const TString& expression,
214  const TString& title,
215  const TString& unit,
216  Double_t min, Double_t max,
217  char varType,
218  Bool_t normalized,
219  void* external )
220 {
221  TString regexpr = expression; // remove possible blanks
222  regexpr.ReplaceAll(" ", "" );
223  fVariables.push_back(VariableInfo( regexpr, title, unit,
224  fVariables.size()+1, varType, external, min, max, normalized ));
225  fNeedsRebuilding = kTRUE;
226  return fVariables.back();
227 }
228 
229 ////////////////////////////////////////////////////////////////////////////////
230 /// add variable with given VariableInfo
231 
232 TMVA::VariableInfo& TMVA::DataSetInfo::AddVariable( const VariableInfo& varInfo){
233  fVariables.push_back(VariableInfo( varInfo ));
234  fNeedsRebuilding = kTRUE;
235  return fVariables.back();
236 }
237 
238 ////////////////////////////////////////////////////////////////////////////////
239 /// add an array of variables identified by an expression corresponding to an array entry in the tree
240 
241 void TMVA::DataSetInfo::AddVariablesArray(const TString &expression, Int_t size, const TString &title, const TString &unit,
242  Double_t min, Double_t max, char varType, Bool_t normalized,
243  void *external)
244 {
245  TString regexpr = expression; // remove possible blanks
246  regexpr.ReplaceAll(" ", "");
247  fVariables.reserve(fVariables.size() + size);
248  for (int i = 0; i < size; ++i) {
249  TString newTitle = title + TString::Format("[%d]", i);
250 
251  fVariables.emplace_back(regexpr, newTitle, unit, fVariables.size() + 1, varType, external, min, max, normalized);
252  // set corresponding bit indicating is a variable from an array
253  fVariables.back().SetBit(kIsArrayVariable);
254  TString newVarName = fVariables.back().GetInternalName() + TString::Format("[%d]", i);
255  fVariables.back().SetInternalName(newVarName);
256  }
257  fVarArrays[regexpr] = size;
258  fNeedsRebuilding = kTRUE;
259 }
260 
261 ////////////////////////////////////////////////////////////////////////////////
262 /// add a variable (can be a complex expression) to the set of
263 /// variables used in the MV analysis
264 
265 TMVA::VariableInfo& TMVA::DataSetInfo::AddTarget( const TString& expression,
266  const TString& title,
267  const TString& unit,
268  Double_t min, Double_t max,
269  Bool_t normalized,
270  void* external )
271 {
272  TString regexpr = expression; // remove possible blanks
273  regexpr.ReplaceAll(" ", "" );
274  char type='F';
275  fTargets.push_back(VariableInfo( regexpr, title, unit,
276  fTargets.size()+1, type, external, min,
277  max, normalized ));
278  fNeedsRebuilding = kTRUE;
279  return fTargets.back();
280 }
281 
282 ////////////////////////////////////////////////////////////////////////////////
283 /// add target with given VariableInfo
284 
285 TMVA::VariableInfo& TMVA::DataSetInfo::AddTarget( const VariableInfo& varInfo){
286  fTargets.push_back(VariableInfo( varInfo ));
287  fNeedsRebuilding = kTRUE;
288  return fTargets.back();
289 }
290 
291 ////////////////////////////////////////////////////////////////////////////////
292 /// add a spectator (can be a complex expression) to the set of spectator variables used in
293 /// the MV analysis
294 
295 TMVA::VariableInfo& TMVA::DataSetInfo::AddSpectator( const TString& expression,
296  const TString& title,
297  const TString& unit,
298  Double_t min, Double_t max, char type,
299  Bool_t normalized, void* external )
300 {
301  TString regexpr = expression; // remove possible blanks
302  regexpr.ReplaceAll(" ", "" );
303  fSpectators.push_back(VariableInfo( regexpr, title, unit,
304  fSpectators.size()+1, type, external, min, max, normalized ));
305  fNeedsRebuilding = kTRUE;
306  return fSpectators.back();
307 }
308 
309 ////////////////////////////////////////////////////////////////////////////////
310 /// add spectator with given VariableInfo
311 
312 TMVA::VariableInfo& TMVA::DataSetInfo::AddSpectator( const VariableInfo& varInfo){
313  fSpectators.push_back(VariableInfo( varInfo ));
314  fNeedsRebuilding = kTRUE;
315  return fSpectators.back();
316 }
317 
318 ////////////////////////////////////////////////////////////////////////////////
319 /// find variable by name
320 
321 Int_t TMVA::DataSetInfo::FindVarIndex(const TString& var) const
322 {
323  for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
324  if (var == GetVariableInfo(ivar).GetInternalName()) return ivar;
325 
326  for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
327  Log() << kINFO << Form("Dataset[%s] : ",fName.Data()) << GetVariableInfo(ivar).GetInternalName() << Endl;
328 
329  Log() << kFATAL << Form("Dataset[%s] : ",fName.Data()) << "<FindVarIndex> Variable \'" << var << "\' not found." << Endl;
330 
331  return -1;
332 }
333 
334 ////////////////////////////////////////////////////////////////////////////////
335 /// set the weight expressions for the classes
336 /// if class name is specified, set only for this class
337 /// if class name is unknown, register new class with this name
338 
339 void TMVA::DataSetInfo::SetWeightExpression( const TString& expr, const TString& className )
340 {
341  if (className != "") {
342  TMVA::ClassInfo* ci = AddClass(className);
343  ci->SetWeight( expr );
344  }
345  else {
346  // no class name specified, set weight for all classes
347  if (fClasses.empty()) {
348  Log() << kWARNING << Form("Dataset[%s] : ",fName.Data()) << "No classes registered yet, cannot specify weight expression!" << Endl;
349  }
350  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
351  (*it)->SetWeight( expr );
352  }
353  }
354 }
355 
356 ////////////////////////////////////////////////////////////////////////////////
357 
358 void TMVA::DataSetInfo::SetCorrelationMatrix( const TString& className, TMatrixD* matrix )
359 {
360  GetClassInfo(className)->SetCorrelationMatrix(matrix);
361 }
362 
363 ////////////////////////////////////////////////////////////////////////////////
364 /// set the cut for the classes
365 
366 void TMVA::DataSetInfo::SetCut( const TCut& cut, const TString& className )
367 {
368  if (className == "") { // if no className has been given set the cut for all the classes
369  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
370  (*it)->SetCut( cut );
371  }
372  }
373  else {
374  TMVA::ClassInfo* ci = AddClass(className);
375  ci->SetCut( cut );
376  }
377 }
378 
379 ////////////////////////////////////////////////////////////////////////////////
380 /// set the cut for the classes
381 
382 void TMVA::DataSetInfo::AddCut( const TCut& cut, const TString& className )
383 {
384  if (className == "") { // if no className has been given set the cut for all the classes
385  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); ++it) {
386  const TCut& oldCut = (*it)->GetCut();
387  (*it)->SetCut( oldCut+cut );
388  }
389  }
390  else {
391  TMVA::ClassInfo* ci = AddClass(className);
392  ci->SetCut( ci->GetCut()+cut );
393  }
394 }
395 
396 ////////////////////////////////////////////////////////////////////////////////
397 /// returns list of variables
398 
399 std::vector<TString> TMVA::DataSetInfo::GetListOfVariables() const
400 {
401  std::vector<TString> vNames;
402  std::vector<TMVA::VariableInfo>::const_iterator viIt = GetVariableInfos().begin();
403  for(;viIt != GetVariableInfos().end(); ++viIt) vNames.push_back( (*viIt).GetInternalName() );
404 
405  return vNames;
406 }
407 
408 ////////////////////////////////////////////////////////////////////////////////
409 /// calculates the correlation matrices for signal and background,
410 /// prints them to standard output, and fills 2D histograms
411 
412 void TMVA::DataSetInfo::PrintCorrelationMatrix( const TString& className )
413 {
414 
415  Log() << kHEADER //<< Form("Dataset[%s] : ",fName.Data())
416  << "Correlation matrix (" << className << "):" << Endl;
417  gTools().FormattedOutput( *CorrelationMatrix( className ), GetListOfVariables(), Log() );
418 }
419 
420 ////////////////////////////////////////////////////////////////////////////////
421 
422 TH2* TMVA::DataSetInfo::CreateCorrelationMatrixHist( const TMatrixD* m,
423  const TString& hName,
424  const TString& hTitle ) const
425 {
426  if (m==0) return 0;
427 
428  const UInt_t nvar = GetNVariables();
429 
430  // workaround till the TMatrix templates are commonly used
431  // this keeps backward compatibility
432  TMatrixF* tm = new TMatrixF( nvar, nvar );
433  for (UInt_t ivar=0; ivar<nvar; ivar++) {
434  for (UInt_t jvar=0; jvar<nvar; jvar++) {
435  (*tm)(ivar, jvar) = (*m)(ivar,jvar);
436  }
437  }
438 
439  TH2F* h2 = new TH2F( *tm );
440  h2->SetNameTitle( hName, hTitle );
441 
442  for (UInt_t ivar=0; ivar<nvar; ivar++) {
443  h2->GetXaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
444  h2->GetYaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
445  }
446 
447  // present in percent, and round off digits
448  // also, use absolute value of correlation coefficient (ignore sign)
449  h2->Scale( 100.0 );
450  for (UInt_t ibin=1; ibin<=nvar; ibin++) {
451  for (UInt_t jbin=1; jbin<=nvar; jbin++) {
452  h2->SetBinContent( ibin, jbin, Int_t(h2->GetBinContent( ibin, jbin )) );
453  }
454  }
455 
456  // style settings
457  const Float_t labelSize = 0.055;
458  h2->SetStats( 0 );
459  h2->GetXaxis()->SetLabelSize( labelSize );
460  h2->GetYaxis()->SetLabelSize( labelSize );
461  h2->SetMarkerSize( 1.5 );
462  h2->SetMarkerColor( 0 );
463  h2->LabelsOption( "d" ); // diagonal labels on x axis
464  h2->SetLabelOffset( 0.011 );// label offset on x axis
465  h2->SetMinimum( -100.0 );
466  h2->SetMaximum( +100.0 );
467 
468  // -------------------------------------------------------------------------------------
469  // just in case one wants to change the position of the color palette axis
470  // -------------------------------------------------------------------------------------
471  // gROOT->SetStyle("Plain");
472  // TStyle* gStyle = gROOT->GetStyle( "Plain" );
473  // gStyle->SetPalette( 1, 0 );
474  // TPaletteAxis* paletteAxis
475  // = (TPaletteAxis*)h2->GetListOfFunctions()->FindObject( "palette" );
476  // -------------------------------------------------------------------------------------
477 
478  Log() << kDEBUG << Form("Dataset[%s] : ",fName.Data()) << "Created correlation matrix as 2D histogram: " << h2->GetName() << Endl;
479 
480  return h2;
481 }
482 
483 ////////////////////////////////////////////////////////////////////////////////
484 /// returns data set
485 
486 TMVA::DataSet* TMVA::DataSetInfo::GetDataSet() const
487 {
488  if (fDataSet==0 || fNeedsRebuilding) {
489  if(fDataSet!=0) ClearDataSet();
490  // fDataSet = DataSetManager::Instance().CreateDataSet(GetName()); //DSMTEST replaced by following lines
491  if( !fDataSetManager )
492  Log() << kFATAL << Form("Dataset[%s] : ",fName.Data()) << "DataSetManager has not been set in DataSetInfo (GetDataSet() )." << Endl;
493  fDataSet = fDataSetManager->CreateDataSet(GetName());
494 
495  fNeedsRebuilding = kFALSE;
496  }
497  return fDataSet;
498 }
499 
500 ////////////////////////////////////////////////////////////////////////////////
501 
502 UInt_t TMVA::DataSetInfo::GetNSpectators(bool all) const
503 {
504  if(all)
505  return fSpectators.size();
506  UInt_t nsp(0);
507  for(std::vector<VariableInfo>::const_iterator spit=fSpectators.begin(); spit!=fSpectators.end(); ++spit) {
508  if(spit->GetVarType()!='C') nsp++;
509  }
510  return nsp;
511 }
512 
513 ////////////////////////////////////////////////////////////////////////////////
514 
515 Int_t TMVA::DataSetInfo::GetClassNameMaxLength() const
516 {
517  Int_t maxL = 0;
518  for (UInt_t cl = 0; cl < GetNClasses(); cl++) {
519  if (TString(GetClassInfo(cl)->GetName()).Length() > maxL) maxL = TString(GetClassInfo(cl)->GetName()).Length();
520  }
521 
522  return maxL;
523 }
524 
525 ////////////////////////////////////////////////////////////////////////////////
526 
527 Int_t TMVA::DataSetInfo::GetVariableNameMaxLength() const
528 {
529  Int_t maxL = 0;
530  for (UInt_t i = 0; i < GetNVariables(); i++) {
531  if (TString(GetVariableInfo(i).GetExpression()).Length() > maxL) maxL = TString(GetVariableInfo(i).GetExpression()).Length();
532  }
533 
534  return maxL;
535 }
536 
537 ////////////////////////////////////////////////////////////////////////////////
538 
539 Int_t TMVA::DataSetInfo::GetTargetNameMaxLength() const
540 {
541  Int_t maxL = 0;
542  for (UInt_t i = 0; i < GetNTargets(); i++) {
543  if (TString(GetTargetInfo(i).GetExpression()).Length() > maxL) maxL = TString(GetTargetInfo(i).GetExpression()).Length();
544  }
545 
546  return maxL;
547 }
548 
549 ////////////////////////////////////////////////////////////////////////////////
550 
551 Double_t TMVA::DataSetInfo::GetTrainingSumSignalWeights(){
552  if (fTrainingSumSignalWeights<0) Log() << kFATAL << Form("Dataset[%s] : ",fName.Data()) << " asking for the sum of training signal event weights which is not initialized yet" << Endl;
553  return fTrainingSumSignalWeights;
554 }
555 
556 ////////////////////////////////////////////////////////////////////////////////
557 
558 Double_t TMVA::DataSetInfo::GetTrainingSumBackgrWeights(){
559  if (fTrainingSumBackgrWeights<0) Log() << kFATAL << Form("Dataset[%s] : ",fName.Data()) << " asking for the sum of training backgr event weights which is not initialized yet" << Endl;
560  return fTrainingSumBackgrWeights;
561 }
562 
563 ////////////////////////////////////////////////////////////////////////////////
564 
565 Double_t TMVA::DataSetInfo::GetTestingSumSignalWeights (){
566  if (fTestingSumSignalWeights<0) Log() << kFATAL << Form("Dataset[%s] : ",fName.Data()) << " asking for the sum of testing signal event weights which is not initialized yet" << Endl;
567  return fTestingSumSignalWeights ;
568 }
569 
570 ////////////////////////////////////////////////////////////////////////////////
571 
572 Double_t TMVA::DataSetInfo::GetTestingSumBackgrWeights (){
573  if (fTestingSumBackgrWeights<0) Log() << kFATAL << Form("Dataset[%s] : ",fName.Data()) << " asking for the sum of testing backgr event weights which is not initialized yet" << Endl;
574  return fTestingSumBackgrWeights ;
575 }