63 ClassImp(RooStats::MetropolisHastings);
65 using namespace RooFit;
66 using namespace RooStats;
71 MetropolisHastings::MetropolisHastings()
84 MetropolisHastings::MetropolisHastings(RooAbsReal&
function,
const RooArgSet& paramsOfInterest,
85 ProposalFunction& proposalFunction, Int_t numIters)
87 fFunction = &
function;
88 SetParameters(paramsOfInterest);
89 SetProposalFunction(proposalFunction);
98 MarkovChain* MetropolisHastings::ConstructChain()
100 if (fParameters.getSize() == 0 || !fPropFunc || !fFunction) {
101 coutE(Eval) <<
"Critical members unintialized: parameters, proposal " <<
102 " function, or (log) likelihood function" << endl;
105 if (fSign == kSignUnset || fType == kTypeUnset) {
106 coutE(Eval) <<
"Please set type and sign of your function using "
107 <<
"MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
112 if (fChainParams.getSize() == 0) fChainParams.add(fParameters);
116 x.addClone(fParameters);
117 RandomizeCollection(x);
118 xPrime.addClone(fParameters);
119 RandomizeCollection(xPrime);
121 MarkovChain* chain =
new MarkovChain();
123 chain->SetParameters(fChainParams);
126 Double_t xL = 0.0, xPrimeL = 0.0, a = 0.0;
131 RooFit::MsgLevel oldMsgLevel = RooMsgService::instance().globalKillBelow();
132 RooMsgService::instance().setGlobalKillBelow(RooFit::PROGRESS);
137 RooAbsReal::setEvalErrorLoggingMode(RooAbsReal::CountErrors);
140 RooAbsReal::clearEvalErrorLog();
143 bool hadEvalError =
true;
153 while (i < 1000 && hadEvalError) {
154 RandomizeCollection(x);
155 RooStats::SetParameters(&x, &fParameters);
156 xL = fFunction->getVal();
159 if (RooAbsReal::numEvalErrors() > 0) {
160 RooAbsReal::clearEvalErrorLog();
163 hadEvalError =
false;
164 }
else if (fType == kRegular) {
168 hadEvalError =
false;
171 hadEvalError =
false;
176 coutE(Eval) <<
"Problem finding a good starting point in " <<
177 "MetropolisHastings::ConstructChain() " << endl;
181 ooccoutP((TObject *)0, Generation) <<
"Metropolis-Hastings progress: ";
184 for (i = 0; i < fNumIters; i++) {
186 hadEvalError =
false;
189 if (i % (fNumIters / 100) == 0) ooccoutP((TObject*)0, Generation) <<
".";
191 fPropFunc->Propose(xPrime, x);
193 RooStats::SetParameters(&xPrime, &fParameters);
194 xPrimeL = fFunction->getVal();
197 if (fFunction->numEvalErrors() > 0 && fType == kLog) {
198 xPrimeL = RooNumber::infinity();
199 fFunction->clearEvalErrorLog();
209 if (fSign == kPositive)
218 if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
219 Double_t xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
220 Double_t xPD = fPropFunc->GetProposalDensity(x, xPrime);
221 if (fType == kRegular)
224 a += TMath::Log(xPrimePD) - TMath::Log(xPD);
227 if (!hadEvalError && ShouldTakeStep(a)) {
232 chain->Add(x, CalcNLL(xL), (Double_t)weight);
236 RooStats::SetParameters(&xPrime, &x);
246 chain->Add(x, CalcNLL(xL), (Double_t)weight);
247 ooccoutP((TObject *)0, Generation) << endl;
249 RooMsgService::instance().setGlobalKillBelow(oldMsgLevel);
251 Int_t numAccepted = chain->Size();
252 coutI(Eval) <<
"Proposal acceptance rate: " <<
253 numAccepted/(Float_t)fNumIters * 100 <<
"%" << endl;
254 coutI(Eval) <<
"Number of steps in chain: " << numAccepted << endl;
265 Bool_t MetropolisHastings::ShouldTakeStep(Double_t a)
267 if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
276 Double_t rand = RooRandom::uniform();
278 rand = TMath::Log(rand);
281 if (-1.0 * rand >= a)
300 Double_t MetropolisHastings::CalcNLL(Double_t xL)
303 if (fSign == kNegative)
308 if (fSign == kPositive)
309 return -1.0 * TMath::Log(xL);
311 return -1.0 * TMath::Log(-xL);