57 ClassImp(TMVA::MethodKNN);
62 TMVA::MethodKNN::MethodKNN( const TString& jobName,
63 const TString& methodTitle,
65 const TString& theOption )
66 : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption)
85 TMVA::MethodKNN::MethodKNN( DataSetInfo& theData,
86 const TString& theWeightFile)
87 : TMVA::MethodBase( Types::kKNN, theData, theWeightFile)
106 TMVA::MethodKNN::~MethodKNN()
108 if (fModule)
delete fModule;
124 void TMVA::MethodKNN::DeclareOptions()
126 DeclareOptionRef(fnkNN = 20,
"nkNN",
"Number of k-nearest neighbors");
127 DeclareOptionRef(fBalanceDepth = 6,
"BalanceDepth",
"Binary tree balance depth");
128 DeclareOptionRef(fScaleFrac = 0.80,
"ScaleFrac",
"Fraction of events used to compute variable width");
129 DeclareOptionRef(fSigmaFact = 1.0,
"SigmaFact",
"Scale factor for sigma in Gaussian kernel");
130 DeclareOptionRef(fKernel =
"Gaus",
"Kernel",
"Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
131 DeclareOptionRef(fTrim = kFALSE,
"Trim",
"Use equal number of signal and background events");
132 DeclareOptionRef(fUseKernel = kFALSE,
"UseKernel",
"Use polynomial kernel weight");
133 DeclareOptionRef(fUseWeight = kTRUE,
"UseWeight",
"Use weight to count kNN events");
134 DeclareOptionRef(fUseLDA = kFALSE,
"UseLDA",
"Use local linear discriminant - experimental feature");
140 void TMVA::MethodKNN::DeclareCompatibilityOptions() {
141 MethodBase::DeclareCompatibilityOptions();
142 DeclareOptionRef(fTreeOptDepth = 6,
"TreeOptDepth",
"Binary tree optimisation depth");
148 void TMVA::MethodKNN::ProcessOptions()
152 Log() << kWARNING <<
"kNN must be a positive integer: set kNN = " << fnkNN << Endl;
154 if (fScaleFrac < 0.0) {
156 Log() << kWARNING <<
"ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
158 if (fScaleFrac > 1.0) {
161 if (!(fBalanceDepth > 0)) {
163 Log() << kWARNING <<
"Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
168 <<
" kNN = \n" << fnkNN
169 <<
" UseKernel = \n" << fUseKernel
170 <<
" SigmaFact = \n" << fSigmaFact
171 <<
" ScaleFrac = \n" << fScaleFrac
172 <<
" Kernel = \n" << fKernel
173 <<
" Trim = \n" << fTrim
174 <<
" Optimize = " << fBalanceDepth << Endl;
180 Bool_t TMVA::MethodKNN::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
182 if (type == Types::kClassification && numberClasses == 2)
return kTRUE;
183 if (type == Types::kRegression)
return kTRUE;
190 void TMVA::MethodKNN::Init()
195 fModule =
new kNN::ModulekNN();
203 void TMVA::MethodKNN::MakeKNN()
206 Log() << kFATAL <<
"ModulekNN is not created" << Endl;
212 if (fScaleFrac > 0.0) {
219 Log() << kINFO <<
"Creating kd-tree with " << fEvent.size() <<
" events" << Endl;
221 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
222 fModule->Add(*event);
226 fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
227 static_cast<UInt_t>(100.0*fScaleFrac),
234 void TMVA::MethodKNN::Train()
236 Log() << kHEADER <<
"<Train> start..." << Endl;
238 if (IsNormalised()) {
239 Log() << kINFO <<
"Input events are normalized - setting ScaleFrac to 0" << Endl;
243 if (!fEvent.empty()) {
244 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" << Endl;
247 if (GetNVariables() < 1)
248 Log() << kFATAL <<
"MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
251 Log() << kINFO <<
"Reading " << GetNEvents() <<
" events" << Endl;
253 for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
255 const Event* evt_ = GetEvent(ievt);
256 Double_t weight = evt_->GetWeight();
259 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0)
continue;
261 kNN::VarVec vvec(GetNVariables(), 0.0);
262 for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
264 Short_t event_type = 0;
266 if (DataInfo().IsSignal(evt_)) {
267 fSumOfWeightsS += weight;
271 fSumOfWeightsB += weight;
278 kNN::Event event_knn(vvec, weight, event_type);
279 event_knn.SetTargets(evt_->GetTargets());
280 fEvent.push_back(event_knn);
284 <<
"Number of signal events " << fSumOfWeightsS << Endl
285 <<
"Number of background events " << fSumOfWeightsB << Endl;
296 Double_t TMVA::MethodKNN::GetMvaValue( Double_t* err, Double_t* errUpper )
299 NoErrorCalc(err, errUpper);
304 const Event *ev = GetEvent();
305 const Int_t nvar = GetNVariables();
306 const Double_t weight = ev->GetWeight();
307 const UInt_t knn =
static_cast<UInt_t
>(fnkNN);
309 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
311 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
312 vvec[ivar] = ev->GetValue(ivar);
318 const kNN::Event event_knn(vvec, weight, 3);
319 fModule->Find(event_knn, knn + 2);
321 const kNN::List &rlist = fModule->GetkNNList();
322 if (rlist.size() != knn + 2) {
323 Log() << kFATAL <<
"kNN result list is empty" << Endl;
327 if (fUseLDA)
return MethodKNN::getLDAValue(rlist, event_knn);
332 Bool_t use_gaus =
false, use_poln =
false;
334 if (fKernel ==
"Gaus") use_gaus =
true;
335 else if (fKernel ==
"Poln") use_poln =
true;
341 Double_t kradius = -1.0;
343 kradius = MethodKNN::getKernelRadius(rlist);
345 if (!(kradius > 0.0)) {
346 Log() << kFATAL <<
"kNN radius is not positive" << Endl;
350 kradius = 1.0/TMath::Sqrt(kradius);
356 std::vector<Double_t> rms_vec;
358 rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
360 if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
361 Log() << kFATAL <<
"Failed to compute RMS vector" << Endl;
366 UInt_t count_all = 0;
367 Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
369 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
372 const kNN::Node<kNN::Event> &node = *(lit->first);
376 if (lit->second < 0.0) {
377 Log() << kFATAL <<
"A neighbor has negative distance to query event" << Endl;
379 else if (!(lit->second > 0.0)) {
380 Log() << kVERBOSE <<
"A neighbor has zero distance to query event" << Endl;
384 Double_t evweight = node.GetWeight();
385 if (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
386 else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
388 if (fUseWeight) weight_all += evweight;
391 if (node.GetEvent().GetType() == 1) {
392 if (fUseWeight) weight_sig += evweight;
395 else if (node.GetEvent().GetType() == 2) {
396 if (fUseWeight) weight_bac += evweight;
400 Log() << kFATAL <<
"Unknown type for training event" << Endl;
406 if (count_all >= knn) {
412 if (!(count_all > 0)) {
413 Log() << kFATAL <<
"Size kNN result list is not positive" << Endl;
418 if (count_all < knn) {
419 Log() << kDEBUG <<
"count_all and kNN have different size: " << count_all <<
" < " << knn << Endl;
423 if (!(weight_all > 0.0)) {
424 Log() << kFATAL <<
"kNN result total weight is not positive" << Endl;
428 return weight_sig/weight_all;
435 const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
437 if( fRegressionReturnVal == 0 )
438 fRegressionReturnVal =
new std::vector<Float_t>;
440 fRegressionReturnVal->clear();
445 const Event *evt = GetEvent();
446 const Int_t nvar = GetNVariables();
447 const UInt_t knn =
static_cast<UInt_t
>(fnkNN);
448 std::vector<float> reg_vec;
450 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
452 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
453 vvec[ivar] = evt->GetValue(ivar);
459 const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
460 fModule->Find(event_knn, knn + 2);
462 const kNN::List &rlist = fModule->GetkNNList();
463 if (rlist.size() != knn + 2) {
464 Log() << kFATAL <<
"kNN result list is empty" << Endl;
465 return *fRegressionReturnVal;
469 Double_t weight_all = 0;
470 UInt_t count_all = 0;
472 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
475 const kNN::Node<kNN::Event> &node = *(lit->first);
476 const kNN::VarVec &tvec = node.GetEvent().GetTargets();
477 const Double_t weight = node.GetEvent().GetWeight();
479 if (reg_vec.empty()) {
480 reg_vec= kNN::VarVec(tvec.size(), 0.0);
483 for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
484 if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
485 else reg_vec[ivar] += tvec[ivar];
488 if (fUseWeight) weight_all += weight;
494 if (count_all == knn) {
500 if (!(weight_all > 0.0)) {
501 Log() << kFATAL <<
"Total weight sum is not positive: " << weight_all << Endl;
502 return *fRegressionReturnVal;
505 for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
506 reg_vec[ivar] /= weight_all;
510 fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
512 return *fRegressionReturnVal;
518 const TMVA::Ranking* TMVA::MethodKNN::CreateRanking()
526 void TMVA::MethodKNN::AddWeightsXMLTo(
void* parent )
const {
527 void* wght = gTools().AddChild(parent,
"Weights");
528 gTools().AddAttr(wght,
"NEvents",fEvent.size());
529 if (fEvent.size()>0) gTools().AddAttr(wght,
"NVar",fEvent.begin()->GetNVar());
530 if (fEvent.size()>0) gTools().AddAttr(wght,
"NTgt",fEvent.begin()->GetNTgt());
532 for (kNN::EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
534 std::stringstream s(
"");
536 for (UInt_t ivar = 0; ivar <
event->GetNVar(); ++ivar) {
537 if (ivar>0) s <<
" ";
538 s << std::scientific <<
event->GetVar(ivar);
541 for (UInt_t itgt = 0; itgt <
event->GetNTgt(); ++itgt) {
542 s <<
" " << std::scientific <<
event->GetTgt(itgt);
545 void* evt = gTools().AddChild(wght,
"Event", s.str().c_str());
546 gTools().AddAttr(evt,
"Type", event->GetType());
547 gTools().AddAttr(evt,
"Weight", event->GetWeight());
553 void TMVA::MethodKNN::ReadWeightsFromXML(
void* wghtnode ) {
554 void* ch = gTools().GetChild(wghtnode);
555 UInt_t nvar = 0, ntgt = 0;
556 gTools().ReadAttr( wghtnode,
"NVar", nvar );
557 gTools().ReadAttr( wghtnode,
"NTgt", ntgt );
561 Double_t evtWeight(0);
565 kNN::VarVec vvec(nvar, 0);
566 kNN::VarVec tvec(ntgt, 0);
568 gTools().ReadAttr( ch,
"Type", evtType );
569 gTools().ReadAttr( ch,
"Weight", evtWeight );
570 std::stringstream s( gTools().GetContent(ch) );
572 for(UInt_t ivar=0; ivar<nvar; ivar++)
575 for(UInt_t itgt=0; itgt<ntgt; itgt++)
578 ch = gTools().GetNextChild(ch);
580 kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
581 fEvent.push_back(event_knn);
591 void TMVA::MethodKNN::ReadWeightsFromStream(std::istream& is)
593 Log() << kINFO <<
"Starting ReadWeightsFromStream(std::istream& is) function..." << Endl;
595 if (!fEvent.empty()) {
596 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" << Endl;
604 std::getline(is, line);
606 if (line.empty() || line.find(
"#") != std::string::npos) {
611 std::string::size_type pos=0;
612 while( (pos=line.find(
',',pos)) != std::string::npos ) { count++; pos++; }
617 if (count < 3 || nvar != count - 2) {
618 Log() << kFATAL <<
"Missing comma delimeter(s)" << Endl;
623 Double_t weight = -1.0;
625 kNN::VarVec vvec(nvar, 0.0);
628 std::string::size_type prev = 0;
630 for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
631 if (line[ipos] !=
',' && ipos + 1 != line.size()) {
635 if (!(ipos > prev)) {
636 Log() << kFATAL <<
"Wrong substring limits" << Endl;
639 std::string vstring = line.substr(prev, ipos - prev);
640 if (ipos + 1 == line.size()) {
641 vstring = line.substr(prev, ipos - prev + 1);
644 if (vstring.empty()) {
645 Log() << kFATAL <<
"Failed to parse string" << Endl;
651 else if (vcount == 1) {
652 type = std::atoi(vstring.c_str());
654 else if (vcount == 2) {
655 weight = std::atof(vstring.c_str());
657 else if (vcount - 3 < vvec.size()) {
658 vvec[vcount - 3] = std::atof(vstring.c_str());
661 Log() << kFATAL <<
"Wrong variable count" << Endl;
668 fEvent.push_back(kNN::Event(vvec, weight, type));
671 Log() << kINFO <<
"Read " << fEvent.size() <<
" events from text file" << Endl;
680 void TMVA::MethodKNN::WriteWeightsToStream(TFile &rf)
const
682 Log() << kINFO <<
"Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
684 if (fEvent.empty()) {
685 Log() << kWARNING <<
"MethodKNN contains no events " << Endl;
689 kNN::Event *
event =
new kNN::Event();
690 TTree *tree =
new TTree(
"knn",
"event tree");
691 tree->SetDirectory(0);
692 tree->Branch(
"event",
"TMVA::kNN::Event", &event);
695 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
697 size += tree->Fill();
701 rf.WriteTObject(tree,
"knn",
"Overwrite");
706 Log() << kINFO <<
"Wrote " << size <<
"MB and " << fEvent.size()
707 <<
" events to ROOT file" << Endl;
716 void TMVA::MethodKNN::ReadWeightsFromStream(TFile &rf)
718 Log() << kINFO <<
"Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
720 if (!fEvent.empty()) {
721 Log() << kINFO <<
"Erasing " << fEvent.size() <<
" previously stored events" << Endl;
726 TTree *tree =
dynamic_cast<TTree *
>(rf.Get(
"knn"));
728 Log() << kFATAL <<
"Failed to find knn tree" << Endl;
732 kNN::Event *
event =
new kNN::Event();
733 tree->SetBranchAddress(
"event", &event);
735 const Int_t nevent = tree->GetEntries();
738 for (Int_t i = 0; i < nevent; ++i) {
739 size += tree->GetEntry(i);
740 fEvent.push_back(*event);
746 Log() << kINFO <<
"Read " << size <<
"MB and " << fEvent.size()
747 <<
" events from ROOT file" << Endl;
758 void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout,
const TString& className )
const
760 fout <<
" // not implemented for class: \"" << className <<
"\"" << std::endl;
761 fout <<
"};" << std::endl;
770 void TMVA::MethodKNN::GetHelpMessage()
const
773 Log() << gTools().Color(
"bold") <<
"--- Short description:" << gTools().Color(
"reset") << Endl;
775 Log() <<
"The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
776 <<
"and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
777 <<
"training events for which a classification category/regression target is known. " << Endl
778 <<
"The k-NN method compares a test event to all training events using a distance " << Endl
779 <<
"function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
780 <<
"The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
781 <<
"quick search for the k events with shortest distance to the test event. The method" << Endl
782 <<
"returns a fraction of signal events among the k neighbors. It is recommended" << Endl
783 <<
"that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
784 <<
"between 0 and 1." << Endl;
787 Log() << gTools().Color(
"bold") <<
"--- Performance tuning via configuration options: "
788 << gTools().Color(
"reset") << Endl;
790 Log() <<
"The k-NN method estimates a density of signal and background events in a "<< Endl
791 <<
"neighborhood around the test event. The method assumes that the density of the " << Endl
792 <<
"signal and background events is uniform and constant within the neighborhood. " << Endl
793 <<
"k is an adjustable parameter and it determines an average size of the " << Endl
794 <<
"neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
795 <<
"fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
796 <<
"local differences between events in the training set. The speed of the k-NN" << Endl
797 <<
"method also increases with larger values of k. " << Endl;
799 Log() <<
"The k-NN method assigns equal weight to all input variables. Different scales " << Endl
800 <<
"among the input variables is compensated using ScaleFrac parameter: the input " << Endl
801 <<
"variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
802 <<
"equal among all the input variables." << Endl;
805 Log() << gTools().Color(
"bold") <<
"--- Additional configuration options: "
806 << gTools().Color(
"reset") << Endl;
808 Log() <<
"The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
809 <<
"response. The kernel re-weights events using a distance to the test event." << Endl;
815 Double_t TMVA::MethodKNN::PolnKernel(
const Double_t value)
const
817 const Double_t avalue = TMath::Abs(value);
819 if (!(avalue < 1.0)) {
823 const Double_t prod = 1.0 - avalue * avalue * avalue;
825 return (prod * prod * prod);
831 Double_t TMVA::MethodKNN::GausKernel(
const kNN::Event &event_knn,
832 const kNN::Event &event,
const std::vector<Double_t> &svec)
const
834 if (event_knn.GetNVar() !=
event.GetNVar() || event_knn.GetNVar() != svec.size()) {
835 Log() << kFATAL <<
"Mismatched vectors in Gaussian kernel function" << Endl;
842 double sum_exp = 0.0;
844 for(
unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
846 const Double_t diff_ =
event.GetVar(ivar) - event_knn.GetVar(ivar);
847 const Double_t sigm_ = svec[ivar];
848 if (!(sigm_ > 0.0)) {
849 Log() << kFATAL <<
"Bad sigma value = " << sigm_ << Endl;
853 sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
861 return std::exp(-sum_exp);
869 Double_t TMVA::MethodKNN::getKernelRadius(
const kNN::List &rlist)
const
871 Double_t kradius = -1.0;
873 const UInt_t knn =
static_cast<UInt_t
>(fnkNN);
875 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
877 if (!(lit->second > 0.0))
continue;
879 if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
882 if (kcount >= knn)
break;
893 const std::vector<Double_t> TMVA::MethodKNN::getRMS(
const kNN::List &rlist,
const kNN::Event &event_knn)
const
895 std::vector<Double_t> rvec;
897 const UInt_t knn =
static_cast<UInt_t
>(fnkNN);
899 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
901 if (!(lit->second > 0.0))
continue;
903 const kNN::Node<kNN::Event> *node_ = lit -> first;
904 const kNN::Event &event_ = node_-> GetEvent();
907 rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
909 else if (rvec.size() != event_.GetNVar()) {
910 Log() << kFATAL <<
"Wrong number of variables, should never happen!" << Endl;
915 for(
unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
916 const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
917 rvec[ivar] += diff_*diff_;
921 if (kcount >= knn)
break;
925 Log() << kFATAL <<
"Bad event kcount = " << kcount << Endl;
930 for(
unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
931 if (!(rvec[ivar] > 0.0)) {
932 Log() << kFATAL <<
"Bad RMS value = " << rvec[ivar] << Endl;
937 rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
945 Double_t TMVA::MethodKNN::getLDAValue(
const kNN::List &rlist,
const kNN::Event &event_knn)
947 LDAEvents sig_vec, bac_vec;
949 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
952 const kNN::Node<kNN::Event> &node = *(lit->first);
953 const kNN::VarVec &tvec = node.GetEvent().GetVars();
955 if (node.GetEvent().GetType() == 1) {
956 sig_vec.push_back(tvec);
958 else if (node.GetEvent().GetType() == 2) {
959 bac_vec.push_back(tvec);
962 Log() << kFATAL <<
"Unknown type for training event" << Endl;
966 fLDA.Initialize(sig_vec, bac_vec);
968 return fLDA.GetProb(event_knn.GetVars(), 1);