Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
HistFactoryModelUtils.cxx
Go to the documentation of this file.
1 /**
2  * \ingroup HistFactory
3  */
4 
5 // A set of utils for navegating HistFactory models
6 #include <stdexcept>
7 #include <typeinfo>
8 
10 #include "TIterator.h"
11 #include "RooAbsArg.h"
12 #include "RooAbsPdf.h"
13 #include "RooArgSet.h"
14 #include "RooArgList.h"
15 #include "RooSimultaneous.h"
16 #include "RooCategory.h"
17 #include "RooRealVar.h"
18 #include "RooProdPdf.h"
19 #include "TH1.h"
20 
23 
24 namespace RooStats{
25 namespace HistFactory{
26 
27 
28  std::string channelNameFromPdf( RooAbsPdf* channelPdf ) {
29  std::string channelPdfName = channelPdf->GetName();
30  std::string ChannelName = channelPdfName.substr(6, channelPdfName.size() );
31  return ChannelName;
32  }
33 
34  RooAbsPdf* getSumPdfFromChannel( RooAbsPdf* sim_channel ) {
35 
36  bool verbose=false;
37 
38  if(verbose) std::cout << "Getting the RooRealSumPdf for the channel: "
39  << sim_channel->GetName() << std::endl;
40 
41  std::string channelPdfName = sim_channel->GetName();
42  std::string ChannelName = channelPdfName.substr(6, channelPdfName.size() );
43 
44  // Now, get the RooRealSumPdf
45  // ie the channel WITHOUT constraints
46  std::string realSumPdfName = ChannelName + "_model";
47 
48  RooAbsPdf* sum_pdf = NULL;
49  TIterator* iter_sum_pdf = sim_channel->getComponents()->createIterator(); //serverIterator();
50  bool FoundSumPdf=false;
51  RooAbsArg* sum_pdf_arg=NULL;
52  while((sum_pdf_arg=(RooAbsArg*)iter_sum_pdf->Next())) {
53  std::string NodeClassName = sum_pdf_arg->ClassName();
54  if( NodeClassName == std::string("RooRealSumPdf") ) {
55  FoundSumPdf=true;
56  sum_pdf = (RooAbsPdf*) sum_pdf_arg;
57  break;
58  }
59  }
60  if( ! FoundSumPdf ) {
61  if(verbose) {
62  std::cout << "Failed to find RooRealSumPdf for channel: " << sim_channel->GetName() << std::endl;
63  sim_channel->getComponents()->Print("V");
64  }
65  sum_pdf=NULL;
66  //throw std::runtime_error("Failed to find RooRealSumPdf for channel");
67  }
68  else {
69  if(verbose) std::cout << "Found RooRealSumPdf: " << sum_pdf->GetName() << std::endl;
70  }
71  delete iter_sum_pdf;
72  iter_sum_pdf = NULL;
73 
74  return sum_pdf;
75 
76  }
77 
78 
79  void FactorizeHistFactoryPdf(const RooArgSet &observables, RooAbsPdf &pdf, RooArgList &obsTerms, RooArgList &constraints) {
80  // utility function to factorize constraint terms from a pdf
81  // (from G. Petrucciani)
82  const std::type_info & id = typeid(pdf);
83  if (id == typeid(RooProdPdf)) {
84  RooProdPdf *prod = dynamic_cast<RooProdPdf *>(&pdf);
85  RooArgList list(prod->pdfList());
86  for (int i = 0, n = list.getSize(); i < n; ++i) {
87  RooAbsPdf *pdfi = (RooAbsPdf *) list.at(i);
88  FactorizeHistFactoryPdf(observables, *pdfi, obsTerms, constraints);
89  }
90  } else if (id == typeid(RooSimultaneous) || id == typeid(HistFactorySimultaneous) ) { //|| id == typeid(RooSimultaneousOpt)) {
91  RooSimultaneous *sim = dynamic_cast<RooSimultaneous *>(&pdf);
92  RooAbsCategoryLValue *cat = (RooAbsCategoryLValue *) sim->indexCat().Clone();
93  for (int ic = 0, nc = cat->numBins((const char *)0); ic < nc; ++ic) {
94  cat->setBin(ic);
95  FactorizeHistFactoryPdf(observables, *sim->getPdf(cat->getLabel()), obsTerms, constraints);
96  }
97  delete cat;
98  } else if (pdf.dependsOn(observables)) {
99  if (!obsTerms.contains(pdf)) obsTerms.add(pdf);
100  } else {
101  if (!constraints.contains(pdf)) constraints.add(pdf);
102  }
103  }
104 
105  /*
106  void getChannelsFromModel( RooAbsPdf* model, RooArgSet* channels, RooArgSet* channelsWithConstraints ) {
107 
108  // Loop through the model
109  // Find all channels
110 
111  std::string modelClassName = model->ClassName();
112 
113  if( modelClassName == std::string("RooSimultaneous") || model->InheritsFrom("RooSimultaneous") ) {
114 
115  TIterator* simServerItr = model->serverIterator();
116 
117  // Loop through the child nodes of the sim pdf
118  // and find the channel nodes
119  RooAbsArg* sim_channel_arg = NULL;
120  while(( sim_channel = (RooAbsArg*) simServerItr->Next() )) {
121 
122  RooAbsPdf* sim_channel = (RooAbsPdf*) sim_channel_arg;
123 
124  // Ignore the Channel Cat
125  std::string channelPdfName = sim_channel->GetName();
126  std::string channelClassName = sim_channel->ClassName();
127  if( channelClassName == std::string("RooCategory") ) continue;
128 
129  // If we got here, we found a channel.
130  // Format is model_<ChannelName>
131 
132  std::string ChannelName = channelPdfName.substr(6, channelPdfName.size() );
133 
134  // Now, get the RooRealSumPdf
135  RooAbsPdf* sum_pdf = getSumPdfFromChannel( sim_channel );
136 
137 
138  / *
139  // Now, get the RooRealSumPdf
140  // ie the channel WITHOUT constraints
141 
142  std::string realSumPdfName = ChannelName + "_model";
143 
144  RooAbsPdf* sum_pdf = NULL;
145  TIterator* iter_sum_pdf = sim_channel->getComponents()->createIterator(); //serverIterator();
146  bool FoundSumPdf=false;
147  RooAbsArg* sum_pdf_arg=NULL;
148  while((sum_pdf_arg=(RooAbsArg*)iter_sum_pdf->Next())) {
149 
150  std::string NodeClassName = sum_pdf_arg->ClassName();
151  if( NodeClassName == std::string("RooRealSumPdf") ) {
152  FoundSumPdf=true;
153  sum_pdf = (RooAbsPdf*) sum_pdf_arg;
154  break;
155  }
156  }
157  if( ! FoundSumPdf ) {
158  std::cout << "Failed to find RooRealSumPdf for channel: " << sim_channel->GetName() << std::endl;
159  sim_channel->getComponents()->Print("V");
160  throw std::runtime_error("Failed to find RooRealSumPdf for channel");
161  }
162  delete iter_sum_pdf;
163  iter_sum_pdf = NULL;
164  * /
165 
166  // Okay, now add to the arg sets
167  channels->add( *sum_pdf );
168  channelsWithConstraints->add( *sim_channel );
169 
170  }
171 
172  delete simServerItr;
173 
174  }
175  else {
176  std::cout << "Model is not a RooSimultaneous or doesn't derive from one." << std::endl;
177  std::cout << "HistFactoryModelUtils isn't yet implemented for these pdf's" << std::endl;
178  }
179 
180  }
181  */
182 
183  bool getStatUncertaintyFromChannel( RooAbsPdf* channel, ParamHistFunc*& paramfunc, RooArgList* gammaList ) {
184 
185  bool verbose=false;
186 
187  // Find the servers of this channel
188  //TIterator* iter = channel->serverIterator();
189  TIterator* iter = channel->getComponents()->createIterator(); //serverIterator();
190  bool FoundParamHistFunc=false;
191  RooAbsArg* paramfunc_arg = NULL;
192  while(( paramfunc_arg = (RooAbsArg*) iter->Next() )) {
193  std::string NodeName = paramfunc_arg->GetName();
194  std::string NodeClassName = paramfunc_arg->ClassName();
195  if( NodeClassName != std::string("ParamHistFunc") ) continue;
196  if( NodeName.find("mc_stat_") != std::string::npos ) {
197  FoundParamHistFunc=true;
198  paramfunc = (ParamHistFunc*) paramfunc_arg;
199  break;
200  }
201  }
202  if( ! FoundParamHistFunc || !paramfunc ) {
203  if(verbose) std::cout << "Failed to find ParamHistFunc for channel: " << channel->GetName() << std::endl;
204  return false;
205  }
206 
207  delete iter;
208  iter = NULL;
209 
210  // Now, get the set of gamma's
211  gammaList = (RooArgList*) &( paramfunc->paramList());
212  if(verbose) gammaList->Print("V");
213 
214  return true;
215 
216  }
217 
218 
219  void getDataValuesForObservables( std::map< std::string, std::vector<double> >& ChannelBinDataMap,
220  RooAbsData* data, RooAbsPdf* pdf ) {
221 
222  bool verbose=false;
223 
224  //std::map< std::string, std::vector<int> ChannelBinDataMap;
225 
226  RooSimultaneous* simPdf = (RooSimultaneous*) pdf;
227 
228  // get category label
229  RooArgSet* allobs = (RooArgSet*) data->get();
230  TIterator* obsIter = allobs->createIterator();
231  RooCategory* cat = NULL;
232  RooAbsArg* temp = NULL;
233  while( (temp=(RooAbsArg*) obsIter->Next())) {
234  // use dynamic cast here instead
235  if( strcmp(temp->ClassName(),"RooCategory")==0){
236  cat = (RooCategory*) temp;
237  break;
238  }
239  }
240  if(verbose) {
241  if(!cat) std::cout <<"didn't find category"<< std::endl;
242  else std::cout <<"found category"<< std::endl;
243  }
244  delete obsIter;
245 
246  // split dataset
247  TList* dataByCategory = data->split(*cat);
248  if(verbose) dataByCategory->Print();
249  // note :
250  // RooAbsData* dataForChan = (RooAbsData*) dataByCategory->FindObject("");
251 
252  // loop over channels
253  RooCategory* channelCat = (RooCategory*) (&simPdf->indexCat());
254  TIterator* iter = channelCat->typeIterator() ;
255  RooCatType* tt = NULL;
256  while((tt=(RooCatType*) iter->Next())) {
257 
258  // Get pdf associated with state from simpdf
259  RooAbsPdf* pdftmp = simPdf->getPdf(tt->GetName()) ;
260 
261  std::string ChannelName = pdftmp->GetName(); //tt->GetName();
262  if(verbose) std::cout << "Getting data for channel: " << ChannelName << std::endl;
263  ChannelBinDataMap[ ChannelName ] = std::vector<double>();
264 
265  RooAbsData* dataForChan = (RooAbsData*) dataByCategory->FindObject(tt->GetName());
266  if(verbose) dataForChan->Print();
267 
268  // Generate observables defined by the pdf associated with this state
269  RooArgSet* obstmp = pdftmp->getObservables(*dataForChan->get()) ;
270  RooRealVar* obs = ((RooRealVar*)obstmp->first());
271  if(verbose) obs->Print();
272 
273  //double expected = pdftmp->expectedEvents(*obstmp);
274 
275  // set value to desired value (this is just an example)
276  // double obsVal = obs->getVal();
277  // set obs to desired value of observable
278  // obs->setVal( obsVal );
279  //double fracAtObsValue = pdftmp->getVal(*obstmp);
280 
281  // get num events expected in bin for obsVal
282  // double nu = expected * fracAtObsValue;
283 
284  // an easier way to get n
285  TH1* histForN = dataForChan->createHistogram("HhstForN",*obs);
286  for(int i=1; i<=histForN->GetNbinsX(); ++i){
287  double n = histForN->GetBinContent(i);
288  if(verbose) std::cout << "n" << i << " = " << n << std::endl;
289  ChannelBinDataMap[ ChannelName ].push_back( n );
290  }
291  delete histForN;
292 
293  } // End Loop Over Categories
294 
295  delete iter;
296  return;
297 
298  }
299 
300 
301  int getStatUncertaintyConstraintTerm( RooArgList* constraints, RooRealVar* gamma_stat,
302  RooAbsReal*& pois_nom, RooRealVar*& tau ) {
303  // Given a set of constraint terms,
304  // find the poisson constraint for the
305  // given gamma and return the mean
306  // as well as the 'tau' parameter
307 
308  bool verbose=false;
309 
310  // To get the constraint term, loop over all constraint terms
311  // and look for the gamma_stat name as well as '_constraint'
312  // std::string constraintTermName = std::string(gamma_stat->GetName()) + "_constraint";
313  TIterator* iter_list = constraints->createIterator();
314  RooAbsArg* term_constr=NULL;
315  bool FoundConstraintTerm=false;
316  RooAbsPdf* constraintTerm=NULL;
317  while((term_constr=(RooAbsArg*)iter_list->Next())) {
318  std::string TermName = term_constr->GetName();
319  // std::cout << "Checking if is a constraint term: " << TermName << std::endl;
320 
321  //if( TermName.find(gamma_stat->GetName())!=string::npos ) {
322  if( term_constr->dependsOn( *gamma_stat) ) {
323  if( TermName.find("_constraint")!=std::string::npos ) {
324  FoundConstraintTerm=true;
325  constraintTerm = (RooAbsPdf*) term_constr;
326  break;
327  }
328  }
329  }
330  if( FoundConstraintTerm==false ) {
331  std::cout << "Error: Couldn't find constraint term for parameter: " << gamma_stat->GetName()
332  << " among constraints: " << constraints->GetName() << std::endl;
333  constraints->Print("V");
334  throw std::runtime_error("Failed to find Gamma ConstraintTerm");
335  return -1;
336  }
337  delete iter_list;
338 
339  /*
340  RooAbsPdf* constraintTerm = (RooAbsPdf*) constraints->find( constraintTermName.c_str() );
341  if( constraintTerm == NULL ) {
342  std::cout << "Error: Couldn't find constraint term: " << constraintTermName
343  << " for parameter: " << gamma_stat->GetName()
344  << std::endl;
345  throw std::runtime_error("Failed to find Gamma ConstraintTerm");
346  return -1;
347  }
348  */
349 
350  // Find the "data" of the poisson term
351  // This is the nominal value
352  bool FoundNomMean=false;
353  TIterator* iter_pois = constraintTerm->serverIterator(); //constraint_args
354  RooAbsArg* term_pois ;
355  while((term_pois=(RooAbsArg*)iter_pois->Next())) {
356  std::string serverName = term_pois->GetName();
357  //std::cout << "Checking Server: " << serverName << std::endl;
358  if( serverName.find("nom_")!=std::string::npos ) {
359  FoundNomMean = true;
360  pois_nom = (RooRealVar*) term_pois;
361  }
362  }
363  if( !FoundNomMean || !pois_nom ) {
364  std::cout << "Error: Did not find Nominal Pois Mean parameter in gamma constraint term PoissonMean: "
365  << constraintTerm->GetName() << std::endl;
366  throw std::runtime_error("Failed to find Nom Pois Mean");
367  }
368  else {
369  if(verbose) std::cout << "Found Poisson 'data' term: " << pois_nom->GetName() << std::endl;
370  }
371  delete iter_pois;
372 
373  // Taking the constraint term (a Poisson), find
374  // the "mean" which is the product: gamma*tau
375  // Then, from that mean, find tau
376  TIterator* iter_constr = constraintTerm->serverIterator(); //constraint_args
377  RooAbsArg* pois_mean_arg=NULL;
378  bool FoundPoissonMean = false;
379  while(( pois_mean_arg = (RooAbsArg*) iter_constr->Next() )) {
380  std::string serverName = pois_mean_arg->GetName();
381  if( pois_mean_arg->dependsOn( *gamma_stat ) ) {
382  FoundPoissonMean=true;
383  // pois_mean = (RooAbsReal*) pois_mean_arg;
384  break;
385  }
386  }
387  if( !FoundPoissonMean || !pois_mean_arg ) {
388  std::cout << "Error: Did not find PoissonMean parameter in gamma constraint term: "
389  << constraintTerm->GetName() << std::endl;
390  throw std::runtime_error("Failed to find PoissonMean");
391  return -1;
392  }
393  else {
394  if(verbose) std::cout << "Found Poisson 'mean' term: " << pois_mean_arg->GetName() << std::endl;
395  }
396  delete iter_constr;
397 
398 
399  TIterator* iter_product = pois_mean_arg->serverIterator(); //constraint_args
400  RooAbsArg* term_in_product ;
401  bool FoundTau=false;
402  while((term_in_product=(RooAbsArg*)iter_product->Next())) {
403  std::string serverName = term_in_product->GetName();
404  //std::cout << "Checking Server: " << serverName << std::endl;
405  if( serverName.find("_tau")!=std::string::npos ) {
406  FoundTau = true;
407  tau = (RooRealVar*) term_in_product;
408  }
409  }
410  if( !FoundTau || !tau ) {
411  std::cout << "Error: Did not find Tau parameter in gamma constraint term PoissonMean: "
412  << pois_mean_arg->GetName() << std::endl;
413  throw std::runtime_error("Failed to find Tau");
414  }
415  else {
416  if(verbose) std::cout << "Found Poisson 'tau' term: " << tau->GetName() << std::endl;
417  }
418  delete iter_product;
419 
420  return 0;
421 
422  }
423 
424 
425 
426 } // close RooStats namespace
427 } // close HistFactory namespace