26 #ifndef ROOT_TMVA_NodekNN
27 #define ROOT_TMVA_NodekNN
71 Node(
const Node *parent,
const T &event, Int_t mod);
74 const Node* Add(
const T &event, UInt_t depth);
76 void SetNodeL(Node *node);
77 void SetNodeR(Node *node);
79 const T& GetEvent()
const;
81 const Node* GetNodeL()
const;
82 const Node* GetNodeR()
const;
83 const Node* GetNodeP()
const;
85 Double_t GetWeight()
const;
87 Float_t GetVarDis()
const;
88 Float_t GetVarMin()
const;
89 Float_t GetVarMax()
const;
91 UInt_t GetMod()
const;
94 void Print(std::ostream& os,
const std::string &offset =
"")
const;
102 const Node& operator=(
const Node &);
113 const Float_t fVarDis;
123 UInt_t Find(std::list<std::pair<
const Node<T> *, Float_t> > &nlist,
124 const Node<T> *node,
const T &event, UInt_t nfind);
129 UInt_t Find(std::list<std::pair<
const Node<T> *, Float_t> > &nlist,
130 const Node<T> *node,
const T &event, Double_t nfind, Double_t ncurr);
134 UInt_t Depth(
const Node<T> *node);
144 inline void Node<T>::SetNodeL(Node<T> *node)
150 inline void Node<T>::SetNodeR(Node<T> *node)
156 inline const T& Node<T>::GetEvent()
const
162 inline const Node<T>* Node<T>::GetNodeL()
const
168 inline const Node<T>* Node<T>::GetNodeR()
const
174 inline const Node<T>* Node<T>::GetNodeP()
const
180 inline Double_t Node<T>::GetWeight()
const
182 return fEvent.GetWeight();
186 inline Float_t Node<T>::GetVarDis()
const
192 inline Float_t Node<T>::GetVarMin()
const
198 inline Float_t Node<T>::GetVarMax()
const
204 inline UInt_t Node<T>::GetMod()
const
213 inline UInt_t Depth(
const Node<T> *node)
216 else return Depth(node->GetNodeP()) + 1;
224 TMVA::kNN::Node<T>::Node(
const Node<T> *parent,
const T &event,
const Int_t mod)
229 fVarDis(event.GetVar(mod)),
237 TMVA::kNN::Node<T>::~Node()
239 if (fNodeL)
delete fNodeL;
240 if (fNodeR)
delete fNodeR;
249 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(
const T &event,
const UInt_t depth)
252 assert(fMod == depth % event.GetNVar() &&
"Wrong recursive depth in Node<>::Add");
254 const Float_t value =
event.GetVar(fMod);
256 fVarMin = std::min(fVarMin, value);
257 fVarMax = std::max(fVarMax, value);
260 if (value < fVarDis) {
263 return fNodeL->Add(event, depth + 1);
266 fNodeL =
new Node<T>(
this, event, (depth + 1) % event.GetNVar());
272 return fNodeR->Add(event, depth + 1);
275 fNodeR =
new Node<T>(
this, event, (depth + 1) % event.GetNVar());
285 void TMVA::kNN::Node<T>::Print()
const
292 void TMVA::kNN::Node<T>::Print(std::ostream& os,
const std::string &offset)
const
294 os << offset <<
"-----------------------------------------------------------" << std::endl;
295 os << offset <<
"Node: mod " << fMod
297 <<
" with weight: " << GetWeight() << std::endl
301 os << offset <<
"Has left node " << std::endl;
304 os << offset <<
"Has right node" << std::endl;
308 os << offset <<
"PrInt_t left node " << std::endl;
309 fNodeL->Print(os, offset +
" ");
312 os << offset <<
"PrInt_t right node" << std::endl;
313 fNodeR->Print(os, offset +
" ");
316 if (!fNodeL && !fNodeR) {
332 UInt_t TMVA::kNN::Find(std::list<std::pair<
const TMVA::kNN::Node<T> *, Float_t> > &nlist,
333 const TMVA::kNN::Node<T> *node,
const T &event,
const UInt_t nfind)
335 if (!node || nfind < 1) {
339 const Float_t value =
event.GetVar(node->GetMod());
341 if (node->GetWeight() > 0.0) {
343 Float_t max_dist = 0.0;
345 if (!nlist.empty()) {
347 max_dist = nlist.back().second;
349 if (nlist.size() == nfind) {
350 if (value > node->GetVarMax() &&
351 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
354 if (value < node->GetVarMin() &&
355 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
361 const Float_t distance =
event.GetDist(node->GetEvent());
363 Bool_t insert_this = kFALSE;
364 Bool_t remove_back = kFALSE;
366 if (nlist.size() < nfind) {
369 else if (nlist.size() == nfind) {
370 if (distance < max_dist) {
376 std::cerr <<
"TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
383 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
386 for (; lit != nlist.end(); ++lit) {
387 if (distance < lit->second) {
395 nlist.insert(lit, std::pair<
const Node<T> *, Float_t>(node, distance));
404 if (node->GetNodeL() && node->GetNodeR()) {
405 if (value < node->GetVarDis()) {
406 count += Find(nlist, node->GetNodeL(), event, nfind);
407 count += Find(nlist, node->GetNodeR(), event, nfind);
410 count += Find(nlist, node->GetNodeR(), event, nfind);
411 count += Find(nlist, node->GetNodeL(), event, nfind);
415 if (node->GetNodeL()) {
416 count += Find(nlist, node->GetNodeL(), event, nfind);
418 if (node->GetNodeR()) {
419 count += Find(nlist, node->GetNodeR(), event, nfind);
440 UInt_t TMVA::kNN::Find(std::list<std::pair<
const TMVA::kNN::Node<T> *, Float_t> > &nlist,
441 const TMVA::kNN::Node<T> *node,
const T &event,
const Double_t nfind, Double_t ncurr)
444 if (!node || !(nfind < 0.0)) {
448 const Float_t value =
event.GetVar(node->GetMod());
450 if (node->GetWeight() > 0.0) {
452 Float_t max_dist = 0.0;
454 if (!nlist.empty()) {
456 max_dist = nlist.back().second;
458 if (!(ncurr < nfind)) {
459 if (value > node->GetVarMax() &&
460 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
463 if (value < node->GetVarMin() &&
464 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
470 const Float_t distance =
event.GetDist(node->GetEvent());
472 Bool_t insert_this = kFALSE;
477 else if (!nlist.empty()) {
478 if (distance < max_dist) {
483 std::cerr <<
"TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
493 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
496 for (; lit != nlist.end(); ++lit) {
497 if (distance < lit->second) {
501 ncurr += lit -> first -> GetWeight();
504 lit = nlist.insert(lit, std::pair<
const Node<T> *, Float_t>(node, distance));
506 for (; lit != nlist.end(); ++lit) {
507 ncurr += lit -> first -> GetWeight();
508 if (!(ncurr < nfind)) {
514 if(lit != nlist.end())
516 nlist.erase(lit, nlist.end());
522 if (node->GetNodeL() && node->GetNodeR()) {
523 if (value < node->GetVarDis()) {
524 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
525 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
528 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
529 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
533 if (node->GetNodeL()) {
534 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
536 if (node->GetNodeR()) {
537 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);