Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
MCMCCalculator.cxx
Go to the documentation of this file.
1 // @(#)root/roostats:$Id$
2 // Authors: Kevin Belasco 17/06/2009
3 // Authors: Kyle Cranmer 17/06/2009
4 /*************************************************************************
5  * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers. *
6  * All rights reserved. *
7  * *
8  * For the licensing terms see $ROOTSYS/LICENSE. *
9  * For the list of contributors see $ROOTSYS/README/CREDITS. *
10  *************************************************************************/
11 
12 
13 /** \class RooStats::MCMCCalculator
14  \ingroup Roostats
15 
16  Bayesian Calculator estimating an interval or a credible region using the
17  Markov-Chain Monte Carlo method to integrate the likelihood function with the
18  prior to obtain the posterior function.
19 
20  By using the Markov-Chain Monte Carlo methods this calculator can work with
21  model which require the integration of a large number of parameters.
22 
23  MCMCCalculator is a concrete implementation of IntervalCalculator. It uses a
24  MetropolisHastings object to construct a Markov Chain of data points in the
25  parameter space. From this Markov Chain, this class can generate a
26  MCMCInterval as per user specification.
27 
28  The interface allows one to pass the model, data, and parameters via a
29  workspace and then specify them with names.
30 
31  After configuring the calculator, one only needs to ask GetInterval(), which
32  will return an ConfInterval (MCMCInterval in this case).
33  */
34 
35 #include "Rtypes.h"
36 #include "RooGlobalFunc.h"
37 #include "RooAbsReal.h"
38 #include "RooArgSet.h"
39 #include "RooArgList.h"
40 #include "RooStats/ModelConfig.h"
41 #include "RooStats/RooStatsUtils.h"
44 #include "RooStats/MarkovChain.h"
45 #include "RooStats/MCMCInterval.h"
46 #include "TIterator.h"
48 #include "RooStats/PdfProposal.h"
49 #include "RooProdPdf.h"
50 
51 ClassImp(RooStats::MCMCCalculator);
52 
53 using namespace RooFit;
54 using namespace RooStats;
55 using namespace std;
56 
57 ////////////////////////////////////////////////////////////////////////////////
58 /// default constructor
59 
60 MCMCCalculator::MCMCCalculator() :
61  fPropFunc(0),
62  fPdf(0),
63  fPriorPdf(0),
64  fData(0),
65  fAxes(0)
66 {
67  fNumIters = 0;
68  fNumBurnInSteps = 0;
69  fNumBins = 0;
70  fUseKeys = kFALSE;
71  fUseSparseHist = kFALSE;
72  fSize = -1;
73  fIntervalType = MCMCInterval::kShortest;
74  fLeftSideTF = -1;
75  fEpsilon = -1;
76  fDelta = -1;
77 }
78 
79 ////////////////////////////////////////////////////////////////////////////////
80 /// constructor from a Model Config with a basic settings package configured
81 /// by SetupBasicUsage()
82 
83 MCMCCalculator::MCMCCalculator(RooAbsData& data, const ModelConfig & model) :
84  fPropFunc(0),
85  fData(&data),
86  fAxes(0)
87 {
88  SetModel(model);
89  SetupBasicUsage();
90 }
91 
92 void MCMCCalculator::SetModel(const ModelConfig & model) {
93  // set the model
94  fPdf = model.GetPdf();
95  fPriorPdf = model.GetPriorPdf();
96  fPOI.removeAll();
97  fNuisParams.removeAll();
98  fConditionalObs.removeAll();
99  fGlobalObs.removeAll();
100  if (model.GetParametersOfInterest())
101  fPOI.add(*model.GetParametersOfInterest());
102  if (model.GetNuisanceParameters())
103  fNuisParams.add(*model.GetNuisanceParameters());
104  if (model.GetConditionalObservables())
105  fConditionalObs.add( *(model.GetConditionalObservables() ) );
106  if (model.GetGlobalObservables())
107  fGlobalObs.add( *(model.GetGlobalObservables() ) );
108 
109 }
110 
111 ////////////////////////////////////////////////////////////////////////////////
112 /// Constructor for automatic configuration with basic settings. Uses a
113 /// UniformProposal, 10,000 iterations, 40 burn in steps, 50 bins for each
114 /// RooRealVar, determines interval by histogram. Finds a 95% confidence
115 /// interval.
116 
117 void MCMCCalculator::SetupBasicUsage()
118 {
119  fPropFunc = 0;
120  fNumIters = 10000;
121  fNumBurnInSteps = 40;
122  fNumBins = 50;
123  fUseKeys = kFALSE;
124  fUseSparseHist = kFALSE;
125  SetTestSize(0.05);
126  fIntervalType = MCMCInterval::kShortest;
127  fLeftSideTF = -1;
128  fEpsilon = -1;
129  fDelta = -1;
130 }
131 
132 ////////////////////////////////////////////////////////////////////////////////
133 
134 void MCMCCalculator::SetLeftSideTailFraction(Double_t a)
135 {
136  if (a < 0 || a > 1) {
137  coutE(InputArguments) << "MCMCCalculator::SetLeftSideTailFraction: "
138  << "Fraction must be in the range [0, 1]. "
139  << a << "is not allowed." << endl;
140  return;
141  }
142 
143  fLeftSideTF = a;
144  fIntervalType = MCMCInterval::kTailFraction;
145 }
146 
147 ////////////////////////////////////////////////////////////////////////////////
148 /// Main interface to get a RooStats::ConfInterval.
149 
150 MCMCInterval* MCMCCalculator::GetInterval() const
151 {
152 
153  if (!fData || !fPdf ) return 0;
154  if (fPOI.getSize() == 0) return 0;
155 
156  if (fSize < 0) {
157  coutE(InputArguments) << "MCMCCalculator::GetInterval: "
158  << "Test size/Confidence level not set. Returning NULL." << endl;
159  return NULL;
160  }
161 
162  // if a proposal function has not been specified create a default one
163  bool useDefaultPropFunc = (fPropFunc == 0);
164  bool usePriorPdf = (fPriorPdf != 0);
165  if (useDefaultPropFunc) fPropFunc = new UniformProposal();
166 
167  // if prior is given create product
168  RooAbsPdf * prodPdf = fPdf;
169  if (usePriorPdf) {
170  TString prodName = TString("product_") + TString(fPdf->GetName()) + TString("_") + TString(fPriorPdf->GetName() );
171  prodPdf = new RooProdPdf(prodName,prodName,RooArgList(*fPdf,*fPriorPdf) );
172  }
173 
174  RooArgSet* constrainedParams = prodPdf->getParameters(*fData);
175  RooAbsReal* nll = prodPdf->createNLL(*fData, Constrain(*constrainedParams),ConditionalObservables(fConditionalObs),GlobalObservables(fGlobalObs));
176  delete constrainedParams;
177 
178  RooArgSet* params = nll->getParameters(*fData);
179  RemoveConstantParameters(params);
180  if (fNumBins > 0) {
181  SetBins(*params, fNumBins);
182  SetBins(fPOI, fNumBins);
183  if (dynamic_cast<PdfProposal*>(fPropFunc)) {
184  RooArgSet* proposalVars = ((PdfProposal*)fPropFunc)->GetPdf()->
185  getParameters((RooAbsData*)NULL);
186  SetBins(*proposalVars, fNumBins);
187  }
188  }
189 
190  MetropolisHastings mh;
191  mh.SetFunction(*nll);
192  mh.SetType(MetropolisHastings::kLog);
193  mh.SetSign(MetropolisHastings::kNegative);
194  mh.SetParameters(*params);
195  if (fChainParams.getSize() > 0) mh.SetChainParameters(fChainParams);
196  mh.SetProposalFunction(*fPropFunc);
197  mh.SetNumIters(fNumIters);
198 
199  MarkovChain* chain = mh.ConstructChain();
200 
201  TString name = TString("MCMCInterval_") + TString(GetName() );
202  MCMCInterval* interval = new MCMCInterval(name, fPOI, *chain);
203  if (fAxes != NULL)
204  interval->SetAxes(*fAxes);
205  if (fNumBurnInSteps > 0)
206  interval->SetNumBurnInSteps(fNumBurnInSteps);
207  interval->SetUseKeys(fUseKeys);
208  interval->SetUseSparseHist(fUseSparseHist);
209  interval->SetIntervalType(fIntervalType);
210  if (fIntervalType == MCMCInterval::kTailFraction)
211  interval->SetLeftSideTailFraction(fLeftSideTF);
212  if (fEpsilon >= 0)
213  interval->SetEpsilon(fEpsilon);
214  if (fDelta >= 0)
215  interval->SetDelta(fDelta);
216  interval->SetConfidenceLevel(1.0 - fSize);
217 
218  if (useDefaultPropFunc) delete fPropFunc;
219  if (usePriorPdf) delete prodPdf;
220  delete nll;
221  delete params;
222 
223  return interval;
224 }