134 using namespace std ;
136 ClassImp(RooFFTConvPdf);
151 RooFFTConvPdf::RooFFTConvPdf(
const char *name,
const char *title, RooRealVar& convVar, RooAbsPdf& pdf1, RooAbsPdf& pdf2, Int_t ipOrder) :
152 RooAbsCachedPdf(name,title,ipOrder),
153 _x(
"!x",
"Convolution Variable",this,convVar),
154 _xprime(
"!xprime",
"External Convolution Variable",this,0),
155 _pdf1(
"!pdf1",
"pdf1",this,pdf1,kFALSE),
156 _pdf2(
"!pdf2",
"pdf2",this,pdf2,kFALSE),
157 _params(
"!params",
"effective parameters",this),
162 _cacheObs(
"!cacheObs",
"Cached observables",this,kFALSE,kFALSE)
164 prepareFFTBinning(convVar);
166 _shift2 = (convVar.getMax(
"cache")+convVar.getMin(
"cache"))/2 ;
177 RooFFTConvPdf::RooFFTConvPdf(
const char *name,
const char *title, RooAbsReal& pdfConvVar, RooRealVar& convVar, RooAbsPdf& pdf1, RooAbsPdf& pdf2, Int_t ipOrder) :
178 RooAbsCachedPdf(name,title,ipOrder),
179 _x(
"!x",
"Convolution Variable",this,convVar,kFALSE,kFALSE),
180 _xprime(
"!xprime",
"External Convolution Variable",this,pdfConvVar),
181 _pdf1(
"!pdf1",
"pdf1",this,pdf1,kFALSE),
182 _pdf2(
"!pdf2",
"pdf2",this,pdf2,kFALSE),
183 _params(
"!params",
"effective parameters",this),
188 _cacheObs(
"!cacheObs",
"Cached observables",this,kFALSE,kFALSE)
190 prepareFFTBinning(convVar);
192 _shift2 = (convVar.getMax(
"cache")+convVar.getMin(
"cache"))/2 ;
202 RooFFTConvPdf::RooFFTConvPdf(
const RooFFTConvPdf& other,
const char* name) :
203 RooAbsCachedPdf(other,name),
204 _x(
"!x",this,other._x),
205 _xprime(
"!xprime",this,other._xprime),
206 _pdf1(
"!pdf1",this,other._pdf1),
207 _pdf2(
"!pdf2",this,other._pdf2),
208 _params(
"!params",this,other._params),
209 _bufFrac(other._bufFrac),
210 _bufStrat(other._bufStrat),
211 _shift1(other._shift1),
212 _shift2(other._shift2),
213 _cacheObs(
"!cacheObs",this,other._cacheObs)
222 RooFFTConvPdf::~RooFFTConvPdf()
232 void RooFFTConvPdf::prepareFFTBinning(RooRealVar& convVar)
const {
233 if (!convVar.hasBinning(
"cache")) {
234 const RooAbsBinning& varBinning = convVar.getBinning();
235 const int optimal =
static_cast<Int_t
>(1024/(1.+_bufFrac));
238 if (varBinning.numBins() < optimal && varBinning.isUniform()) {
239 coutI(Caching) <<
"Changing internal binning of variable '" << convVar.GetName()
240 <<
"' in FFT '" << fName <<
"'"
241 <<
" from " << varBinning.numBins()
242 <<
" to " << optimal <<
" to improve the precision of the numerical FFT."
243 <<
" This can be done manually by setting an additional binning named 'cache'." << std::endl;
244 convVar.setBinning(RooUniformBinning(varBinning.lowBound(), varBinning.highBound(), optimal,
"cache"),
"cache");
246 coutE(Caching) <<
"The internal binning of variable " << convVar.GetName()
247 <<
" is not uniform. The numerical FFT will likely yield wrong results." << std::endl;
248 convVar.setBinning(varBinning,
"cache");
257 const char* RooFFTConvPdf::inputBaseName()
const
259 static TString name ;
260 name = _pdf1.arg().GetName() ;
261 name.Append(
"_CONV_") ;
262 name.Append(_pdf2.arg().GetName()) ;
272 RooFFTConvPdf::PdfCacheElem* RooFFTConvPdf::createCache(
const RooArgSet* nset)
const
274 return new FFTCacheElem(*
this,nset) ;
283 RooFFTConvPdf::FFTCacheElem::FFTCacheElem(
const RooFFTConvPdf&
self,
const RooArgSet* nsetIn) :
284 PdfCacheElem(self,nsetIn),
285 fftr2c1(0),fftr2c2(0),fftc2r(0)
287 RooAbsPdf* clonePdf1 = (RooAbsPdf*)
self._pdf1.arg().cloneTree() ;
288 RooAbsPdf* clonePdf2 = (RooAbsPdf*)
self._pdf2.arg().cloneTree() ;
289 clonePdf1->attachDataSet(*hist()) ;
290 clonePdf2->attachDataSet(*hist()) ;
293 RooRealVar* convObs = (RooRealVar*) hist()->get()->find(
self._x.arg().GetName()) ;
296 string refName = Form(
"refrange_fft_%s",
self.GetName()) ;
297 convObs->setRange(refName.c_str(),convObs->getMin(),convObs->getMax()) ;
299 if (
self._shift1!=0) {
300 RooLinearVar* shiftObs1 =
new RooLinearVar(Form(
"%s_shifted_FFTBuffer1",convObs->GetName()),
"shiftObs1",
301 *convObs,RooFit::RooConst(1),RooFit::RooConst(-1*
self._shift1)) ;
303 RooArgSet clonedBranches1 ;
304 RooCustomizer cust(*clonePdf1,
"fft") ;
305 cust.replaceArg(*convObs,*shiftObs1) ;
307 pdf1Clone = (RooAbsPdf*) cust.build() ;
309 pdf1Clone->addOwnedComponents(*shiftObs1) ;
310 pdf1Clone->addOwnedComponents(*clonePdf1) ;
313 pdf1Clone = clonePdf1 ;
316 if (
self._shift2!=0) {
317 RooLinearVar* shiftObs2 =
new RooLinearVar(Form(
"%s_shifted_FFTBuffer2",convObs->GetName()),
"shiftObs2",
318 *convObs,RooFit::RooConst(1),RooFit::RooConst(-1*
self._shift2)) ;
320 RooArgSet clonedBranches2 ;
321 RooCustomizer cust(*clonePdf2,
"fft") ;
322 cust.replaceArg(*convObs,*shiftObs2) ;
324 pdf1Clone->addOwnedComponents(*shiftObs2) ;
325 pdf1Clone->addOwnedComponents(*clonePdf2) ;
327 pdf2Clone = (RooAbsPdf*) cust.build() ;
330 pdf2Clone = clonePdf2 ;
335 RooArgSet* fftParams =
self.getParameters(*convObs) ;
339 fftParams->remove(*hist()->
get(),kTRUE,kTRUE) ;
341 pdf1Clone->recursiveRedirectServers(*fftParams) ;
342 pdf2Clone->recursiveRedirectServers(*fftParams) ;
343 pdf1Clone->fixAddCoefRange(refName.c_str(),
true) ;
344 pdf2Clone->fixAddCoefRange(refName.c_str(),
true) ;
347 RooArgSet convSet(
self._x.arg());
348 pdf1Clone->fixAddCoefNormalization(convSet,
true);
349 pdf2Clone->fixAddCoefNormalization(convSet,
true);
356 const Int_t N = convObs->numBins();
358 oocoutW(&
self, Eval) <<
"The FFT convolution '" <<
self.GetName() <<
"' will run with " << N
359 <<
" bins. A decent accuracy for difficult convolutions is typically only reached with n >= 1000. Suggest to increase the number"
360 " of bins of the observable '" << convObs->GetName() <<
"'." << std::endl;
362 Int_t Nbuf =
static_cast<Int_t
>((N*
self.bufferFraction())/2 + 0.5) ;
363 Double_t obw = (convObs->getMax() - convObs->getMin())/N ;
364 Int_t N2 = N+2*Nbuf ;
366 scanBinning =
new RooUniformBinning (convObs->getMin()-Nbuf*obw,convObs->getMax()+Nbuf*obw,N2) ;
367 histBinning = convObs->getBinning().clone() ;
371 hist()->setDirtyProp(kFALSE) ;
372 convObs->setOperMode(ADirty,kTRUE) ;
379 TString RooFFTConvPdf::histNameSuffix()
const
381 return TString(Form(
"_BufFrac%3.1f_BufStrat%d",_bufFrac,_bufStrat)) ;
389 RooArgList RooFFTConvPdf::FFTCacheElem::containedArgs(Action a)
391 RooArgList ret(PdfCacheElem::containedArgs(a)) ;
393 ret.add(*pdf1Clone) ;
394 ret.add(*pdf2Clone) ;
395 if (pdf1Clone->ownedComponents()) {
396 ret.add(*pdf1Clone->ownedComponents()) ;
398 if (pdf2Clone->ownedComponents()) {
399 ret.add(*pdf2Clone->ownedComponents()) ;
408 RooFFTConvPdf::FFTCacheElem::~FFTCacheElem()
428 void RooFFTConvPdf::fillCacheObject(RooAbsCachedPdf::PdfCacheElem& cache)
const
430 RooDataHist& cacheHist = *cache.hist() ;
432 ((FFTCacheElem&)cache).pdf1Clone->setOperMode(ADirty,kTRUE) ;
433 ((FFTCacheElem&)cache).pdf2Clone->setOperMode(ADirty,kTRUE) ;
437 RooArgSet(*cacheHist.get()).snapshot(otherObs) ;
439 RooAbsArg* histArg = otherObs.find(_x.arg().GetName()) ;
441 otherObs.remove(*histArg,kTRUE,kTRUE) ;
448 if (otherObs.getSize()==0) {
449 fillCacheSlice((FFTCacheElem&)cache,RooArgSet()) ;
457 Int_t n = otherObs.getSize() ;
458 Int_t* binCur =
new Int_t[n+1] ;
459 Int_t* binMax =
new Int_t[n+1] ;
462 RooAbsLValue** obsLV =
new RooAbsLValue*[n] ;
463 TIterator* iter = otherObs.createIterator() ;
466 while((arg=(RooAbsArg*)iter->Next())) {
467 RooAbsLValue* lvarg =
dynamic_cast<RooAbsLValue*
>(arg) ;
471 binMax[i] = lvarg->numBins(binningName())-1 ;
479 for (Int_t j=0 ; j<n ; j++) { obsLV[j]->setBin(binCur[j],binningName()) ; }
484 fillCacheSlice((FFTCacheElem&)cache,otherObs) ;
487 while(binCur[curObs]==binMax[curObs]) {
516 void RooFFTConvPdf::fillCacheSlice(FFTCacheElem& aux,
const RooArgSet& slicePos)
const
519 RooDataHist& cacheHist = *aux.hist() ;
532 Int_t N,N2,binShift1,binShift2 ;
534 RooRealVar* histX = (RooRealVar*) cacheHist.get()->find(_x.arg().GetName()) ;
535 if (_bufStrat==Extend) histX->setBinning(*aux.scanBinning) ;
536 Double_t* input1 = scanPdf((RooRealVar&)_x.arg(),*aux.pdf1Clone,cacheHist,slicePos,N,N2,binShift1,_shift1) ;
537 Double_t* input2 = scanPdf((RooRealVar&)_x.arg(),*aux.pdf2Clone,cacheHist,slicePos,N,N2,binShift2,_shift2) ;
538 if (_bufStrat==Extend) histX->setBinning(*aux.histBinning) ;
545 aux.fftr2c1 = TVirtualFFT::FFT(1, &N2,
"R2CK");
546 aux.fftr2c2 = TVirtualFFT::FFT(1, &N2,
"R2CK");
547 aux.fftc2r = TVirtualFFT::FFT(1, &N2,
"C2RK");
551 aux.fftr2c1->SetPoints(input1);
552 aux.fftr2c1->Transform();
555 aux.fftr2c2->SetPoints(input2);
556 aux.fftr2c2->Transform();
560 for (Int_t i=0 ; i<N2/2+1 ; i++) {
561 Double_t re1,re2,im1,im2 ;
562 aux.fftr2c1->GetPointComplex(i,re1,im1) ;
563 aux.fftr2c2->GetPointComplex(i,re2,im2) ;
564 Double_t re = re1*re2 - im1*im2 ;
565 Double_t im = re1*im2 + re2*im1 ;
567 aux.fftc2r->SetPointComplex(i,t) ;
571 aux.fftc2r->Transform() ;
573 Int_t totalShift = binShift1 + (N2-N)/2 ;
577 TIterator* iter =
const_cast<RooDataHist&
>(cacheHist).sliceIterator(const_cast<RooAbsReal&>(_x.arg()),slicePos) ;
578 for (Int_t i =0 ; i<N ; i++) {
581 Int_t j = i + totalShift ;
583 while (j>=N2) j-= N2 ;
586 cacheHist.set(aux.fftc2r->GetPointReal(j)) ;
605 Double_t* RooFFTConvPdf::scanPdf(RooRealVar& obs, RooAbsPdf& pdf,
const RooDataHist& hist,
const RooArgSet& slicePos,
606 Int_t& N, Int_t& N2, Int_t& zeroBin, Double_t shift)
const
609 RooRealVar* histX = (RooRealVar*) hist.get()->find(obs.GetName()) ;
612 N = histX->numBins(binningName()) ;
613 Int_t Nbuf =
static_cast<Int_t
>((N*bufferFraction())/2 + 0.5) ;
618 Double_t* array =
new Double_t[N2] ;
625 if (histX->getMax()>=0 && histX->getMin()<=0) {
626 zeroBin = histX->getBinning().binNumber(0) ;
627 }
else if (histX->getMin()>0) {
628 Double_t bw = (histX->getMax() - histX->getMin())/N2 ;
629 zeroBin = Int_t(-histX->getMin()/bw) ;
631 Double_t bw = (histX->getMax() - histX->getMin())/N2 ;
632 zeroBin = Int_t(-1*histX->getMax()/bw) ;
635 Int_t binShift = Int_t((N2* shift) / (histX->getMax()-histX->getMin())) ;
637 zeroBin += binShift ;
638 while(zeroBin>=N2) zeroBin-= N2 ;
639 while(zeroBin<0) zeroBin+= N2 ;
642 Double_t *tmp =
new Double_t[N2] ;
648 for (k=0 ; k<N2 ; k++) {
650 tmp[k] = pdf.getVal(hist.get()) ;
659 Double_t val = pdf.getVal(hist.get()) ;
660 for (k=0 ; k<Nbuf ; k++) {
663 for (k=0 ; k<N ; k++) {
665 tmp[k+Nbuf] = pdf.getVal(hist.get()) ;
668 val = pdf.getVal(hist.get()) ;
669 for (k=0 ; k<Nbuf ; k++) {
670 tmp[N+Nbuf+k] = val ;
678 for (k=0 ; k<N ; k++) {
680 tmp[k+Nbuf] = pdf.getVal(hist.get()) ;
682 for (k=1 ; k<=Nbuf ; k++) {
684 tmp[Nbuf-k] = pdf.getVal(hist.get()) ;
686 tmp[Nbuf+N+k-1] = pdf.getVal(hist.get()) ;
692 for (Int_t i=0 ; i<N2 ; i++) {
694 Int_t j = i - (zeroBin) ;
720 RooArgSet* RooFFTConvPdf::actualObservables(
const RooArgSet& nset)
const
723 RooArgSet* obs1 = _pdf1.arg().getObservables(nset) ;
724 RooArgSet* obs2 = _pdf2.arg().getObservables(nset) ;
725 obs1->add(*obs2,kTRUE) ;
728 if (nset.contains(_x.arg())) {
731 TIterator* iter = obs1->createIterator() ;
734 while((arg=(RooAbsArg*)iter->Next())) {
735 if (arg->IsA()->InheritsFrom(RooAbsReal::Class()) && !_cacheObs.find(arg->GetName())) {
740 obs1->remove(killList) ;
743 obs1->add(_x.arg(),kTRUE) ;
745 obs1->add(_cacheObs) ;
752 if (_cacheObs.getSize()>0) {
753 TIterator* iter = obs1->createIterator() ;
756 while((arg=(RooAbsArg*)iter->Next())) {
757 if (arg->IsA()->InheritsFrom(RooAbsReal::Class()) && !_cacheObs.find(arg->GetName())) {
762 obs1->remove(killList) ;
767 obs1->add(_x.arg(),kTRUE) ;
782 RooArgSet* RooFFTConvPdf::actualParameters(
const RooArgSet& nset)
const
784 RooArgSet* vars = getVariables() ;
785 RooArgSet* obs = actualObservables(nset) ;
798 RooAbsArg& RooFFTConvPdf::pdfObservable(RooAbsArg& histObservable)
const
800 if (_xprime.absArg() && string(histObservable.GetName())==_x.absArg()->GetName()) {
801 return (*_xprime.absArg()) ;
803 return histObservable ;
815 RooAbsGenContext* RooFFTConvPdf::genContext(
const RooArgSet &vars,
const RooDataSet *prototype,
816 const RooArgSet* auxProto, Bool_t verbose)
const
818 RooArgSet vars2(vars) ;
819 vars2.remove(_x.arg(),kTRUE,kTRUE) ;
820 Int_t numAddDep = vars2.getSize() ;
823 Bool_t pdfCanDir = (((RooAbsPdf&)_pdf1.arg()).getGenerator(_x.arg(),dummy) != 0 && \
824 ((RooAbsPdf&)_pdf1.arg()).isDirectGenSafe(_x.arg())) ;
825 Bool_t resCanDir = (((RooAbsPdf&)_pdf2.arg()).getGenerator(_x.arg(),dummy) !=0 &&
826 ((RooAbsPdf&)_pdf2.arg()).isDirectGenSafe(_x.arg())) ;
829 cxcoutI(Generation) <<
"RooFFTConvPdf::genContext() input p.d.f " << _pdf1.arg().GetName()
830 <<
" has internal generator that is safe to use in current context" << endl ;
833 cxcoutI(Generation) <<
"RooFFTConvPdf::genContext() input p.d.f. " << _pdf2.arg().GetName()
834 <<
" has internal generator that is safe to use in current context" << endl ;
837 cxcoutI(Generation) <<
"RooFFTConvPdf::genContext() generation requested for observables other than the convolution observable " << _x.arg().GetName() << endl ;
841 if (numAddDep>0 || !pdfCanDir || !resCanDir) {
844 cxcoutI(Generation) <<
"RooFFTConvPdf::genContext() selecting accept/reject generator context because one or both of the input "
845 <<
"p.d.f.s cannot use internal generator and/or "
846 <<
"observables other than the convolution variable are requested for generation" << endl ;
847 return new RooGenContext(*
this,vars,prototype,auxProto,verbose) ;
851 cxcoutI(Generation) <<
"RooFFTConvPdf::genContext() selecting specialized convolution generator context as both input "
852 <<
"p.d.fs are safe for internal generator and only "
853 <<
"the convolution observables is requested for generation" << endl ;
854 return new RooConvGenContext(*
this,vars,prototype,auxProto,verbose) ;
863 void RooFFTConvPdf::setBufferFraction(Double_t frac)
866 coutE(InputArguments) <<
"RooFFTConvPdf::setBufferFraction(" << GetName() <<
") fraction should be greater than or equal to zero" << endl ;
872 _cacheMgr.sterilize() ;
887 void RooFFTConvPdf::setBufferStrategy(BufStrat bs)
898 void RooFFTConvPdf::printMetaArgs(ostream& os)
const
900 os << _pdf1.arg().GetName() <<
"(" << _x.arg().GetName() <<
") (*) " << _pdf2.arg().GetName() <<
"(" << _x.arg().GetName() <<
") " ;
908 void RooFFTConvPdf::calcParams()
910 RooArgSet* params1 = _pdf1.arg().getParameters(_x.arg()) ;
911 RooArgSet* params2 = _pdf2.arg().getParameters(_x.arg()) ;
912 _params.removeAll() ;
913 _params.add(*params1) ;
914 _params.add(*params2,kTRUE) ;
924 Bool_t RooFFTConvPdf::redirectServersHook(
const RooAbsCollection& , Bool_t , Bool_t , Bool_t )