Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
BinarySearchTree.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : BinarySearchTree *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header file for description) *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18  * *
19  * Copyright (c) 2005: *
20  * CERN, Switzerland *
21  * U. of Victoria, Canada *
22  * MPI-K Heidelberg, Germany *
23  * LAPP, Annecy, France *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  * *
29  **********************************************************************************/
30 
31 /*! \class TMVA::BinarySearchTree
32 \ingroup TMVA
33 
34 A simple Binary search tree including a volume search method.
35 
36 */
37 
38 #include <stdexcept>
39 #include <cstdlib>
40 #include <queue>
41 #include <algorithm>
42 
43 #include "TMath.h"
44 
45 #include "TMatrixDBase.h"
46 #include "TObjString.h"
47 #include "TTree.h"
48 
49 #include "TMVA/MsgLogger.h"
50 #include "TMVA/MethodBase.h"
51 #include "TMVA/Tools.h"
52 #include "TMVA/Event.h"
53 #include "TMVA/BinarySearchTree.h"
54 
55 #include "TMVA/BinaryTree.h"
56 #include "TMVA/Types.h"
57 #include "TMVA/Node.h"
58 
59 ClassImp(TMVA::BinarySearchTree);
60 
61 ////////////////////////////////////////////////////////////////////////////////
62 /// default constructor
63 
64 TMVA::BinarySearchTree::BinarySearchTree( void ) :
65 BinaryTree(),
66  fPeriod ( 1 ),
67  fCurrentDepth( 0 ),
68  fStatisticsIsValid( kFALSE ),
69  fSumOfWeights( 0 ),
70  fCanNormalize( kFALSE )
71 {
72  fNEventsW[0]=fNEventsW[1]=0.;
73 }
74 
75 ////////////////////////////////////////////////////////////////////////////////
76 /// copy constructor that creates a true copy, i.e. a completely independent tree
77 
78 TMVA::BinarySearchTree::BinarySearchTree( const BinarySearchTree &b)
79  : BinaryTree(),
80  fPeriod ( b.fPeriod ),
81  fCurrentDepth( 0 ),
82  fStatisticsIsValid( kFALSE ),
83  fSumOfWeights( b.fSumOfWeights ),
84  fCanNormalize( kFALSE )
85 {
86  fNEventsW[0]=fNEventsW[1]=0.;
87  Log() << kFATAL << " Copy constructor not implemented yet " << Endl;
88 }
89 
90 ////////////////////////////////////////////////////////////////////////////////
91 /// destructor
92 
93 TMVA::BinarySearchTree::~BinarySearchTree( void )
94 {
95  for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
96  pIt != fNormalizeTreeTable.end(); ++pIt) {
97  delete pIt->second;
98  }
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 /// re-create a new tree (decision tree or search tree) from XML
103 
104 TMVA::BinarySearchTree* TMVA::BinarySearchTree::CreateFromXML(void* node, UInt_t tmva_Version_Code ) {
105  std::string type("");
106  gTools().ReadAttr(node,"type", type);
107  BinarySearchTree* bt = new BinarySearchTree();
108  bt->ReadXML( node, tmva_Version_Code );
109  return bt;
110 }
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 /// insert a new "event" in the binary tree
114 
115 void TMVA::BinarySearchTree::Insert( const Event* event )
116 {
117  fCurrentDepth=0;
118  fStatisticsIsValid = kFALSE;
119 
120  if (this->GetRoot() == NULL) { // If the list is empty...
121  this->SetRoot( new BinarySearchTreeNode(event)); //Make the new node the root.
122  // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
123  this->GetRoot()->SetPos('s');
124  this->GetRoot()->SetDepth(0);
125  fNNodes = 1;
126  fSumOfWeights = event->GetWeight();
127  ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
128  this->SetPeriode(event->GetNVariables());
129  }
130  else {
131  // sanity check:
132  if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
133  Log() << kFATAL << "<Insert> event vector length != Periode specified in Binary Tree" << Endl
134  << "--- event size: " << event->GetNVariables() << " Periode: " << this->GetPeriode() << Endl
135  << "--- and all this when trying filling the "<<fNNodes+1<<"th Node" << Endl;
136  }
137  // insert a new node at the propper position
138  this->Insert(event, this->GetRoot());
139  }
140 
141  // normalise the tree to speed up searches
142  if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,new const Event(*event)) );
143 }
144 
145 ////////////////////////////////////////////////////////////////////////////////
146 /// private internal function to insert a event (node) at the proper position
147 
148 void TMVA::BinarySearchTree::Insert( const Event *event,
149  Node *node )
150 {
151  fCurrentDepth++;
152  fStatisticsIsValid = kFALSE;
153 
154  if (node->GoesLeft(*event)){ // If the adding item is less than the current node's data...
155  if (node->GetLeft() != NULL){ // If there is a left node...
156  // Add the new event to the left node
157  this->Insert(event, node->GetLeft());
158  }
159  else { // If there is not a left node...
160  // Make the new node for the new event
161  BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
162  fNNodes++;
163  fSumOfWeights += event->GetWeight();
164  current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
165  current->SetParent(node); // Set the new node's previous node.
166  current->SetPos('l');
167  current->SetDepth( node->GetDepth() + 1 );
168  node->SetLeft(current); // Make it the left node of the current one.
169  }
170  }
171  else if (node->GoesRight(*event)) { // If the adding item is less than or equal to the current node's data...
172  if (node->GetRight() != NULL) { // If there is a right node...
173  // Add the new node to it.
174  this->Insert(event, node->GetRight());
175  }
176  else { // If there is not a right node...
177  // Make the new node.
178  BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
179  fNNodes++;
180  fSumOfWeights += event->GetWeight();
181  current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
182  current->SetParent(node); // Set the new node's previous node.
183  current->SetPos('r');
184  current->SetDepth( node->GetDepth() + 1 );
185  node->SetRight(current); // Make it the left node of the current one.
186  }
187  }
188  else Log() << kFATAL << "<Insert> neither left nor right :)" << Endl;
189 }
190 
191 ////////////////////////////////////////////////////////////////////////////////
192 ///search the tree to find the node matching "event"
193 
194 TMVA::BinarySearchTreeNode* TMVA::BinarySearchTree::Search( Event* event ) const
195 {
196  return this->Search( event, this->GetRoot() );
197 }
198 
199 ////////////////////////////////////////////////////////////////////////////////
200 /// Private, recursive, function for searching.
201 
202 TMVA::BinarySearchTreeNode* TMVA::BinarySearchTree::Search(Event* event, Node* node) const
203 {
204  if (node != NULL) { // If the node is not NULL...
205  // If we have found the node...
206  if (((BinarySearchTreeNode*)(node))->EqualsMe(*event))
207  return (BinarySearchTreeNode*)node; // Return it
208  if (node->GoesLeft(*event)) // If the node's data is greater than the search item...
209  return this->Search(event, node->GetLeft()); //Search the left node.
210  else //If the node's data is less than the search item...
211  return this->Search(event, node->GetRight()); //Search the right node.
212  }
213  else return NULL; //If the node is NULL, return NULL.
214 }
215 
216 ////////////////////////////////////////////////////////////////////////////////
217 /// return the sum of event (node) weights
218 
219 Double_t TMVA::BinarySearchTree::GetSumOfWeights( void ) const
220 {
221  if (fSumOfWeights <= 0) {
222  Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
223  << " I call CalcStatistics which hopefully fixes things"
224  << Endl;
225  }
226  if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
227 
228  return fSumOfWeights;
229 }
230 
231 ////////////////////////////////////////////////////////////////////////////////
232 /// return the sum of event (node) weights
233 
234 Double_t TMVA::BinarySearchTree::GetSumOfWeights( Int_t theType ) const
235 {
236  if (fSumOfWeights <= 0) {
237  Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
238  << " I call CalcStatistics which hopefully fixes things"
239  << Endl;
240  }
241  if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
242 
243  return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1 ];
244 }
245 
246 ////////////////////////////////////////////////////////////////////////////////
247 /// create the search tree from the event collection
248 /// using ONLY the variables specified in "theVars"
249 
250 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, const std::vector<Int_t>& theVars,
251  Int_t theType )
252 {
253  fPeriod = theVars.size();
254  return Fill(events, theType);
255 }
256 
257 ////////////////////////////////////////////////////////////////////////////////
258 /// create the search tree from the events in a TTree
259 /// using ALL the variables specified included in the Event
260 
261 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, Int_t theType )
262 {
263  UInt_t n=events.size();
264 
265  UInt_t nevents = 0;
266  if (fSumOfWeights != 0) {
267  Log() << kWARNING
268  << "You are filling a search three that is not empty.. "
269  << " do you know what you are doing?"
270  << Endl;
271  }
272  for (UInt_t ievt=0; ievt<n; ievt++) {
273  // insert event into binary tree
274  if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
275  this->Insert( events[ievt] );
276  nevents++;
277  fSumOfWeights += events[ievt]->GetWeight();
278  }
279  } // end of event loop
280  CalcStatistics(0);
281 
282  return fSumOfWeights;
283 }
284 
285 ////////////////////////////////////////////////////////////////////////////////
286 /// normalises the binary-search tree to reduce the branch length and hence speed up the
287 /// search procedure (on average).
288 
289 void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound,
290  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound,
291  UInt_t actDim )
292 {
293 
294  if (leftBound == rightBound) return;
295 
296  if (actDim == fPeriod) actDim = 0;
297  for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; ++i) {
298  i->first = i->second->GetValue( actDim );
299  }
300 
301  std::sort( leftBound, rightBound );
302 
303  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp = leftBound;
304  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
305 
306  // meet in the middle
307  while (true) {
308  --rightTemp;
309  if (rightTemp == leftTemp ) {
310  break;
311  }
312  ++leftTemp;
313  if (leftTemp == rightTemp) {
314  break;
315  }
316  }
317 
318  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid = leftTemp;
319  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
320 
321  if (mid!=leftBound)--midTemp;
322 
323  while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim )) {
324  --mid;
325  --midTemp;
326  }
327 
328  Insert( mid->second );
329 
330  // Print(std::cout);
331  // std::cout << std::endl << std::endl;
332 
333  NormalizeTree( leftBound, mid, actDim+1 );
334  ++mid;
335  // Print(std::cout);
336  // std::cout << std::endl << std::endl;
337  NormalizeTree( mid, rightBound, actDim+1 );
338 
339 
340  return;
341 }
342 
343 ////////////////////////////////////////////////////////////////////////////////
344 /// Normalisation of tree
345 
346 void TMVA::BinarySearchTree::NormalizeTree()
347 {
348  SetNormalize( kFALSE );
349  Clear( NULL );
350  this->SetRoot(NULL);
351  NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 );
352 }
353 
354 ////////////////////////////////////////////////////////////////////////////////
355 /// clear nodes
356 
357 void TMVA::BinarySearchTree::Clear( Node* n )
358 {
359  BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
360 
361  if (currentNode->GetLeft() != 0) Clear( currentNode->GetLeft() );
362  if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
363 
364  if (n != NULL) delete n;
365 
366  return;
367 }
368 
369 ////////////////////////////////////////////////////////////////////////////////
370 /// search the whole tree and add up all weights of events that
371 /// lie within the given volume
372 
373 Double_t TMVA::BinarySearchTree::SearchVolume( Volume* volume,
374  std::vector<const BinarySearchTreeNode*>* events )
375 {
376  return SearchVolume( this->GetRoot(), volume, 0, events );
377 }
378 
379 ////////////////////////////////////////////////////////////////////////////////
380 /// recursively walk through the daughter nodes and add up all weights of events that
381 /// lie within the given volume
382 
383 Double_t TMVA::BinarySearchTree::SearchVolume( Node* t, Volume* volume, Int_t depth,
384  std::vector<const BinarySearchTreeNode*>* events )
385 {
386  if (t==NULL) return 0; // Are we at an outer leave?
387 
388  BinarySearchTreeNode* st = (BinarySearchTreeNode*)t;
389 
390  Double_t count = 0.0;
391  if (InVolume( st->GetEventV(), volume )) {
392  count += st->GetWeight();
393  if (NULL != events) events->push_back( st );
394  }
395  if (st->GetLeft()==NULL && st->GetRight()==NULL) {
396 
397  return count; // Are we at an outer leave?
398  }
399 
400  Bool_t tl, tr;
401  Int_t d = depth%this->GetPeriode();
402  if (d != st->GetSelector()) {
403  Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
404  << d << " != " << "node "<< st->GetSelector() << Endl;
405  }
406  tl = (*(volume->fLower))[d] < st->GetEventV()[d]; // Should we descend left?
407  tr = (*(volume->fUpper))[d] >= st->GetEventV()[d]; // Should we descend right?
408 
409  if (tl) count += SearchVolume( st->GetLeft(), volume, (depth+1), events );
410  if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
411 
412  return count;
413 }
414 
415 ////////////////////////////////////////////////////////////////////////////////
416 /// test if the data points are in the given volume
417 
418 Bool_t TMVA::BinarySearchTree::InVolume(const std::vector<Float_t>& event, Volume* volume ) const
419 {
420 
421  Bool_t result = false;
422  for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
423  result = ( (*(volume->fLower))[ivar] < event[ivar] &&
424  (*(volume->fUpper))[ivar] >= event[ivar] );
425  if (!result) break;
426  }
427  return result;
428 }
429 
430 ////////////////////////////////////////////////////////////////////////////////
431 /// calculate basic statistics (mean, rms for each variable)
432 
433 void TMVA::BinarySearchTree::CalcStatistics( Node* n )
434 {
435  if (fStatisticsIsValid) return;
436 
437  BinarySearchTreeNode * currentNode = (BinarySearchTreeNode*)n;
438 
439  // default, start at the tree top, then descend recursively
440  if (n == NULL) {
441  fSumOfWeights = 0;
442  for (Int_t sb=0; sb<2; sb++) {
443  fNEventsW[sb] = 0;
444  fMeans[sb] = std::vector<Float_t>(fPeriod);
445  fRMS[sb] = std::vector<Float_t>(fPeriod);
446  fMin[sb] = std::vector<Float_t>(fPeriod);
447  fMax[sb] = std::vector<Float_t>(fPeriod);
448  fSum[sb] = std::vector<Double_t>(fPeriod);
449  fSumSq[sb] = std::vector<Double_t>(fPeriod);
450  for (UInt_t j=0; j<fPeriod; j++) {
451  fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
452  fMin[sb][j] = FLT_MAX;
453  fMax[sb][j] = -FLT_MAX;
454  }
455  }
456  currentNode = (BinarySearchTreeNode*) this->GetRoot();
457  if (currentNode == NULL) return; // no root-node
458  }
459 
460  const std::vector<Float_t> & evtVec = currentNode->GetEventV();
461  Double_t weight = currentNode->GetWeight();
462  // Int_t type = currentNode->IsSignal();
463  // Int_t type = currentNode->IsSignal() ? 0 : 1;
464  Int_t type = Int_t(currentNode->GetClass())== Types::kSignal ? 0 : 1;
465 
466  fNEventsW[type] += weight;
467  fSumOfWeights += weight;
468 
469  for (UInt_t j=0; j<fPeriod; j++) {
470  Float_t val = evtVec[j];
471  fSum[type][j] += val*weight;
472  fSumSq[type][j] += val*val*weight;
473  if (val < fMin[type][j]) fMin[type][j] = val;
474  if (val > fMax[type][j]) fMax[type][j] = val;
475  }
476 
477  if ( (currentNode->GetLeft() != NULL) ) CalcStatistics( currentNode->GetLeft() );
478  if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() );
479 
480  if (n == NULL) { // i.e. the root node
481  for (Int_t sb=0; sb<2; sb++) {
482  for (UInt_t j=0; j<fPeriod; j++) {
483  if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0; continue; }
484  fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
485  fRMS[sb][j] = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
486  }
487  }
488  fStatisticsIsValid = kTRUE;
489  }
490 
491  return;
492 }
493 
494 ////////////////////////////////////////////////////////////////////////////////
495 /// recursively walk through the daughter nodes and add up all weights of events that
496 /// lie within the given volume a maximum number of events can be given
497 
498 Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
499  Int_t max_points )
500 {
501  if (this->GetRoot() == NULL) return 0; // Are we at an outer leave?
502 
503  std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
504  std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (const BinarySearchTreeNode*)this->GetRoot(), 0 );
505  queue.push( st );
506 
507  Int_t count = 0;
508 
509  while ( !queue.empty() ) {
510  st = queue.front(); queue.pop();
511 
512  if (count == max_points)
513  return count;
514 
515  if (InVolume( st.first->GetEventV(), volume )) {
516  count++;
517  if (NULL != events) events->push_back( st.first );
518  }
519 
520  Bool_t tl, tr;
521  Int_t d = st.second;
522  if ( d == Int_t(this->GetPeriode()) ) d = 0;
523 
524  if (d != st.first->GetSelector()) {
525  Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
526  << d << " != " << "node "<< st.first->GetSelector() << Endl;
527  }
528 
529  tl = (*(volume->fLower))[d] < st.first->GetEventV()[d] && st.first->GetLeft() != NULL; // Should we descend left?
530  tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL; // Should we descend right?
531 
532  if (tl) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
533  if (tr) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );
534  }
535 
536  return count;
537 }