50 ClassImp(RooStats::HistFactory::RooBarlowBeestonLL);
55 RooStats::HistFactory::RooBarlowBeestonLL::RooBarlowBeestonLL() :
56 RooAbsReal(
"RooBarlowBeestonLL",
"RooBarlowBeestonLL"),
60 _pdf(NULL), _data(NULL)
71 RooStats::HistFactory::RooBarlowBeestonLL::RooBarlowBeestonLL(
const char *name,
const char *title,
73 RooAbsReal(name,title),
74 _nll(
"input",
"-log(L) function",this,nllIn),
77 _pdf(NULL), _data(NULL)
104 RooStats::HistFactory::RooBarlowBeestonLL::RooBarlowBeestonLL(
const RooBarlowBeestonLL& other,
const char* name) :
105 RooAbsReal(other,name),
106 _nll(
"nll",this,other._nll),
109 _pdf(NULL), _data(NULL),
110 _paramFixed(other._paramFixed)
127 RooStats::HistFactory::RooBarlowBeestonLL::~RooBarlowBeestonLL()
142 void RooStats::HistFactory::RooBarlowBeestonLL::BarlowCache::SetBinCenter()
const {
143 TIterator* iter = bin_center->createIterator() ;
145 while((var=(RooRealVar*)iter->Next())) {
146 RooRealVar* target = (RooRealVar*) observables->find(var->GetName()) ;
147 target->setVal(var->getVal()) ;
155 void RooStats::HistFactory::RooBarlowBeestonLL::initializeBarlowCache() {
159 std::cout <<
"Error: Must initialize data before initializing cache" << std::endl;
160 throw std::runtime_error(
"Uninitialized Data");
163 std::cout <<
"Error: Must initialize model pdf before initializing cache" << std::endl;
164 throw std::runtime_error(
"Uninitialized model pdf");
168 std::map< std::string, std::vector<double> > ChannelBinDataMap;
169 getDataValuesForObservables( ChannelBinDataMap, _data, _pdf );
173 RooArgList constraints;
174 RooArgSet* obsSet = _pdf->getObservables(*_data);
175 FactorizeHistFactoryPdf(*obsSet, *_pdf, obsTerms, constraints);
177 if( obsTerms.getSize() == 0 ) {
178 std::cout <<
"Error: Found no observable terms with pdf: " << _pdf->GetName()
179 <<
" using dataset: " << _data->GetName() << std::endl;
182 if( constraints.getSize() == 0 ) {
183 std::cout <<
"Error: Found no constraint terms with pdf: " << _pdf->GetName()
184 <<
" using dataset: " << _data->GetName() << std::endl;
196 RooSimultaneous* simPdf = (RooSimultaneous*) _pdf;
197 RooCategory* channelCat = (RooCategory*) (&simPdf->indexCat());
198 TIterator* iter = channelCat->typeIterator() ;
199 RooCatType* tt = NULL;
200 while((tt=(RooCatType*) iter->Next())) {
215 RooAbsPdf* channelPdf = simPdf->getPdf(tt->GetName());
216 std::string channel_name = channelPdf->GetName();
219 RooArgList* gammas =
new RooArgList();
220 ParamHistFunc* param_func=NULL;
221 bool hasStatUncert = getStatUncertaintyFromChannel( channelPdf, param_func, gammas );
222 if( ! hasStatUncert ) {
224 std::cout <<
"Channel: " << channel_name
225 <<
" doesn't have statistical uncertainties"
231 if(verbose) std::cout <<
"Found ParamHistFunc: " << param_func->GetName() << std::endl;
238 int num_bins = param_func->numBins();
243 std::vector<BarlowCache> temp_cache( num_bins );
244 bool channel_has_stat_uncertainty=
false;
246 for( Int_t bin_index = 0; bin_index < num_bins; ++bin_index ) {
252 RooRealVar* gamma_stat = &(param_func->getParameter(bin_index));
253 if( gamma_stat->isConstant() ) {
254 if(verbose) std::cout <<
"Ignoring constant gamma: " << gamma_stat->GetName() << std::endl;
258 cache.hasStatUncert=
true;
259 channel_has_stat_uncertainty=
true;
260 cache.gamma = gamma_stat;
261 _statUncertParams.insert( gamma_stat->GetName() );
265 RooArgSet* bin_center = (RooArgSet*) param_func->get( bin_index )->snapshot();
266 cache.bin_center = bin_center;
267 cache.observables = obsSet;
269 cache.binVolume = param_func->binVolume();
272 RooArgList obs_list( *(cache.bin_center) );
275 RooAbsReal* pois_mean = NULL;
276 RooRealVar* tau = NULL;
277 getStatUncertaintyConstraintTerm( &constraints, gamma_stat, pois_mean, tau );
278 if( !tau || !pois_mean ) {
279 std::cout <<
"Failed to find pois mean or tau parameter for " << gamma_stat->GetName() << std::endl;
282 if(verbose) std::cout <<
"Found pois mean and tau for parameter: " << gamma_stat->GetName()
283 <<
" tau: " << tau->GetName() <<
" " << tau->getVal()
284 <<
" pois_mean: " << pois_mean->GetName() <<
" " << pois_mean->getVal()
289 cache.nom_pois_mean = pois_mean;
292 RooAbsPdf* sum_pdf = getSumPdfFromChannel( channelPdf );
293 if( sum_pdf == NULL ) {
294 std::cout <<
"Failed to find RooRealSumPdf in channel " << channel_name
295 <<
", therefor skipping this channel for analytic uncertainty minimization"
297 channel_has_stat_uncertainty=
false;
300 cache.sumPdf = sum_pdf;
303 if( ChannelBinDataMap.find(channel_name) == ChannelBinDataMap.end() ) {
304 std::cout <<
"Error: channel with name: " << channel_name
305 <<
" not found in BinDataMap" << std::endl;
306 throw runtime_error(
"BinDataMap");
308 double nData = ChannelBinDataMap[channel_name].at(bin_index);
311 temp_cache.at(bin_index) = cache;
316 if( channel_has_stat_uncertainty ) {
317 std::cout <<
"Adding channel: " << channel_name
318 <<
" to the barlow cache" << std::endl;
319 _barlowCache[channel_name] = temp_cache;
374 RooArgSet* RooStats::HistFactory::RooBarlowBeestonLL::getParameters(
const RooArgSet* depList, Bool_t stripDisconnected)
const {
375 RooArgSet* allArgs = RooAbsArg::getParameters( depList, stripDisconnected );
377 TIterator* iter_args = allArgs->createIterator();
379 while((arg=(RooRealVar*)iter_args->Next())) {
380 std::string arg_name = arg->GetName();
385 if( _statUncertParams.find(arg_name.c_str()) != _statUncertParams.end() ) {
386 allArgs->remove( *arg, kTRUE );
399 const RooArgSet& RooStats::HistFactory::RooBarlowBeestonLL::bestFitParams() const
402 return _paramAbsMin ;
408 const RooArgSet& RooStats::HistFactory::RooBarlowBeestonLL::bestFitObs() const
462 Double_t RooStats::HistFactory::RooBarlowBeestonLL::evaluate()
const
480 std::map< std::string, std::vector< BarlowCache > >::iterator iter_cache;
481 for( iter_cache = _barlowCache.begin(); iter_cache != _barlowCache.end(); ++iter_cache ) {
483 std::string channel_name = (*iter_cache).first;
484 std::vector< BarlowCache >& channel_cache = (*iter_cache).second;
499 for(
unsigned int i = 0; i < channel_cache.size(); ++i ) {
500 BarlowCache& bin_cache = channel_cache.at(i);
501 if( !bin_cache.hasStatUncert )
continue;
502 RooRealVar* gamma = bin_cache.gamma;
505 std::vector< double > nu_b_vec( channel_cache.size() );
506 for(
unsigned int i = 0; i < channel_cache.size(); ++i ) {
507 BarlowCache& bin_cache = channel_cache.at(i);
508 if( !bin_cache.hasStatUncert )
continue;
510 RooAbsPdf* sum_pdf = (RooAbsPdf*) bin_cache.sumPdf;
511 RooArgSet* obsSet = bin_cache.observables;
512 double binVolume = bin_cache.binVolume;
514 bin_cache.SetBinCenter();
515 double nu_b = sum_pdf->getVal(*obsSet)*sum_pdf->expectedEvents(*obsSet)*binVolume;
516 nu_b_vec.at(i) = nu_b;
521 for(
unsigned int i = 0; i < channel_cache.size(); ++i ) {
522 BarlowCache& bin_cache = channel_cache.at(i);
523 if( !bin_cache.hasStatUncert )
continue;
524 RooRealVar* gamma = bin_cache.gamma;
527 std::vector< double > nu_b_stat_vec( channel_cache.size() );
528 for(
unsigned int i = 0; i < channel_cache.size(); ++i ) {
529 BarlowCache& bin_cache = channel_cache.at(i);
530 if( !bin_cache.hasStatUncert )
continue;
532 RooAbsPdf* sum_pdf = (RooAbsPdf*) bin_cache.sumPdf;
533 RooArgSet* obsSet = bin_cache.observables;
534 double binVolume = bin_cache.binVolume;
536 bin_cache.SetBinCenter();
537 double nu_b_stat = sum_pdf->getVal(*obsSet)*sum_pdf->expectedEvents(*obsSet)*binVolume - nu_b_vec.at(i);
538 nu_b_stat_vec.at(i) = nu_b_stat;
549 for(
unsigned int i = 0; i < channel_cache.size(); ++i ) {
551 BarlowCache& bin_cache = channel_cache.at(i);
553 if( !bin_cache.hasStatUncert ) {
561 bin_cache.SetBinCenter();
564 RooRealVar* gamma = bin_cache.gamma;
565 RooRealVar* tau = bin_cache.tau;
566 RooAbsReal* pois_mean = bin_cache.nom_pois_mean;
573 double nu_b = nu_b_vec.at(i);
574 double nu_b_stat = nu_b_stat_vec.at(i);
576 double tau_val = tau->getVal();
577 double nData = bin_cache.nData;
578 double m_val = pois_mean->getVal();
581 double gamma_hat_hat = 1.0;
584 if(nu_b_stat > 0.00000001) {
586 double A = nu_b_stat*nu_b_stat + tau_val*nu_b_stat;
587 double B = nu_b*tau_val + nu_b*nu_b_stat - nData*nu_b_stat - m_val*nu_b_stat;
588 double C = -1*m_val*nu_b;
590 double discrim = B*B-4*A*C;
593 std::cout <<
"Warning: Discriminant (B*B - 4AC) < 0" << std::endl;
594 std::cout <<
"Warning: Taking B*B - 4*A*C == 0" << std::endl;
599 std::cout <<
"Warning: A <= 0" << std::endl;
600 throw runtime_error(
"BarlowBeestonLL::evaluate() : A < 0");
603 gamma_hat_hat = ( -1*B + TMath::Sqrt(discrim) ) / (2*A);
609 gamma_hat_hat = m_val/tau_val;
613 if( TMath::IsNaN(gamma_hat_hat) ) {
614 std::cout <<
"ERROR: gamma hat hat is NAN" << std::endl;
615 throw runtime_error(
"BarlowBeestonLL::evaluate() : gamma hat hat is NAN");
618 if( gamma_hat_hat <= 0 ) {
619 std::cout <<
"WARNING: gamma hat hat <= 0. Setting to 0" << std::endl;
636 gamma->setVal( gamma_hat_hat );
666 void RooStats::HistFactory::RooBarlowBeestonLL::validateAbsMin() const
668 // Check if constant status of any of the parameters have changed
672 while((par=(RooAbsArg*)_piter->Next())) {
673 if (_paramFixed[par->GetName()] != par->isConstant()) {
674 cxcoutI(Minimization) << "RooStats::HistFactory::RooBarlowBeestonLL::evaluate(" << GetName() << ") constant status of parameter " << par->GetName() << " has changed from "
675 << (_paramFixed[par->GetName()]?"fixed":"floating") << " to " << (par->isConstant()?"fixed":"floating")
676 << ", recalculating absolute minimum" << endl ;
677 _absMinValid = kFALSE ;
684 // If we don't have the absolute minimum w.r.t all observables, calculate that first
687 cxcoutI(Minimization) << "RooStats::HistFactory::RooBarlowBeestonLL::evaluate(" << GetName() << ") determining minimum likelihood for current configurations w.r.t all observable" << endl ;
690 // Save current values of non-marginalized parameters
691 RooArgSet* obsStart = (RooArgSet*) _obs.snapshot(kFALSE) ;
693 // Start from previous global minimum
694 if (_paramAbsMin.getSize()>0) {
695 const_cast<RooSetProxy&>(_par).assignValueOnly(_paramAbsMin) ;
697 if (_obsAbsMin.getSize()>0) {
698 const_cast<RooSetProxy&>(_obs).assignValueOnly(_obsAbsMin) ;
701 // Find minimum with all observables floating
702 const_cast<RooSetProxy&>(_obs).setAttribAll("Constant",kFALSE) ;
705 // Save value and remember
707 _absMinValid = kTRUE ;
709 // Save parameter values at abs minimum as well
710 _paramAbsMin.removeAll() ;
712 // Only store non-constant parameters here!
713 RooArgSet* tmp = (RooArgSet*) _par.selectByAttrib("Constant",kFALSE) ;
714 _paramAbsMin.addClone(*tmp) ;
717 _obsAbsMin.addClone(_obs) ;
719 // Save constant status of all parameters
722 while((par=(RooAbsArg*)_piter->Next())) {
723 _paramFixed[par->GetName()] = par->isConstant() ;
726 if (dologI(Minimization)) {
727 cxcoutI(Minimization) << "RooStats::HistFactory::RooBarlowBeestonLL::evaluate(" << GetName() << ") minimum found at (" ;
732 while ((arg=(RooAbsReal*)_oiter->Next())) {
733 ccxcoutI(Minimization) << (first?"":", ") << arg->GetName() << "=" << arg->getVal() ;
736 ccxcoutI(Minimization) << ")" << endl ;
739 // Restore original parameter values
740 const_cast<RooSetProxy&>(_obs) = *obsStart ;
750 Bool_t RooStats::HistFactory::RooBarlowBeestonLL::redirectServersHook(
const RooAbsCollection& , Bool_t ,