59 ClassImp(TMVA::BinarySearchTree);
64 TMVA::BinarySearchTree::BinarySearchTree(
void ) :
68 fStatisticsIsValid( kFALSE ),
70 fCanNormalize( kFALSE )
72 fNEventsW[0]=fNEventsW[1]=0.;
78 TMVA::BinarySearchTree::BinarySearchTree(
const BinarySearchTree &b)
80 fPeriod ( b.fPeriod ),
82 fStatisticsIsValid( kFALSE ),
83 fSumOfWeights( b.fSumOfWeights ),
84 fCanNormalize( kFALSE )
86 fNEventsW[0]=fNEventsW[1]=0.;
87 Log() << kFATAL <<
" Copy constructor not implemented yet " << Endl;
93 TMVA::BinarySearchTree::~BinarySearchTree(
void )
95 for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
96 pIt != fNormalizeTreeTable.end(); ++pIt) {
104 TMVA::BinarySearchTree* TMVA::BinarySearchTree::CreateFromXML(
void* node, UInt_t tmva_Version_Code ) {
105 std::string type(
"");
106 gTools().ReadAttr(node,
"type", type);
107 BinarySearchTree* bt =
new BinarySearchTree();
108 bt->ReadXML( node, tmva_Version_Code );
115 void TMVA::BinarySearchTree::Insert(
const Event* event )
118 fStatisticsIsValid = kFALSE;
120 if (this->GetRoot() == NULL) {
121 this->SetRoot(
new BinarySearchTreeNode(event));
123 this->GetRoot()->SetPos(
's');
124 this->GetRoot()->SetDepth(0);
126 fSumOfWeights =
event->GetWeight();
127 ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
128 this->SetPeriode(event->GetNVariables());
132 if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
133 Log() << kFATAL <<
"<Insert> event vector length != Periode specified in Binary Tree" << Endl
134 <<
"--- event size: " <<
event->GetNVariables() <<
" Periode: " << this->GetPeriode() << Endl
135 <<
"--- and all this when trying filling the "<<fNNodes+1<<
"th Node" << Endl;
138 this->Insert(event, this->GetRoot());
142 if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,
new const Event(*event)) );
148 void TMVA::BinarySearchTree::Insert(
const Event *event,
152 fStatisticsIsValid = kFALSE;
154 if (node->GoesLeft(*event)){
155 if (node->GetLeft() != NULL){
157 this->Insert(event, node->GetLeft());
161 BinarySearchTreeNode* current =
new BinarySearchTreeNode(event);
163 fSumOfWeights +=
event->GetWeight();
164 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
165 current->SetParent(node);
166 current->SetPos(
'l');
167 current->SetDepth( node->GetDepth() + 1 );
168 node->SetLeft(current);
171 else if (node->GoesRight(*event)) {
172 if (node->GetRight() != NULL) {
174 this->Insert(event, node->GetRight());
178 BinarySearchTreeNode* current =
new BinarySearchTreeNode(event);
180 fSumOfWeights +=
event->GetWeight();
181 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
182 current->SetParent(node);
183 current->SetPos(
'r');
184 current->SetDepth( node->GetDepth() + 1 );
185 node->SetRight(current);
188 else Log() << kFATAL <<
"<Insert> neither left nor right :)" << Endl;
194 TMVA::BinarySearchTreeNode* TMVA::BinarySearchTree::Search( Event* event )
const
196 return this->Search( event, this->GetRoot() );
202 TMVA::BinarySearchTreeNode* TMVA::BinarySearchTree::Search(Event* event, Node* node)
const
206 if (((BinarySearchTreeNode*)(node))->EqualsMe(*event))
207 return (BinarySearchTreeNode*)node;
208 if (node->GoesLeft(*event))
209 return this->Search(event, node->GetLeft());
211 return this->Search(event, node->GetRight());
219 Double_t TMVA::BinarySearchTree::GetSumOfWeights(
void )
const
221 if (fSumOfWeights <= 0) {
222 Log() << kWARNING <<
"you asked for the SumOfWeights, which is not filled yet"
223 <<
" I call CalcStatistics which hopefully fixes things"
226 if (fSumOfWeights <= 0) Log() << kFATAL <<
" Zero events in your Search Tree" <<Endl;
228 return fSumOfWeights;
234 Double_t TMVA::BinarySearchTree::GetSumOfWeights( Int_t theType )
const
236 if (fSumOfWeights <= 0) {
237 Log() << kWARNING <<
"you asked for the SumOfWeights, which is not filled yet"
238 <<
" I call CalcStatistics which hopefully fixes things"
241 if (fSumOfWeights <= 0) Log() << kFATAL <<
" Zero events in your Search Tree" <<Endl;
243 return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1 ];
250 Double_t TMVA::BinarySearchTree::Fill(
const std::vector<Event*>& events,
const std::vector<Int_t>& theVars,
253 fPeriod = theVars.size();
254 return Fill(events, theType);
261 Double_t TMVA::BinarySearchTree::Fill(
const std::vector<Event*>& events, Int_t theType )
263 UInt_t n=events.size();
266 if (fSumOfWeights != 0) {
268 <<
"You are filling a search three that is not empty.. "
269 <<
" do you know what you are doing?"
272 for (UInt_t ievt=0; ievt<n; ievt++) {
274 if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
275 this->Insert( events[ievt] );
277 fSumOfWeights += events[ievt]->GetWeight();
282 return fSumOfWeights;
289 void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound,
290 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound,
294 if (leftBound == rightBound)
return;
296 if (actDim == fPeriod) actDim = 0;
297 for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; ++i) {
298 i->first = i->second->GetValue( actDim );
301 std::sort( leftBound, rightBound );
303 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp = leftBound;
304 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
309 if (rightTemp == leftTemp ) {
313 if (leftTemp == rightTemp) {
318 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid = leftTemp;
319 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
321 if (mid!=leftBound)--midTemp;
323 while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim )) {
328 Insert( mid->second );
333 NormalizeTree( leftBound, mid, actDim+1 );
337 NormalizeTree( mid, rightBound, actDim+1 );
346 void TMVA::BinarySearchTree::NormalizeTree()
348 SetNormalize( kFALSE );
351 NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 );
357 void TMVA::BinarySearchTree::Clear( Node* n )
359 BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
361 if (currentNode->GetLeft() != 0) Clear( currentNode->GetLeft() );
362 if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
364 if (n != NULL)
delete n;
373 Double_t TMVA::BinarySearchTree::SearchVolume( Volume* volume,
374 std::vector<const BinarySearchTreeNode*>* events )
376 return SearchVolume( this->GetRoot(), volume, 0, events );
383 Double_t TMVA::BinarySearchTree::SearchVolume( Node* t, Volume* volume, Int_t depth,
384 std::vector<const BinarySearchTreeNode*>* events )
386 if (t==NULL)
return 0;
388 BinarySearchTreeNode* st = (BinarySearchTreeNode*)t;
390 Double_t count = 0.0;
391 if (InVolume( st->GetEventV(), volume )) {
392 count += st->GetWeight();
393 if (NULL != events) events->push_back( st );
395 if (st->GetLeft()==NULL && st->GetRight()==NULL) {
401 Int_t d = depth%this->GetPeriode();
402 if (d != st->GetSelector()) {
403 Log() << kFATAL <<
"<SearchVolume> selector in Searchvolume "
404 << d <<
" != " <<
"node "<< st->GetSelector() << Endl;
406 tl = (*(volume->fLower))[d] < st->GetEventV()[d];
407 tr = (*(volume->fUpper))[d] >= st->GetEventV()[d];
409 if (tl) count += SearchVolume( st->GetLeft(), volume, (depth+1), events );
410 if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
418 Bool_t TMVA::BinarySearchTree::InVolume(
const std::vector<Float_t>& event, Volume* volume )
const
421 Bool_t result =
false;
422 for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
423 result = ( (*(volume->fLower))[ivar] <
event[ivar] &&
424 (*(volume->fUpper))[ivar] >=
event[ivar] );
433 void TMVA::BinarySearchTree::CalcStatistics( Node* n )
435 if (fStatisticsIsValid)
return;
437 BinarySearchTreeNode * currentNode = (BinarySearchTreeNode*)n;
442 for (Int_t sb=0; sb<2; sb++) {
444 fMeans[sb] = std::vector<Float_t>(fPeriod);
445 fRMS[sb] = std::vector<Float_t>(fPeriod);
446 fMin[sb] = std::vector<Float_t>(fPeriod);
447 fMax[sb] = std::vector<Float_t>(fPeriod);
448 fSum[sb] = std::vector<Double_t>(fPeriod);
449 fSumSq[sb] = std::vector<Double_t>(fPeriod);
450 for (UInt_t j=0; j<fPeriod; j++) {
451 fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
452 fMin[sb][j] = FLT_MAX;
453 fMax[sb][j] = -FLT_MAX;
456 currentNode = (BinarySearchTreeNode*) this->GetRoot();
457 if (currentNode == NULL)
return;
460 const std::vector<Float_t> & evtVec = currentNode->GetEventV();
461 Double_t weight = currentNode->GetWeight();
464 Int_t type = Int_t(currentNode->GetClass())== Types::kSignal ? 0 : 1;
466 fNEventsW[type] += weight;
467 fSumOfWeights += weight;
469 for (UInt_t j=0; j<fPeriod; j++) {
470 Float_t val = evtVec[j];
471 fSum[type][j] += val*weight;
472 fSumSq[type][j] += val*val*weight;
473 if (val < fMin[type][j]) fMin[type][j] = val;
474 if (val > fMax[type][j]) fMax[type][j] = val;
477 if ( (currentNode->GetLeft() != NULL) ) CalcStatistics( currentNode->GetLeft() );
478 if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() );
481 for (Int_t sb=0; sb<2; sb++) {
482 for (UInt_t j=0; j<fPeriod; j++) {
483 if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0;
continue; }
484 fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
485 fRMS[sb][j] = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
488 fStatisticsIsValid = kTRUE;
498 Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
501 if (this->GetRoot() == NULL)
return 0;
503 std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
504 std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (
const BinarySearchTreeNode*)this->GetRoot(), 0 );
509 while ( !queue.empty() ) {
510 st = queue.front(); queue.pop();
512 if (count == max_points)
515 if (InVolume( st.first->GetEventV(), volume )) {
517 if (NULL != events) events->push_back( st.first );
522 if ( d == Int_t(this->GetPeriode()) ) d = 0;
524 if (d != st.first->GetSelector()) {
525 Log() << kFATAL <<
"<SearchVolume> selector in Searchvolume "
526 << d <<
" != " <<
"node "<< st.first->GetSelector() << Endl;
529 tl = (*(volume->fLower))[d] < st.first->GetEventV()[d] && st.first->GetLeft() != NULL;
530 tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL;
532 if (tl) queue.push( std::make_pair( (
const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
533 if (tr) queue.push( std::make_pair( (
const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );