50 TMVA::kNN::Event::Event()
60 TMVA::kNN::Event::Event(
const VarVec &var,
const Double_t weight,
const Short_t type)
70 TMVA::kNN::Event::Event(
const VarVec &var,
const Double_t weight,
const Short_t type,
const VarVec &tvec)
81 TMVA::kNN::Event::~Event()
88 TMVA::kNN::VarType TMVA::kNN::Event::GetDist(
const Event &other)
const
90 const UInt_t nvar = GetNVar();
92 if (nvar != other.GetNVar()) {
93 std::cerr <<
"Distance: two events have different dimensions" << std::endl;
98 for (UInt_t ivar = 0; ivar < nvar; ++ivar) {
99 sum += GetDist(other.GetVar(ivar), ivar);
107 void TMVA::kNN::Event::SetTargets(
const VarVec &tvec)
114 const TMVA::kNN::VarVec& TMVA::kNN::Event::GetTargets()
const
121 const TMVA::kNN::VarVec& TMVA::kNN::Event::GetVars()
const
129 void TMVA::kNN::Event::Print()
const
137 void TMVA::kNN::Event::Print(std::ostream& os)
const
139 Int_t dp = os.precision();
141 for (UInt_t ivar = 0; ivar != GetNVar(); ++ivar) {
149 os << std::setfill(
' ') << std::setw(5) << std::setprecision(3) << GetVar(ivar);
156 os <<
" no variables";
158 os << std::setprecision(dp);
164 std::ostream& TMVA::kNN::operator<<(std::ostream& os,
const TMVA::kNN::Event& event)
173 TMVA::kNN::ModulekNN::ModulekNN()
176 fLogger( new MsgLogger(
"ModulekNN") )
183 TMVA::kNN::ModulekNN::~ModulekNN()
186 delete fTree; fTree = 0;
194 void TMVA::kNN::ModulekNN::Clear()
212 void TMVA::kNN::ModulekNN::Add(
const Event &event)
215 Log() << kFATAL <<
"<Add> Cannot add event: tree is already built" << Endl;
220 fDimn =
event.GetNVar();
222 else if (fDimn != event.GetNVar()) {
223 Log() << kFATAL <<
"ModulekNN::Add() - number of dimension does not match previous events" << Endl;
227 fEvent.push_back(event);
229 for (UInt_t ivar = 0; ivar < fDimn; ++ivar) {
230 fVar[ivar].push_back(event.GetVar(ivar));
233 std::map<Short_t, UInt_t>::iterator cit = fCount.find(event.GetType());
234 if (cit == fCount.end()) {
235 fCount[
event.GetType()] = 1;
245 Bool_t TMVA::kNN::ModulekNN::Fill(
const UShort_t odepth,
const UInt_t ifrac,
const std::string &option)
248 Log() << kFATAL <<
"ModulekNN::Fill - tree has already been created" << Endl;
255 if (option.find(
"trim") != std::string::npos) {
256 for (std::map<Short_t, UInt_t>::const_iterator it = fCount.begin(); it != fCount.end(); ++it) {
257 if (min == 0 || min > it->second) {
262 Log() << kINFO <<
"<Fill> Will trim all event types to " << min <<
" events" << Endl;
269 for (EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
270 std::map<Short_t, UInt_t>::iterator cit = fCount.find(event->GetType());
271 if (cit == fCount.end()) {
272 fCount[
event->GetType()] = 1;
274 else if (cit->second < min) {
281 for (UInt_t d = 0; d < fDimn; ++d) {
282 fVar[d].push_back(event->GetVar(d));
285 evec.push_back(*event);
288 Log() << kINFO <<
"<Fill> Erased " << fEvent.size() - evec.size() <<
" events" << Endl;
297 for (VarMap::iterator it = fVar.begin(); it != fVar.end(); ++it) {
298 std::sort((it->second).begin(), (it->second).end());
301 if (option.find(
"metric") != std::string::npos && ifrac > 0) {
302 ComputeMetric(ifrac);
306 for (VarMap::iterator it = fVar.begin(); it != fVar.end(); ++it) {
307 std::sort((it->second).begin(), (it->second).end());
315 fTree = Optimize(odepth);
318 Log() << kFATAL <<
"ModulekNN::Fill() - failed to create tree" << Endl;
322 for (EventVec::const_iterator event = fEvent.begin();
event != fEvent.end(); ++event) {
323 fTree->Add(*event, 0);
325 std::map<Short_t, UInt_t>::iterator cit = fCount.find(event->GetType());
326 if (cit == fCount.end()) {
327 fCount[
event->GetType()] = 1;
334 for (std::map<Short_t, UInt_t>::const_iterator it = fCount.begin(); it != fCount.end(); ++it) {
335 Log() << kINFO <<
"<Fill> Class " << it->first <<
" has " << std::setw(8)
336 << it->second <<
" events" << Endl;
348 Bool_t TMVA::kNN::ModulekNN::Find(Event event,
const UInt_t nfind,
const std::string &option)
const
351 Log() << kFATAL <<
"ModulekNN::Find() - tree has not been filled" << Endl;
354 if (fDimn != event.GetNVar()) {
355 Log() << kFATAL <<
"ModulekNN::Find() - number of dimension does not match training events" << Endl;
359 Log() << kFATAL <<
"ModulekNN::Find() - requested 0 nearest neighbors" << Endl;
365 if (!fVarScale.empty()) {
366 event = Scale(event);
373 if(option.find(
"weight") != std::string::npos)
378 kNN::Find<kNN::Event>(fkNNList, fTree, event, Double_t(nfind), 0.0);
384 kNN::Find<kNN::Event>(fkNNList, fTree, event, nfind);
393 Bool_t TMVA::kNN::ModulekNN::Find(
const UInt_t nfind,
const std::string &option)
const
395 if (fCount.empty() || !fTree) {
398 typedef std::map<Short_t, UInt_t>::const_iterator const_iterator;
399 TTHREAD_TLS_DECL_ARG(const_iterator,cit,fCount.end());
401 if (cit == fCount.end()) {
402 cit = fCount.begin();
405 const Short_t etype = (cit++)->first;
407 if (option ==
"flat") {
409 for (UInt_t d = 0; d < fDimn; ++d) {
410 VarMap::const_iterator vit = fVar.find(d);
411 if (vit == fVar.end()) {
415 const std::vector<Double_t> &vvec = vit->second;
422 const VarType min = vvec.front();
423 const VarType max = vvec.back();
424 const VarType width = max - min;
426 if (width < 0.0 || width > 0.0) {
427 dvec.push_back(min + width*GetRndmThreadLocal().Rndm());
434 const Event event(dvec, 1.0, etype);
449 TMVA::kNN::Node<TMVA::kNN::Event>* TMVA::kNN::ModulekNN::Optimize(
const UInt_t odepth)
451 if (fVar.empty() || fDimn != fVar.size()) {
452 Log() << kWARNING <<
"<Optimize> Cannot build a tree" << Endl;
456 const UInt_t size = (fVar.begin()->second).size();
458 Log() << kWARNING <<
"<Optimize> Cannot build a tree without events" << Endl;
462 VarMap::const_iterator it = fVar.begin();
463 for (; it != fVar.end(); ++it) {
464 if ((it->second).size() != size) {
465 Log() << kWARNING <<
"<Optimize> # of variables doesn't match between dimensions" << Endl;
470 if (
double(fDimn*size) < TMath::Power(2.0,
double(odepth))) {
471 Log() << kWARNING <<
"<Optimize> Optimization depth exceeds number of events" << Endl;
475 Log() << kHEADER <<
"Optimizing tree for " << fDimn <<
" variables with " << size <<
" values" << Endl;
477 std::vector<Node<Event> *> pvec, cvec;
480 if (it == fVar.end() || (it->second).size() < 2) {
481 Log() << kWARNING <<
"<Optimize> Missing 0 variable" << Endl;
485 const Event pevent(VarVec(fDimn, (it->second)[size/2]), -1.0, -1);
487 Node<Event> *tree =
new Node<Event>(0, pevent, 0);
489 pvec.push_back(tree);
491 for (UInt_t depth = 1; depth < odepth; ++depth) {
492 const UInt_t mod = depth % fDimn;
494 VarMap::const_iterator vit = fVar.find(mod);
495 if (vit == fVar.end()) {
496 Log() << kFATAL <<
"Missing " << mod <<
" variable" << Endl;
499 const std::vector<Double_t> &dvec = vit->second;
501 if (dvec.size() < 2) {
502 Log() << kFATAL <<
"Missing " << mod <<
" variable" << Endl;
507 for (std::vector<Node<Event> *>::iterator pit = pvec.begin(); pit != pvec.end(); ++pit) {
508 Node<Event> *parent = *pit;
510 const VarType lmedian = dvec[size*ichild/(2*pvec.size() + 1)];
513 const VarType rmedian = dvec[size*ichild/(2*pvec.size() + 1)];
516 const Event levent(VarVec(fDimn, lmedian), -1.0, -1);
517 const Event revent(VarVec(fDimn, rmedian), -1.0, -1);
519 Node<Event> *lchild =
new Node<Event>(parent, levent, mod);
520 Node<Event> *rchild =
new Node<Event>(parent, revent, mod);
522 parent->SetNodeL(lchild);
523 parent->SetNodeR(rchild);
525 cvec.push_back(lchild);
526 cvec.push_back(rchild);
542 void TMVA::kNN::ModulekNN::ComputeMetric(
const UInt_t ifrac)
548 Log() << kFATAL <<
"ModulekNN::ComputeMetric - fraction can not exceed 100%" << Endl;
551 if (!fVarScale.empty()) {
552 Log() << kFATAL <<
"ModulekNN::ComputeMetric - metric is already computed" << Endl;
555 if (fEvent.size() < 100) {
556 Log() << kFATAL <<
"ModulekNN::ComputeMetric - number of events is too small" << Endl;
560 const UInt_t lfrac = (100 - ifrac)/2;
561 const UInt_t rfrac = 100 - (100 - ifrac)/2;
563 Log() << kINFO <<
"Computing scale factor for 1d distributions: "
564 <<
"(ifrac, bottom, top) = (" << ifrac <<
"%, " << lfrac <<
"%, " << rfrac <<
"%)" << Endl;
568 for (VarMap::const_iterator vit = fVar.begin(); vit != fVar.end(); ++vit) {
569 const std::vector<Double_t> &dvec = vit->second;
571 std::vector<Double_t>::const_iterator beg_it = dvec.end();
572 std::vector<Double_t>::const_iterator end_it = dvec.end();
575 for (std::vector<Double_t>::const_iterator dit = dvec.begin(); dit != dvec.end(); ++dit, ++dist) {
577 if ((100*dist)/dvec.size() == lfrac && beg_it == dvec.end()) {
581 if ((100*dist)/dvec.size() == rfrac && end_it == dvec.end()) {
586 if (beg_it == dvec.end() || end_it == dvec.end()) {
587 beg_it = dvec.begin();
590 assert(beg_it != end_it &&
"Empty vector");
595 const Double_t lpos = *beg_it;
596 const Double_t rpos = *end_it;
598 if (!(lpos < rpos)) {
599 Log() << kFATAL <<
"ModulekNN::ComputeMetric() - min value is greater than max value" << Endl;
610 fVarScale[vit->first] = rpos - lpos;
615 for (UInt_t ievent = 0; ievent < fEvent.size(); ++ievent) {
616 fEvent[ievent] = Scale(fEvent[ievent]);
618 for (UInt_t ivar = 0; ivar < fDimn; ++ivar) {
619 fVar[ivar].push_back(fEvent[ievent].GetVar(ivar));
628 const TMVA::kNN::Event TMVA::kNN::ModulekNN::Scale(
const Event &event)
const
630 if (fVarScale.empty()) {
634 if (event.GetNVar() != fVarScale.size()) {
635 Log() << kFATAL <<
"ModulekNN::Scale() - mismatched metric and event size" << Endl;
639 VarVec vvec(event.GetNVar(), 0.0);
641 for (UInt_t ivar = 0; ivar <
event.GetNVar(); ++ivar) {
642 std::map<int, Double_t>::const_iterator fit = fVarScale.find(ivar);
643 if (fit == fVarScale.end()) {
644 Log() << kFATAL <<
"ModulekNN::Scale() - failed to find scale for " << ivar << Endl;
648 if (fit->second > 0.0) {
649 vvec[ivar] =
event.GetVar(ivar)/fit->second;
652 Log() << kFATAL <<
"Variable " << ivar <<
" has zero width" << Endl;
656 return Event(vvec, event.GetWeight(),
event.GetType(),
event.GetTargets());
662 void TMVA::kNN::ModulekNN::Print()
const
670 void TMVA::kNN::ModulekNN::Print(std::ostream &os)
const
672 os <<
"----------------------------------------------------------------------"<< std::endl;
673 os <<
"Printing knn result" << std::endl;
674 os << fkNNEvent << std::endl;
678 std::map<Short_t, Double_t> min, max;
680 os <<
"Printing " << fkNNList.size() <<
" nearest neighbors" << std::endl;
681 for (List::const_iterator it = fkNNList.begin(); it != fkNNList.end(); ++it) {
682 os << ++count <<
": " << it->second <<
": " << it->first->GetEvent() << std::endl;
684 const Event &
event = it->first->GetEvent();
685 for (UShort_t ivar = 0; ivar <
event.GetNVar(); ++ivar) {
686 if (min.find(ivar) == min.end()) {
687 min[ivar] =
event.GetVar(ivar);
689 else if (min[ivar] > event.GetVar(ivar)) {
690 min[ivar] =
event.GetVar(ivar);
693 if (max.find(ivar) == max.end()) {
694 max[ivar] =
event.GetVar(ivar);
696 else if (max[ivar] < event.GetVar(ivar)) {
697 max[ivar] =
event.GetVar(ivar);
702 if (min.size() == max.size()) {
703 for (std::map<Short_t, Double_t>::const_iterator mit = min.begin(); mit != min.end(); ++mit) {
704 const Short_t i = mit->first;
705 Log() << kINFO <<
"(var, min, max) = (" << i <<
"," << min[i] <<
", " << max[i] <<
")" << Endl;
709 os <<
"----------------------------------------------------------------------" << std::endl;