51 TMVA::SVWorkingSet::SVWorkingSet()
52 : fdoRegression(kFALSE),
62 fLogger( new MsgLogger(
"SVWorkingSet", kINFO ) )
69 TMVA::SVWorkingSet::SVWorkingSet(std::vector<TMVA::SVEvent*>*inputVectors, SVKernelFunction* kernelFunction,
70 Float_t tol, Bool_t doreg)
71 : fdoRegression(doreg),
72 fInputData(inputVectors),
74 fKFunction(kernelFunction),
80 fLogger( new MsgLogger(
"SVWorkingSet", kINFO ) )
82 fKMatrix =
new TMVA::SVKernelMatrix(inputVectors, kernelFunction);
84 for( UInt_t i = 0; i < fInputData->size(); i++){
85 pt = fKMatrix->GetLine(i);
86 fInputData->at(i)->SetLine(pt);
87 fInputData->at(i)->SetNs(i);
88 if(fdoRegression) fInputData->at(i)->SetErrorCache(fInputData->at(i)->GetTarget());
91 UInt_t kk = rand.Integer(fInputData->size());
93 fTEventLow = fTEventUp =fInputData->at(0);
94 fB_low = fTEventUp ->GetTarget() - fTolerance;
95 fB_up = fTEventLow->GetTarget() + fTolerance;
99 if(fInputData->at(kk)->GetTypeFlag()==-1){
100 fTEventLow = fInputData->at(kk);
103 kk = rand.Integer(fInputData->size());
107 if (fInputData->at(kk)->GetTypeFlag()==1) {
108 fTEventUp = fInputData->at(kk);
111 kk = rand.Integer(fInputData->size());
114 fTEventUp ->SetErrorCache(fTEventUp->GetTarget());
115 fTEventLow->SetErrorCache(fTEventUp->GetTarget());
121 TMVA::SVWorkingSet::~SVWorkingSet()
123 if (fKMatrix != 0) {
delete fKMatrix; fKMatrix = 0;}
129 Bool_t TMVA::SVWorkingSet::ExamineExample( TMVA::SVEvent* jevt )
132 Float_t fErrorC_J = 0.;
133 if( jevt->GetIdx()==0) fErrorC_J = jevt->GetErrorCache();
135 Float_t *fKVals = jevt->GetLine();
137 std::vector<TMVA::SVEvent*>::iterator idIter;
140 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
141 if((*idIter)->GetAlpha()>0)
142 fErrorC_J += (*idIter)->GetAlpha()*(*idIter)->GetTypeFlag()*fKVals[k];
147 fErrorC_J -= jevt->GetTypeFlag();
148 jevt->SetErrorCache(fErrorC_J);
150 if((jevt->GetIdx() == 1) && (fErrorC_J < fB_up )){
154 else if ((jevt->GetIdx() == -1)&&(fErrorC_J > fB_low)) {
159 Bool_t converged = kTRUE;
161 if((jevt->GetIdx()>=0) && (fB_low - fErrorC_J > 2*fTolerance)) {
166 if((jevt->GetIdx()<=0) && (fErrorC_J - fB_up > 2*fTolerance)) {
171 if (converged)
return kFALSE;
173 if(jevt->GetIdx()==0){
174 if(fB_low - fErrorC_J > fErrorC_J - fB_up) ievt = fTEventLow;
175 else ievt = fTEventUp;
178 if (TakeStep(ievt, jevt))
return kTRUE;
185 Bool_t TMVA::SVWorkingSet::TakeStep(TMVA::SVEvent* ievt,TMVA::SVEvent* jevt )
187 if (ievt == jevt)
return kFALSE;
188 std::vector<TMVA::SVEvent*>::iterator idIter;
189 const Float_t epsilon = 1e-8;
191 Float_t type_I, type_J;
192 Float_t errorC_I, errorC_J;
193 Float_t alpha_I, alpha_J;
195 Float_t newAlpha_I, newAlpha_J;
198 Float_t l, h, lobj = 0, hobj = 0;
201 type_I = ievt->GetTypeFlag();
202 alpha_I = ievt->GetAlpha();
203 errorC_I = ievt->GetErrorCache();
205 type_J = jevt->GetTypeFlag();
206 alpha_J = jevt->GetAlpha();
207 errorC_J = jevt->GetErrorCache();
209 s = Int_t( type_I * type_J );
211 Float_t c_i = ievt->GetCweight();
213 Float_t c_j = jevt->GetCweight();
217 if (type_I == type_J) {
218 Float_t gamma = alpha_I + alpha_J;
248 Float_t gamma = alpha_I - alpha_J;
251 if ( gamma >= (c_i - c_j) )
258 if ( (c_i - c_j) >= gamma)
265 if (l == h)
return kFALSE;
266 Float_t kernel_II, kernel_IJ, kernel_JJ;
268 kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
269 kernel_IJ = fKMatrix->GetElement(ievt->GetNs(), jevt->GetNs());
270 kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
272 eta = 2*kernel_IJ - kernel_II - kernel_JJ;
274 newAlpha_J = alpha_J + (type_J*( errorC_J - errorC_I ))/eta;
275 if (newAlpha_J < l) newAlpha_J = l;
276 else if (newAlpha_J > h) newAlpha_J = h;
283 Float_t c_J = type_J*( errorC_I - errorC_J ) - eta * alpha_J;
284 lobj = c_I * l * l + c_J * l;
285 hobj = c_I * h * h + c_J * h;
287 if (lobj > hobj + epsilon) newAlpha_J = l;
288 else if (lobj < hobj - epsilon) newAlpha_J = h;
289 else newAlpha_J = alpha_J;
292 if (TMath::Abs( newAlpha_J - alpha_J ) < ( epsilon * ( newAlpha_J + alpha_J+ epsilon ))){
296 newAlpha_I = alpha_I - s*( newAlpha_J - alpha_J );
298 if (newAlpha_I < 0) {
299 newAlpha_J += s* newAlpha_I;
302 else if (newAlpha_I > c_i) {
303 Float_t temp = newAlpha_I - c_i;
304 newAlpha_J += s * temp;
308 Float_t dL_I = type_I * ( newAlpha_I - alpha_I );
309 Float_t dL_J = type_J * ( newAlpha_J - alpha_J );
312 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
314 if((*idIter)->GetIdx()==0){
315 Float_t ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
316 Float_t jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
318 (*idIter)->UpdateErrorCache(dL_I * ii + dL_J * jj);
321 ievt->SetAlpha(newAlpha_I);
322 jevt->SetAlpha(newAlpha_J);
328 ievt->SetErrorCache(errorC_I + dL_I*kernel_II + dL_J*kernel_IJ);
329 jevt->SetErrorCache(errorC_J + dL_I*kernel_IJ + dL_J*kernel_JJ);
336 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
337 if((*idIter)->GetIdx()==0){
338 if((*idIter)->GetErrorCache()> fB_low){
339 fB_low = (*idIter)->GetErrorCache();
340 fTEventLow = (*idIter);
342 if( (*idIter)->GetErrorCache()< fB_up){
343 fB_up =(*idIter)->GetErrorCache();
344 fTEventUp = (*idIter);
350 if (fB_low < TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
351 if (ievt->GetErrorCache() > fB_low) {
352 fB_low = ievt->GetErrorCache();
356 fB_low = jevt->GetErrorCache();
361 if (fB_up > TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
362 if (ievt->GetErrorCache()< fB_low) {
363 fB_up =ievt->GetErrorCache();
367 fB_up =jevt->GetErrorCache() ;
376 Bool_t TMVA::SVWorkingSet::Terminated()
378 if((fB_up > fB_low - 2*fTolerance))
return kTRUE;
385 void TMVA::SVWorkingSet::Train(UInt_t nMaxIter)
388 Int_t numChanged = 0;
389 Int_t examineAll = 1;
391 Float_t numChangedOld = 0;
392 Int_t deltaChanges = 0;
395 std::vector<TMVA::SVEvent*>::iterator idIter;
397 while ((numChanged > 0) || (examineAll > 0)) {
398 if (fIPyCurrentIter) *fIPyCurrentIter = numit;
399 if (fExitFromTraining && *fExitFromTraining)
break;
402 for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter){
403 if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
404 else numChanged += (UInt_t)ExamineExampleReg(*idIter);
408 for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter) {
409 if ((*idIter)->IsInI0()) {
410 if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
411 else numChanged += (UInt_t)ExamineExampleReg(*idIter);
420 if (examineAll == 1) examineAll = 0;
421 else if (numChanged == 0 || numChanged < 10 || deltaChanges > 3 ) examineAll = 1;
423 if (numChanged == numChangedOld) deltaChanges++;
424 else deltaChanges = 0;
425 numChangedOld = numChanged;
428 if (numit >= nMaxIter) {
430 <<
"Max number of iterations exceeded. "
431 <<
"Training may not be completed. Try use less Cost parameter" << Endl;
439 void TMVA::SVWorkingSet::SetIndex( TMVA::SVEvent* event )
441 if( (0< event->GetAlpha()) && (event->GetAlpha()<
event->GetCweight()))
444 if( event->GetTypeFlag() == 1){
445 if( event->GetAlpha() == 0)
447 else if( event->GetAlpha() ==
event->GetCweight() )
450 if( event->GetTypeFlag() == -1){
451 if( event->GetAlpha() == 0)
453 else if( event->GetAlpha() ==
event->GetCweight() )
460 void TMVA::SVWorkingSet::PrintStat()
462 std::vector<TMVA::SVEvent*>::iterator idIter;
464 for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter)
465 if((*idIter)->GetAlpha() !=0) counter++;
470 std::vector<TMVA::SVEvent*>* TMVA::SVWorkingSet::GetSupportVectors()
472 std::vector<TMVA::SVEvent*>::iterator idIter;
473 if( fSupVec != 0) {
delete fSupVec; fSupVec = 0; }
474 fSupVec =
new std::vector<TMVA::SVEvent*>(0);
476 for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
477 if((*idIter)->GetDeltaAlpha() !=0){
478 fSupVec->push_back((*idIter));
486 Bool_t TMVA::SVWorkingSet::TakeStepReg(TMVA::SVEvent* ievt,TMVA::SVEvent* jevt )
488 if (ievt == jevt)
return kFALSE;
489 std::vector<TMVA::SVEvent*>::iterator idIter;
490 const Float_t epsilon = 0.001*fTolerance;
492 const Float_t kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
493 const Float_t kernel_IJ = fKMatrix->GetElement(ievt->GetNs(),jevt->GetNs());
494 const Float_t kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
497 const Float_t eta = -2*kernel_IJ + kernel_II + kernel_JJ;
498 const Float_t gamma = ievt->GetDeltaAlpha() + jevt->GetDeltaAlpha();
503 Bool_t caseA, caseB, caseC, caseD, terminated;
504 caseA = caseB = caseC = caseD = terminated = kFALSE;
505 Float_t b_alpha_i, b_alpha_j, b_alpha_i_p, b_alpha_j_p;
506 const Float_t b_cost_i = ievt->GetCweight();
507 const Float_t b_cost_j = jevt->GetCweight();
509 b_alpha_i = ievt->GetAlpha();
510 b_alpha_j = jevt->GetAlpha();
511 b_alpha_i_p = ievt->GetAlpha_p();
512 b_alpha_j_p = jevt->GetAlpha_p();
515 Float_t deltafi = ievt->GetErrorCache()-jevt->GetErrorCache();
519 const Float_t null = 0.;
521 Float_t tmp_alpha_i, tmp_alpha_j;
522 tmp_alpha_i = tmp_alpha_j = 0.;
525 if((caseA == kFALSE) && (b_alpha_i > 0 || (b_alpha_i_p == 0 && deltafi > 0)) && (b_alpha_j > 0 || (b_alpha_j_p == 0 && deltafi < 0)))
528 low = TMath::Max( null, gamma - b_cost_j );
529 high = TMath::Min( b_cost_i , gamma);
532 tmp_alpha_j = b_alpha_j - (deltafi/eta);
533 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
534 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
535 tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j);
538 if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
539 b_alpha_j = tmp_alpha_j;
540 b_alpha_i = tmp_alpha_i;
549 else if((caseB==kFALSE) && (b_alpha_i>0 || (b_alpha_i_p==0 && deltafi >2*epsilon )) && (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi>2*epsilon)))
552 low = TMath::Max( null, gamma );
553 high = TMath::Min( b_cost_i , b_cost_j + gamma);
557 tmp_alpha_j = b_alpha_j_p - ((deltafi-2*epsilon)/eta);
558 tmp_alpha_j = TMath::Min(tmp_alpha_j,high);
559 tmp_alpha_j = TMath::Max(low,tmp_alpha_j);
560 tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j_p);
563 if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
564 b_alpha_j_p = tmp_alpha_j;
565 b_alpha_i = tmp_alpha_i;
573 else if((caseC==kFALSE) && (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi < -2*epsilon )) && (b_alpha_j>0 || (b_alpha_j_p==0 && deltafi< -2*epsilon)))
576 low = TMath::Max(null, -gamma );
577 high = TMath::Min(b_cost_i, -gamma+b_cost_j);
580 tmp_alpha_j = b_alpha_j - ((deltafi+2*epsilon)/eta);
581 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
582 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
583 tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j);
586 if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
587 b_alpha_j = tmp_alpha_j;
588 b_alpha_i_p = tmp_alpha_i;
596 else if((caseD == kFALSE) &&
597 (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi <0 )) &&
598 (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi >0 )))
601 low = TMath::Max(null,-gamma - b_cost_j);
602 high = TMath::Min(b_cost_i, -gamma);
605 tmp_alpha_j = b_alpha_j_p + (deltafi/eta);
606 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
607 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
608 tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j_p);
610 if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
611 b_alpha_j_p = tmp_alpha_j;
612 b_alpha_i_p = tmp_alpha_i;
624 deltafi += ievt->GetDeltaAlpha()*(kernel_II - kernel_IJ) + jevt->GetDeltaAlpha()*(kernel_IJ - kernel_JJ);
626 if( IsDiffSignificant(b_alpha_i, ievt->GetAlpha(), epsilon) ||
627 IsDiffSignificant(b_alpha_j, jevt->GetAlpha(), epsilon) ||
628 IsDiffSignificant(b_alpha_i_p, ievt->GetAlpha_p(), epsilon) ||
629 IsDiffSignificant(b_alpha_j_p, jevt->GetAlpha_p(), epsilon) ){
633 const Float_t diff_alpha_i = ievt->GetDeltaAlpha()+b_alpha_i_p - ievt->GetAlpha();
634 const Float_t diff_alpha_j = jevt->GetDeltaAlpha()+b_alpha_j_p - jevt->GetAlpha();
638 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
641 if((*idIter)->GetIdx()==0){
642 Float_t k_ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
643 Float_t k_jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
645 (*idIter)->UpdateErrorCache(diff_alpha_i * k_ii + diff_alpha_j * k_jj);
650 ievt->SetAlpha(b_alpha_i);
651 jevt->SetAlpha(b_alpha_j);
652 ievt->SetAlpha_p(b_alpha_i_p);
653 jevt->SetAlpha_p(b_alpha_j_p);
662 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
663 if((!(*idIter)->IsInI3()) && ((*idIter)->GetErrorCache()> fB_low)){
664 fB_low = (*idIter)->GetErrorCache();
665 fTEventLow = (*idIter);
668 if((!(*idIter)->IsInI2()) && ((*idIter)->GetErrorCache()< fB_up)){
669 fB_up =(*idIter)->GetErrorCache();
670 fTEventUp = (*idIter);
674 }
else return kFALSE;
680 Bool_t TMVA::SVWorkingSet::ExamineExampleReg(TMVA::SVEvent* jevt)
684 Float_t fErrorC_J = 0.;
685 if( jevt->IsInI0()) {
686 fErrorC_J = jevt->GetErrorCache();
689 Float_t *fKVals = jevt->GetLine();
691 std::vector<TMVA::SVEvent*>::iterator idIter;
694 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
695 fErrorC_J -= (*idIter)->GetDeltaAlpha()*fKVals[k];
699 fErrorC_J += jevt->GetTarget();
700 jevt->SetErrorCache(fErrorC_J);
703 if(fErrorC_J + feps < fB_up ){
704 fB_up = fErrorC_J + feps;
707 else if(fErrorC_J -feps > fB_low) {
708 fB_low = fErrorC_J - feps;
711 }
else if((jevt->IsInI2()) && (fErrorC_J + feps > fB_low)){
712 fB_low = fErrorC_J + feps;
714 }
else if((jevt->IsInI3()) && (fErrorC_J - feps < fB_up)){
715 fB_up = fErrorC_J - feps;
720 Bool_t converged = kTRUE;
723 if( fB_low -fErrorC_J + feps > 2*fTolerance){
726 if(fErrorC_J-feps-fB_up > fB_low-fErrorC_J+feps){
729 }
else if(fErrorC_J -feps - fB_up > 2*fTolerance){
732 if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
740 if( fB_low -fErrorC_J - feps > 2*fTolerance){
743 if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
746 }
else if(fErrorC_J + feps - fB_up > 2*fTolerance){
749 if(fB_low - fErrorC_J-feps > fErrorC_J+feps -fB_up){
757 if( fB_low -fErrorC_J - feps > 2*fTolerance){
760 if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
763 }
else if(fErrorC_J - feps - fB_up > 2*fTolerance){
766 if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
774 if( fErrorC_J + feps -fB_up > 2*fTolerance){
782 if(fB_low -fErrorC_J +feps > 2*fTolerance){
788 if(converged)
return kFALSE;
789 if (TakeStepReg(ievt, jevt))
return kTRUE;
793 Bool_t TMVA::SVWorkingSet::IsDiffSignificant(Float_t a_i, Float_t a_j, Float_t eps)
795 if( TMath::Abs(a_i - a_j) > eps*(a_i + a_j + eps))
return kTRUE;