128 REGISTER_METHOD(Fisher)
130 ClassImp(TMVA::MethodFisher);
135 TMVA::MethodFisher::MethodFisher( const TString& jobName,
136 const TString& methodTitle,
138 const TString& theOption ) :
139 MethodBase( jobName, Types::kFisher, methodTitle, dsi, theOption),
141 fTheMethod ( "Fisher" ),
142 fFisherMethod ( kFisher ),
157 TMVA::MethodFisher::MethodFisher( DataSetInfo& dsi,
158 const TString& theWeightFile) :
159 MethodBase( Types::kFisher, dsi, theWeightFile),
161 fTheMethod (
"Fisher" ),
162 fFisherMethod ( kFisher ),
177 void TMVA::MethodFisher::Init(
void )
180 fFisherCoeff =
new std::vector<Double_t>( GetNvar() );
183 SetSignalReferenceCut( 0.0 );
194 void TMVA::MethodFisher::DeclareOptions()
196 DeclareOptionRef( fTheMethod =
"Fisher",
"Method",
"Discrimination method" );
197 AddPreDefVal(TString(
"Fisher"));
198 AddPreDefVal(TString(
"Mahalanobis"));
204 void TMVA::MethodFisher::ProcessOptions()
206 if (fTheMethod ==
"Fisher" ) fFisherMethod = kFisher;
207 else fFisherMethod = kMahalanobis;
216 TMVA::MethodFisher::~MethodFisher(
void )
218 if (fBetw ) {
delete fBetw; fBetw = 0; }
219 if (fWith ) {
delete fWith; fWith = 0; }
220 if (fCov ) {
delete fCov; fCov = 0; }
221 if (fDiscrimPow ) {
delete fDiscrimPow; fDiscrimPow = 0; }
222 if (fFisherCoeff) {
delete fFisherCoeff; fFisherCoeff = 0; }
228 Bool_t TMVA::MethodFisher::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
230 if (type == Types::kClassification && numberClasses == 2)
return kTRUE;
237 void TMVA::MethodFisher::Train(
void )
243 GetCov_WithinClass();
246 GetCov_BetweenClass();
268 Double_t TMVA::MethodFisher::GetMvaValue( Double_t* err, Double_t* errUpper )
270 const Event * ev = GetEvent();
271 Double_t result = fF0;
272 for (UInt_t ivar=0; ivar<GetNvar(); ivar++)
273 result += (*fFisherCoeff)[ivar]*ev->GetValue(ivar);
276 NoErrorCalc(err, errUpper);
285 void TMVA::MethodFisher::InitMatrices(
void )
288 fMeanMatx =
new TMatrixD( GetNvar(), 3 );
291 fBetw =
new TMatrixD( GetNvar(), GetNvar() );
292 fWith =
new TMatrixD( GetNvar(), GetNvar() );
293 fCov =
new TMatrixD( GetNvar(), GetNvar() );
296 fDiscrimPow =
new std::vector<Double_t>( GetNvar() );
302 void TMVA::MethodFisher::GetMean(
void )
308 const UInt_t nvar = DataInfo().GetNVariables();
311 Double_t* sumS =
new Double_t[nvar];
312 Double_t* sumB =
new Double_t[nvar];
313 for (UInt_t ivar=0; ivar<nvar; ivar++) { sumS[ivar] = sumB[ivar] = 0; }
316 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
319 const Event * ev = GetEvent(ievt);
322 Double_t weight = ev->GetWeight();
323 if (DataInfo().IsSignal(ev)) fSumOfWeightsS += weight;
324 else fSumOfWeightsB += weight;
326 Double_t* sum = DataInfo().IsSignal(ev) ? sumS : sumB;
328 for (UInt_t ivar=0; ivar<nvar; ivar++) sum[ivar] += ev->GetValue( ivar )*weight;
331 for (UInt_t ivar=0; ivar<nvar; ivar++) {
332 (*fMeanMatx)( ivar, 2 ) = sumS[ivar];
333 (*fMeanMatx)( ivar, 0 ) = sumS[ivar]/fSumOfWeightsS;
335 (*fMeanMatx)( ivar, 2 ) += sumB[ivar];
336 (*fMeanMatx)( ivar, 1 ) = sumB[ivar]/fSumOfWeightsB;
339 (*fMeanMatx)( ivar, 2 ) /= (fSumOfWeightsS + fSumOfWeightsB);
351 void TMVA::MethodFisher::GetCov_WithinClass(
void )
354 assert( fSumOfWeightsS > 0 && fSumOfWeightsB > 0 );
359 const Int_t nvar = GetNvar();
360 const Int_t nvar2 = nvar*nvar;
361 Double_t *sumSig =
new Double_t[nvar2];
362 Double_t *sumBgd =
new Double_t[nvar2];
363 Double_t *xval =
new Double_t[nvar];
364 memset(sumSig,0,nvar2*
sizeof(Double_t));
365 memset(sumBgd,0,nvar2*
sizeof(Double_t));
368 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
371 const Event* ev = GetEvent(ievt);
373 Double_t weight = ev->GetWeight();
375 for (Int_t x=0; x<nvar; x++) xval[x] = ev->GetValue( x );
377 for (Int_t x=0; x<nvar; x++) {
378 for (Int_t y=0; y<nvar; y++) {
379 if (DataInfo().IsSignal(ev)) {
380 Double_t v = ( (xval[x] - (*fMeanMatx)(x, 0))*(xval[y] - (*fMeanMatx)(y, 0)) )*weight;
383 Double_t v = ( (xval[x] - (*fMeanMatx)(x, 1))*(xval[y] - (*fMeanMatx)(y, 1)) )*weight;
391 for (Int_t x=0; x<nvar; x++) {
392 for (Int_t y=0; y<nvar; y++) {
402 (*fWith)(x, y) = sumSig[k]/fSumOfWeightsS + sumBgd[k]/fSumOfWeightsB;
417 void TMVA::MethodFisher::GetCov_BetweenClass(
void )
420 assert( fSumOfWeightsS > 0 && fSumOfWeightsB > 0);
422 Double_t prodSig, prodBgd;
424 for (UInt_t x=0; x<GetNvar(); x++) {
425 for (UInt_t y=0; y<GetNvar(); y++) {
427 prodSig = ( ((*fMeanMatx)(x, 0) - (*fMeanMatx)(x, 2))*
428 ((*fMeanMatx)(y, 0) - (*fMeanMatx)(y, 2)) );
429 prodBgd = ( ((*fMeanMatx)(x, 1) - (*fMeanMatx)(x, 2))*
430 ((*fMeanMatx)(y, 1) - (*fMeanMatx)(y, 2)) );
432 (*fBetw)(x, y) = (fSumOfWeightsS*prodSig + fSumOfWeightsB*prodBgd) / (fSumOfWeightsS + fSumOfWeightsB);
440 void TMVA::MethodFisher::GetCov_Full(
void )
442 for (UInt_t x=0; x<GetNvar(); x++)
443 for (UInt_t y=0; y<GetNvar(); y++)
444 (*fCov)(x, y) = (*fWith)(x, y) + (*fBetw)(x, y);
457 void TMVA::MethodFisher::GetFisherCoeff(
void )
460 assert( fSumOfWeightsS > 0 && fSumOfWeightsB > 0);
463 TMatrixD* theMat = 0;
464 switch (GetFisherMethod()) {
472 Log() << kFATAL <<
"<GetFisherCoeff> undefined method" << GetFisherMethod() << Endl;
475 TMatrixD invCov( *theMat );
477 if ( TMath::Abs(invCov.Determinant()) < 10E-24 ) {
478 Log() << kWARNING <<
"<GetFisherCoeff> matrix is almost singular with determinant="
479 << TMath::Abs(invCov.Determinant())
480 <<
" did you use the variables that are linear combinations or highly correlated?"
483 if ( TMath::Abs(invCov.Determinant()) < 10E-120 ) {
485 Log() << kFATAL <<
"<GetFisherCoeff> matrix is singular with determinant="
486 << TMath::Abs(invCov.Determinant())
487 <<
" did you use the variables that are linear combinations? \n"
488 <<
" do you any clue as to what went wrong in above printout of the covariance matrix? "
495 Double_t xfact = TMath::Sqrt( fSumOfWeightsS*fSumOfWeightsB ) / (fSumOfWeightsS + fSumOfWeightsB);
498 std::vector<Double_t> diffMeans( GetNvar() );
500 for (ivar=0; ivar<GetNvar(); ivar++) {
501 (*fFisherCoeff)[ivar] = 0;
503 for (jvar=0; jvar<GetNvar(); jvar++) {
504 Double_t d = (*fMeanMatx)(jvar, 0) - (*fMeanMatx)(jvar, 1);
505 (*fFisherCoeff)[ivar] += invCov(ivar, jvar)*d;
508 (*fFisherCoeff)[ivar] *= xfact;
514 for (ivar=0; ivar<GetNvar(); ivar++){
515 fF0 += (*fFisherCoeff)[ivar]*((*fMeanMatx)(ivar, 0) + (*fMeanMatx)(ivar, 1));
528 void TMVA::MethodFisher::GetDiscrimPower(
void )
530 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
531 if ((*fCov)(ivar, ivar) != 0)
532 (*fDiscrimPow)[ivar] = (*fBetw)(ivar, ivar)/(*fCov)(ivar, ivar);
534 (*fDiscrimPow)[ivar] = 0;
541 const TMVA::Ranking* TMVA::MethodFisher::CreateRanking()
544 fRanking =
new Ranking( GetName(),
"Discr. power" );
546 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
547 fRanking->AddRank( Rank( GetInputLabel(ivar), (*fDiscrimPow)[ivar] ) );
557 void TMVA::MethodFisher::PrintCoefficients(
void )
559 Log() << kHEADER <<
"Results for Fisher coefficients:" << Endl;
561 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
562 Log() << kINFO <<
"NOTE: The coefficients must be applied to TRANFORMED variables" << Endl;
563 Log() << kINFO <<
" List of the transformation: " << Endl;
564 TListIter trIt(&GetTransformationHandler().GetTransformationList());
565 while (VariableTransformBase *trf = (VariableTransformBase*) trIt()) {
566 Log() << kINFO <<
" -- " << trf->GetName() << Endl;
569 std::vector<TString> vars;
570 std::vector<Double_t> coeffs;
571 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
572 vars .push_back( GetInputLabel(ivar) );
573 coeffs.push_back( (*fFisherCoeff)[ivar] );
575 vars .push_back(
"(offset)" );
576 coeffs.push_back( fF0 );
577 TMVA::gTools().FormattedOutput( coeffs, vars,
"Variable" ,
"Coefficient", Log() );
582 if (IsNormalised()) {
583 Log() << kINFO <<
"NOTE: You have chosen to use the \"Normalise\" booking option. Hence, the" << Endl;
584 Log() << kINFO <<
" coefficients must be applied to NORMALISED (') variables as follows:" << Endl;
586 for (UInt_t ivar=0; ivar<GetNvar(); ivar++)
if (GetInputLabel(ivar).Length() > maxL) maxL = GetInputLabel(ivar).Length();
589 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
591 << std::setw(maxL+9) << TString(
"[") + GetInputLabel(ivar) +
"]' = 2*("
592 << std::setw(maxL+2) << TString(
"[") + GetInputLabel(ivar) +
"]"
593 << std::setw(3) << (GetXmin(ivar) > 0 ?
" - " :
" + ")
594 << std::setw(6) << TMath::Abs(GetXmin(ivar)) << std::setw(3) <<
")/"
595 << std::setw(6) << (GetXmax(ivar) - GetXmin(ivar) )
596 << std::setw(3) <<
" - 1"
599 Log() << kINFO <<
"The TMVA Reader will properly account for this normalisation, but if the" << Endl;
600 Log() << kINFO <<
"Fisher classifier is applied outside the Reader, the transformation must be" << Endl;
601 Log() << kINFO <<
"implemented -- or the \"Normalise\" option is removed and Fisher retrained." << Endl;
602 Log() << kINFO << Endl;
609 void TMVA::MethodFisher::ReadWeightsFromStream( std::istream& istr )
612 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) istr >> (*fFisherCoeff)[ivar];
618 void TMVA::MethodFisher::AddWeightsXMLTo(
void* parent )
const
620 void* wght = gTools().AddChild(parent,
"Weights");
621 gTools().AddAttr( wght,
"NCoeff", GetNvar()+1 );
622 void* coeffxml = gTools().AddChild(wght,
"Coefficient");
623 gTools().AddAttr( coeffxml,
"Index", 0 );
624 gTools().AddAttr( coeffxml,
"Value", fF0 );
625 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
626 coeffxml = gTools().AddChild( wght,
"Coefficient" );
627 gTools().AddAttr( coeffxml,
"Index", ivar+1 );
628 gTools().AddAttr( coeffxml,
"Value", (*fFisherCoeff)[ivar] );
635 void TMVA::MethodFisher::ReadWeightsFromXML(
void* wghtnode )
637 UInt_t ncoeff, coeffidx;
638 gTools().ReadAttr( wghtnode,
"NCoeff", ncoeff );
639 fFisherCoeff->resize(ncoeff-1);
641 void* ch = gTools().GetChild(wghtnode);
644 gTools().ReadAttr( ch,
"Index", coeffidx );
645 gTools().ReadAttr( ch,
"Value", coeff );
646 if (coeffidx==0) fF0 = coeff;
647 else (*fFisherCoeff)[coeffidx-1] = coeff;
648 ch = gTools().GetNextChild(ch);
655 void TMVA::MethodFisher::MakeClassSpecific( std::ostream& fout,
const TString& className )
const
657 Int_t dp = fout.precision();
658 fout <<
" double fFisher0;" << std::endl;
659 fout <<
" std::vector<double> fFisherCoefficients;" << std::endl;
660 fout <<
"};" << std::endl;
661 fout <<
"" << std::endl;
662 fout <<
"inline void " << className <<
"::Initialize() " << std::endl;
663 fout <<
"{" << std::endl;
664 fout <<
" fFisher0 = " << std::setprecision(12) << fF0 <<
";" << std::endl;
665 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
666 fout <<
" fFisherCoefficients.push_back( " << std::setprecision(12) << (*fFisherCoeff)[ivar] <<
" );" << std::endl;
669 fout <<
" // sanity check" << std::endl;
670 fout <<
" if (fFisherCoefficients.size() != fNvars) {" << std::endl;
671 fout <<
" std::cout << \"Problem in class \\\"\" << fClassName << \"\\\"::Initialize: mismatch in number of input values\"" << std::endl;
672 fout <<
" << fFisherCoefficients.size() << \" != \" << fNvars << std::endl;" << std::endl;
673 fout <<
" fStatusIsClean = false;" << std::endl;
674 fout <<
" } " << std::endl;
675 fout <<
"}" << std::endl;
677 fout <<
"inline double " << className <<
"::GetMvaValue__( const std::vector<double>& inputValues ) const" << std::endl;
678 fout <<
"{" << std::endl;
679 fout <<
" double retval = fFisher0;" << std::endl;
680 fout <<
" for (size_t ivar = 0; ivar < fNvars; ivar++) {" << std::endl;
681 fout <<
" retval += fFisherCoefficients[ivar]*inputValues[ivar];" << std::endl;
682 fout <<
" }" << std::endl;
684 fout <<
" return retval;" << std::endl;
685 fout <<
"}" << std::endl;
687 fout <<
"// Clean up" << std::endl;
688 fout <<
"inline void " << className <<
"::Clear() " << std::endl;
689 fout <<
"{" << std::endl;
690 fout <<
" // clear coefficients" << std::endl;
691 fout <<
" fFisherCoefficients.clear(); " << std::endl;
692 fout <<
"}" << std::endl;
693 fout << std::setprecision(dp);
702 void TMVA::MethodFisher::GetHelpMessage()
const
705 Log() << gTools().Color(
"bold") <<
"--- Short description:" << gTools().Color(
"reset") << Endl;
707 Log() <<
"Fisher discriminants select events by distinguishing the mean " << Endl;
708 Log() <<
"values of the signal and background distributions in a trans- " << Endl;
709 Log() <<
"formed variable space where linear correlations are removed." << Endl;
711 Log() <<
" (More precisely: the \"linear discriminator\" determines" << Endl;
712 Log() <<
" an axis in the (correlated) hyperspace of the input " << Endl;
713 Log() <<
" variables such that, when projecting the output classes " << Endl;
714 Log() <<
" (signal and background) upon this axis, they are pushed " << Endl;
715 Log() <<
" as far as possible away from each other, while events" << Endl;
716 Log() <<
" of a same class are confined in a close vicinity. The " << Endl;
717 Log() <<
" linearity property of this classifier is reflected in the " << Endl;
718 Log() <<
" metric with which \"far apart\" and \"close vicinity\" are " << Endl;
719 Log() <<
" determined: the covariance matrix of the discriminating" << Endl;
720 Log() <<
" variable space.)" << Endl;
722 Log() << gTools().Color(
"bold") <<
"--- Performance optimisation:" << gTools().Color(
"reset") << Endl;
724 Log() <<
"Optimal performance for Fisher discriminants is obtained for " << Endl;
725 Log() <<
"linearly correlated Gaussian-distributed variables. Any deviation" << Endl;
726 Log() <<
"from this ideal reduces the achievable separation power. In " << Endl;
727 Log() <<
"particular, no discrimination at all is achieved for a variable" << Endl;
728 Log() <<
"that has the same sample mean for signal and background, even if " << Endl;
729 Log() <<
"the shapes of the distributions are very different. Thus, Fisher " << Endl;
730 Log() <<
"discriminants often benefit from suitable transformations of the " << Endl;
731 Log() <<
"input variables. For example, if a variable x in [-1,1] has a " << Endl;
732 Log() <<
"a parabolic signal distributions, and a uniform background" << Endl;
733 Log() <<
"distributions, their mean value is zero in both cases, leading " << Endl;
734 Log() <<
"to no separation. The simple transformation x -> |x| renders this " << Endl;
735 Log() <<
"variable powerful for the use in a Fisher discriminant." << Endl;
737 Log() << gTools().Color(
"bold") <<
"--- Performance tuning via configuration options:" << gTools().Color(
"reset") << Endl;
739 Log() <<
"<None>" << Endl;