77 using std::stringstream;
81 ClassImp(TMVA::MethodFDA);
86 TMVA::MethodFDA::MethodFDA( const TString& jobName,
87 const TString& methodTitle,
89 const TString& theOption)
90 : MethodBase( jobName, Types::kFDA, methodTitle, theData, theOption),
95 fConvergerFitter( 0 ),
96 fSumOfWeightsSig( 0 ),
97 fSumOfWeightsBkg( 0 ),
99 fOutputDimensions( 0 )
106 TMVA::MethodFDA::MethodFDA( DataSetInfo& theData,
107 const TString& theWeightFile)
108 : MethodBase( Types::kFDA, theData, theWeightFile),
113 fConvergerFitter( 0 ),
114 fSumOfWeightsSig( 0 ),
115 fSumOfWeightsBkg( 0 ),
117 fOutputDimensions( 0 )
124 void TMVA::MethodFDA::Init(
void )
131 fSumOfWeightsSig = 0;
132 fSumOfWeightsBkg = 0;
134 fFormulaStringP =
"";
135 fParRangeStringP =
"";
136 fFormulaStringT =
"";
137 fParRangeStringT =
"";
143 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal =
new std::vector<Float_t>();
163 void TMVA::MethodFDA::DeclareOptions()
165 DeclareOptionRef( fFormulaStringP =
"(0)",
"Formula",
"The discrimination formula" );
166 DeclareOptionRef( fParRangeStringP =
"()",
"ParRanges",
"Parameter ranges" );
169 DeclareOptionRef( fFitMethod =
"MINUIT",
"FitMethod",
"Optimisation Method");
170 AddPreDefVal(TString(
"MC"));
171 AddPreDefVal(TString(
"GA"));
172 AddPreDefVal(TString(
"SA"));
173 AddPreDefVal(TString(
"MINUIT"));
175 DeclareOptionRef( fConverger =
"None",
"Converger",
"FitMethod uses Converger to improve result");
176 AddPreDefVal(TString(
"None"));
177 AddPreDefVal(TString(
"MINUIT"));
183 void TMVA::MethodFDA::CreateFormula()
186 fFormulaStringT = fFormulaStringP;
191 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
192 fFormulaStringT.ReplaceAll( Form(
"(%i)",ipar), Form(
"[%i]",ipar) );
196 for (Int_t ipar=fNPars; ipar<1000; ipar++) {
197 if (fFormulaStringT.Contains( Form(
"(%i)",ipar) ))
199 <<
"<CreateFormula> Formula contains expression: \"" << Form(
"(%i)",ipar) <<
"\", "
200 <<
"which cannot be attributed to a parameter; "
201 <<
"it may be that the number of variable ranges given via \"ParRanges\" "
202 <<
"does not match the number of parameters in the formula expression, please verify!"
207 for (Int_t ivar=GetNvar()-1; ivar >= 0; ivar--) {
208 fFormulaStringT.ReplaceAll( Form(
"x%i",ivar), Form(
"[%i]",ivar+fNPars) );
212 for (UInt_t ivar=GetNvar(); ivar<1000; ivar++) {
213 if (fFormulaStringT.Contains( Form(
"x%i",ivar) ))
215 <<
"<CreateFormula> Formula contains expression: \"" << Form(
"x%i",ivar) <<
"\", "
216 <<
"which cannot be attributed to an input variable" << Endl;
219 Log() <<
"User-defined formula string : \"" << fFormulaStringP <<
"\"" << Endl;
220 Log() <<
"TFormula-compatible formula string: \"" << fFormulaStringT <<
"\"" << Endl;
221 Log() << kDEBUG <<
"Creating and compiling formula" << Endl;
224 if (fFormula)
delete fFormula;
225 fFormula =
new TFormula(
"FDA_Formula", fFormulaStringT );
228 if (!fFormula->IsValid())
229 Log() << kFATAL <<
"<ProcessOptions> Formula expression could not be properly compiled" << Endl;
232 if (fFormula->GetNpar() > (Int_t)(fNPars + GetNvar()))
233 Log() << kFATAL <<
"<ProcessOptions> Dubious number of parameters in formula expression: "
234 << fFormula->GetNpar() <<
" - compared to maximum allowed: " << fNPars + GetNvar() << Endl;
240 void TMVA::MethodFDA::ProcessOptions()
243 fParRangeStringT = fParRangeStringP;
246 fParRangeStringT.ReplaceAll(
" ",
"" );
247 fNPars = fParRangeStringT.CountChar(
')' );
249 TList* parList = gTools().ParseFormatLine( fParRangeStringT,
";" );
250 if ((UInt_t)parList->GetSize() != fNPars) {
251 Log() << kFATAL <<
"<ProcessOptions> Mismatch in parameter string: "
252 <<
"the number of parameters: " << fNPars <<
" != ranges defined: "
253 << parList->GetSize() <<
"; the format of the \"ParRanges\" string "
254 <<
"must be: \"(-1.2,3.4);(-2.3,4.55);...\", "
255 <<
"where the numbers in \"(a,b)\" correspond to the a=min, b=max parameter ranges; "
256 <<
"each parameter defined in the function string must have a corresponding rang."
260 fParRange.resize( fNPars );
261 for (UInt_t ipar=0; ipar<fNPars; ipar++) fParRange[ipar] = 0;
263 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
265 TString str = ((TObjString*)parList->At(ipar))->GetString();
266 Ssiz_t istr = str.First(
',' );
267 TString pminS(str(1,istr-1));
268 TString pmaxS(str(istr+1,str.Length()-2-istr));
270 stringstream stmin; Float_t pmin=0; stmin << pminS.Data(); stmin >> pmin;
271 stringstream stmax; Float_t pmax=0; stmax << pmaxS.Data(); stmax >> pmax;
274 if (TMath::Abs(pmax-pmin) < 1.e-30) pmax = pmin;
275 if (pmin > pmax) Log() << kFATAL <<
"<ProcessOptions> max > min in interval for parameter: ["
276 << ipar <<
"] : [" << pmin <<
", " << pmax <<
"] " << Endl;
278 Log() << kINFO <<
"Create parameter interval for parameter " << ipar <<
" : [" << pmin <<
"," << pmax <<
"]" << Endl;
279 fParRange[ipar] =
new Interval( pmin, pmax );
288 fOutputDimensions = 1;
290 fOutputDimensions = DataInfo().GetNTargets();
292 fOutputDimensions = DataInfo().GetNClasses();
294 for( Int_t dim = 1; dim < fOutputDimensions; ++dim ){
295 for( UInt_t par = 0; par < fNPars; ++par ){
296 fParRange.push_back( fParRange.at(par) );
302 fConvergerFitter = (IFitterTarget*)
this;
303 if (fConverger ==
"MINUIT") {
304 fConvergerFitter =
new MinuitFitter( *
this, Form(
"%s_Converger_Minuit", GetName()), fParRange, GetOptions() );
305 SetOptions(dynamic_cast<Configurable*>(fConvergerFitter)->GetOptions());
308 if(fFitMethod ==
"MC")
309 fFitter =
new MCFitter( *fConvergerFitter, Form(
"%s_Fitter_MC", GetName()), fParRange, GetOptions() );
310 else if (fFitMethod ==
"GA")
311 fFitter =
new GeneticFitter( *fConvergerFitter, Form(
"%s_Fitter_GA", GetName()), fParRange, GetOptions() );
312 else if (fFitMethod ==
"SA")
313 fFitter =
new SimulatedAnnealingFitter( *fConvergerFitter, Form(
"%s_Fitter_SA", GetName()), fParRange, GetOptions() );
314 else if (fFitMethod ==
"MINUIT")
315 fFitter =
new MinuitFitter( *fConvergerFitter, Form(
"%s_Fitter_Minuit", GetName()), fParRange, GetOptions() );
317 Log() << kFATAL <<
"<Train> Do not understand fit method:" << fFitMethod << Endl;
320 fFitter->CheckForUnusedOptions();
326 TMVA::MethodFDA::~MethodFDA(
void )
334 Bool_t TMVA::MethodFDA::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
336 if (type == Types::kClassification && numberClasses == 2)
return kTRUE;
337 if (type == Types::kMulticlass )
return kTRUE;
338 if (type == Types::kRegression )
return kTRUE;
346 void TMVA::MethodFDA::ClearAll(
void )
351 for (UInt_t ipar=0; ipar<fParRange.size() && ipar<fNPars; ipar++) {
352 if (fParRange[ipar] != 0) {
delete fParRange[ipar]; fParRange[ipar] = 0; }
356 if (fFormula != 0) {
delete fFormula; fFormula = 0; }
363 void TMVA::MethodFDA::Train(
void )
367 fSumOfWeightsSig = 0;
368 fSumOfWeightsBkg = 0;
370 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
373 const Event* ev = GetEvent(ievt);
376 Float_t w = ev->GetWeight();
378 if (!DoRegression()) {
379 if (DataInfo().IsSignal(ev)) { fSumOfWeightsSig += w; }
380 else { fSumOfWeightsBkg += w; }
386 if (!DoRegression()) {
387 if (fSumOfWeightsSig <= 0 || fSumOfWeightsBkg <= 0) {
388 Log() << kFATAL <<
"<Train> Troubles in sum of weights: "
389 << fSumOfWeightsSig <<
" (S) : " << fSumOfWeightsBkg <<
" (B)" << Endl;
392 else if (fSumOfWeights <= 0) {
393 Log() << kFATAL <<
"<Train> Troubles in sum of weights: "
394 << fSumOfWeights << Endl;
399 for (std::vector<Interval*>::const_iterator parIt = fParRange.begin(); parIt != fParRange.end(); ++parIt) {
400 fBestPars.push_back( (*parIt)->GetMean() );
404 Double_t estimator = fFitter->Run( fBestPars );
407 PrintResults( fFitMethod, fBestPars, estimator );
409 delete fFitter; fFitter = 0;
410 if (fConvergerFitter!=0 && fConvergerFitter!=(IFitterTarget*)
this) {
411 delete fConvergerFitter;
412 fConvergerFitter = 0;
421 void TMVA::MethodFDA::PrintResults(
const TString& fitter, std::vector<Double_t>& pars,
const Double_t estimator )
const
424 Log() << kHEADER <<
"Results for parameter fit using \"" << fitter <<
"\" fitter:" << Endl;
425 std::vector<TString> parNames;
426 for (UInt_t ipar=0; ipar<pars.size(); ipar++) parNames.push_back( Form(
"Par(%i)",ipar ) );
427 gTools().FormattedOutput( pars, parNames,
"Parameter" ,
"Fit result", Log(),
"%g" );
428 Log() <<
"Discriminator expression: \"" << fFormulaStringP <<
"\"" << Endl;
429 Log() <<
"Value of estimator at minimum: " << estimator << Endl;
435 Double_t TMVA::MethodFDA::EstimatorFunction( std::vector<Double_t>& pars )
437 const Double_t sumOfWeights[] = { fSumOfWeightsBkg, fSumOfWeightsSig, fSumOfWeights };
438 Double_t estimator[] = { 0, 0, 0 };
440 Double_t result, deviation;
441 Double_t desired = 0.0;
444 if( DoRegression() ){
445 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
447 const TMVA::Event* ev = GetEvent(ievt);
449 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
450 desired = ev->GetTarget( dim );
451 result = InterpretFormula( ev, pars.begin(), pars.end() );
452 deviation = TMath::Power(result - desired, 2);
453 estimator[2] += deviation * ev->GetWeight();
456 estimator[2] /= sumOfWeights[2];
460 }
else if( DoMulticlass() ){
461 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
463 const TMVA::Event* ev = GetEvent(ievt);
465 CalculateMulticlassValues( ev, pars, *fMulticlassReturnVal );
467 Double_t crossEntropy = 0.0;
468 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
469 Double_t y = fMulticlassReturnVal->at(dim);
470 Double_t t = (ev->GetClass() ==
static_cast<UInt_t
>(dim) ? 1.0 : 0.0 );
471 crossEntropy += t*log(y);
473 estimator[2] += ev->GetWeight()*crossEntropy;
475 estimator[2] /= sumOfWeights[2];
480 for (UInt_t ievt=0; ievt<GetNEvents(); ievt++) {
482 const TMVA::Event* ev = GetEvent(ievt);
484 desired = (DataInfo().IsSignal(ev) ? 1.0 : 0.0);
485 result = InterpretFormula( ev, pars.begin(), pars.end() );
486 deviation = TMath::Power(result - desired, 2);
487 estimator[Int_t(desired)] += deviation * ev->GetWeight();
489 estimator[0] /= sumOfWeights[0];
490 estimator[1] /= sumOfWeights[1];
492 return estimator[0] + estimator[1];
499 Double_t TMVA::MethodFDA::InterpretFormula(
const Event* event, std::vector<Double_t>::iterator parBegin, std::vector<Double_t>::iterator parEnd )
503 for( std::vector<Double_t>::iterator it = parBegin; it != parEnd; ++it ){
505 fFormula->SetParameter( ipar, (*it) );
508 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) fFormula->SetParameter( ivar+ipar, event->GetValue(ivar) );
510 Double_t result = fFormula->Eval( 0 );
518 Double_t TMVA::MethodFDA::GetMvaValue( Double_t* err, Double_t* errUpper )
520 const Event* ev = GetEvent();
523 NoErrorCalc(err, errUpper);
525 return InterpretFormula( ev, fBestPars.begin(), fBestPars.end() );
530 const std::vector<Float_t>& TMVA::MethodFDA::GetRegressionValues()
532 if (fRegressionReturnVal == NULL) fRegressionReturnVal =
new std::vector<Float_t>();
533 fRegressionReturnVal->clear();
535 const Event* ev = GetEvent();
537 Event* evT =
new Event(*ev);
539 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
540 Int_t offset = dim*fNPars;
541 evT->SetTarget(dim,InterpretFormula( ev, fBestPars.begin()+offset, fBestPars.begin()+offset+fNPars ) );
543 const Event* evT2 = GetTransformationHandler().InverseTransform( evT );
544 fRegressionReturnVal->push_back(evT2->GetTarget(0));
548 return (*fRegressionReturnVal);
553 const std::vector<Float_t>& TMVA::MethodFDA::GetMulticlassValues()
555 if (fMulticlassReturnVal == NULL) fMulticlassReturnVal =
new std::vector<Float_t>();
556 fMulticlassReturnVal->clear();
557 std::vector<Float_t> temp;
560 const TMVA::Event* evt = GetEvent();
562 CalculateMulticlassValues( evt, fBestPars, temp );
564 UInt_t nClasses = DataInfo().GetNClasses();
565 for(UInt_t iClass=0; iClass<nClasses; iClass++){
567 for(UInt_t j=0;j<nClasses;j++){
569 norm+=exp(temp[j]-temp[iClass]);
571 (*fMulticlassReturnVal).push_back(1.0/(1.0+norm));
574 return (*fMulticlassReturnVal);
581 void TMVA::MethodFDA::CalculateMulticlassValues(
const TMVA::Event*& evt, std::vector<Double_t>& parameters, std::vector<Float_t>& values)
592 for( Int_t dim = 0; dim < fOutputDimensions; ++dim ){
593 Int_t offset = dim*fNPars;
594 Double_t value = InterpretFormula( evt, parameters.begin()+offset, parameters.begin()+offset+fNPars );
596 values.push_back( value );
604 void TMVA::MethodFDA::ReadWeightsFromStream( std::istream& istr )
611 fBestPars.resize( fNPars );
612 for (UInt_t ipar=0; ipar<fNPars; ipar++) istr >> fBestPars[ipar];
619 void TMVA::MethodFDA::AddWeightsXMLTo(
void* parent )
const
621 void* wght = gTools().AddChild(parent,
"Weights");
622 gTools().AddAttr( wght,
"NPars", fNPars );
623 gTools().AddAttr( wght,
"NDim", fOutputDimensions );
624 for (UInt_t ipar=0; ipar<fNPars*fOutputDimensions; ipar++) {
625 void* coeffxml = gTools().AddChild( wght,
"Parameter" );
626 gTools().AddAttr( coeffxml,
"Index", ipar );
627 gTools().AddAttr( coeffxml,
"Value", fBestPars[ipar] );
631 gTools().AddAttr( wght,
"Formula", fFormulaStringP );
637 void TMVA::MethodFDA::ReadWeightsFromXML(
void* wghtnode )
639 gTools().ReadAttr( wghtnode,
"NPars", fNPars );
641 if(gTools().HasAttr( wghtnode,
"NDim")) {
642 gTools().ReadAttr( wghtnode,
"NDim" , fOutputDimensions );
645 fOutputDimensions = 1;
649 fBestPars.resize( fNPars*fOutputDimensions );
651 void* ch = gTools().GetChild(wghtnode);
655 gTools().ReadAttr( ch,
"Index", ipar );
656 gTools().ReadAttr( ch,
"Value", par );
659 if (ipar >= fNPars*fOutputDimensions) Log() << kFATAL <<
"<ReadWeightsFromXML> index out of range: "
660 << ipar <<
" >= " << fNPars << Endl;
661 fBestPars[ipar] = par;
663 ch = gTools().GetNextChild(ch);
667 gTools().ReadAttr( wghtnode,
"Formula", fFormulaStringP );
676 void TMVA::MethodFDA::MakeClassSpecific( std::ostream& fout,
const TString& className )
const
678 fout <<
" double fParameter[" << fNPars <<
"];" << std::endl;
679 fout <<
"};" << std::endl;
680 fout <<
"" << std::endl;
681 fout <<
"inline void " << className <<
"::Initialize() " << std::endl;
682 fout <<
"{" << std::endl;
683 for(UInt_t ipar=0; ipar<fNPars; ipar++) {
684 fout <<
" fParameter[" << ipar <<
"] = " << fBestPars[ipar] <<
";" << std::endl;
686 fout <<
"}" << std::endl;
688 fout <<
"inline double " << className <<
"::GetMvaValue__( const std::vector<double>& inputValues ) const" << std::endl;
689 fout <<
"{" << std::endl;
690 fout <<
" // interpret the formula" << std::endl;
693 TString str = fFormulaStringT;
694 for (UInt_t ipar=0; ipar<fNPars; ipar++) {
695 str.ReplaceAll( Form(
"[%i]", ipar), Form(
"fParameter[%i]", ipar) );
699 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
700 str.ReplaceAll( Form(
"[%i]", ivar+fNPars), Form(
"inputValues[%i]", ivar) );
703 fout <<
" double retval = " << str <<
";" << std::endl;
705 fout <<
" return retval; " << std::endl;
706 fout <<
"}" << std::endl;
708 fout <<
"// Clean up" << std::endl;
709 fout <<
"inline void " << className <<
"::Clear() " << std::endl;
710 fout <<
"{" << std::endl;
711 fout <<
" // nothing to clear" << std::endl;
712 fout <<
"}" << std::endl;
721 void TMVA::MethodFDA::GetHelpMessage()
const
724 Log() << gTools().Color(
"bold") <<
"--- Short description:" << gTools().Color(
"reset") << Endl;
726 Log() <<
"The function discriminant analysis (FDA) is a classifier suitable " << Endl;
727 Log() <<
"to solve linear or simple nonlinear discrimination problems." << Endl;
729 Log() <<
"The user provides the desired function with adjustable parameters" << Endl;
730 Log() <<
"via the configuration option string, and FDA fits the parameters to" << Endl;
731 Log() <<
"it, requiring the signal (background) function value to be as close" << Endl;
732 Log() <<
"as possible to 1 (0). Its advantage over the more involved and" << Endl;
733 Log() <<
"automatic nonlinear discriminators is the simplicity and transparency " << Endl;
734 Log() <<
"of the discrimination expression. A shortcoming is that FDA will" << Endl;
735 Log() <<
"underperform for involved problems with complicated, phase space" << Endl;
736 Log() <<
"dependent nonlinear correlations." << Endl;
738 Log() <<
"Please consult the Users Guide for the format of the formula string" << Endl;
739 Log() <<
"and the allowed parameter ranges:" << Endl;
740 if (gConfig().WriteOptionsReference()) {
741 Log() <<
"<a href=\"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf\">"
742 <<
"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf</a>" << Endl;
744 else Log() <<
"http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf" << Endl;
746 Log() << gTools().Color(
"bold") <<
"--- Performance optimisation:" << gTools().Color(
"reset") << Endl;
748 Log() <<
"The FDA performance depends on the complexity and fidelity of the" << Endl;
749 Log() <<
"user-defined discriminator function. As a general rule, it should" << Endl;
750 Log() <<
"be able to reproduce the discrimination power of any linear" << Endl;
751 Log() <<
"discriminant analysis. To reach into the nonlinear domain, it is" << Endl;
752 Log() <<
"useful to inspect the correlation profiles of the input variables," << Endl;
753 Log() <<
"and add quadratic and higher polynomial terms between variables as" << Endl;
754 Log() <<
"necessary. Comparison with more involved nonlinear classifiers can" << Endl;
755 Log() <<
"be used as a guide." << Endl;
757 Log() << gTools().Color(
"bold") <<
"--- Performance tuning via configuration options:" << gTools().Color(
"reset") << Endl;
759 Log() <<
"Depending on the function used, the choice of \"FitMethod\" is" << Endl;
760 Log() <<
"crucial for getting valuable solutions with FDA. As a guideline it" << Endl;
761 Log() <<
"is recommended to start with \"FitMethod=MINUIT\". When more complex" << Endl;
762 Log() <<
"functions are used where MINUIT does not converge to reasonable" << Endl;
763 Log() <<
"results, the user should switch to non-gradient FitMethods such" << Endl;
764 Log() <<
"as GeneticAlgorithm (GA) or Monte Carlo (MC). It might prove to be" << Endl;
765 Log() <<
"useful to combine GA (or MC) with MINUIT by setting the option" << Endl;
766 Log() <<
"\"Converger=MINUIT\". GA (MC) will then set the starting parameters" << Endl;
767 Log() <<
"for MINUIT such that the basic quality of GA (MC) of finding global" << Endl;
768 Log() <<
"minima is combined with the efficacy of MINUIT of finding local" << Endl;
769 Log() <<
"minima." << Endl;