64 TMVA::Rule::Rule( RuleEnsemble *re,
65 const std::vector< const Node * >& nodes )
70 , fCoefficient ( 0.0 )
72 , fImportanceRef ( 1.0 )
73 , fRuleEnsemble ( re )
76 , fLogger( new MsgLogger(
"RuleFit") )
84 fCut =
new RuleCut( nodes );
85 fSSB = fCut->GetPurity();
86 fSSBNeve = fCut->GetCutNeve();
92 TMVA::Rule::Rule( RuleEnsemble *re )
97 , fCoefficient ( 0.0 )
99 , fImportanceRef ( 1.0 )
100 , fRuleEnsemble ( re )
103 , fLogger( new MsgLogger(
"RuleFit") )
115 , fCoefficient ( 0.0 )
116 , fImportance ( 0.0 )
117 , fImportanceRef ( 1.0 )
118 , fRuleEnsemble ( 0 )
121 , fLogger( new MsgLogger(
"RuleFit") )
137 Bool_t TMVA::Rule::ContainsVariable(UInt_t iv)
const
139 Bool_t found = kFALSE;
140 Bool_t doneLoop = kFALSE;
141 UInt_t nvars = fCut->GetNvars();
145 found = (fCut->GetSelector(i) == iv);
147 doneLoop = (found || (i==nvars));
154 void TMVA::Rule::SetMsgType( EMsgType t )
156 fLogger->SetMinType(t);
170 Bool_t TMVA::Rule::Equal(
const Rule& other, Bool_t useCutValue, Double_t mindist )
const
173 if (mindist<0) useCutValue=kFALSE;
174 Double_t d = RuleDist( other, useCutValue );
176 if (useCutValue) rval = ( (!(d<0)) && (d<mindist) );
177 else rval = (!(d<0));
190 Double_t TMVA::Rule::RuleDist(
const Rule& other, Bool_t useCutValue )
const
192 if (fCut->GetNvars()!=other.GetRuleCut()->GetNvars())
return -1.0;
194 const UInt_t nvars = fCut->GetNvars();
200 Double_t vminA,vmaxA;
201 Double_t vminB,vmaxB;
212 const RuleCut *otherCut = other.GetRuleCut();
213 while ((equal) && (in<nvars)) {
215 equal = ( (fCut->GetSelector(in) == (otherCut->GetSelector(in))) &&
216 (fCut->GetCutDoMin(in) == (otherCut->GetCutDoMin(in))) &&
217 (fCut->GetCutDoMax(in) == (otherCut->GetCutDoMax(in))) );
221 sel = fCut->GetSelector(in);
222 vminA = fCut->GetCutMin(in);
223 vmaxA = fCut->GetCutMax(in);
224 vminB = other.GetRuleCut()->GetCutMin(in);
225 vmaxB = other.GetRuleCut()->GetCutMax(in);
227 rms = fRuleEnsemble->GetRuleFit()->GetMethodBase()->GetRMS(sel);
230 if (fCut->GetCutDoMin(in))
231 smin = ( rms>0 ? (vminA-vminB)/rms : 0 );
232 if (fCut->GetCutDoMax(in))
233 smax = ( rms>0 ? (vmaxA-vmaxB)/rms : 0 );
234 sumdc2 += smin*smin + smax*smax;
240 if (!useCutValue) sumdc2 = (equal ? 0.0:-1.0);
241 else sumdc2 = (equal ? sqrt(sumdc2) : -1.0);
249 Bool_t TMVA::Rule::operator==(
const Rule& other )
const
251 return this->Equal( other, kTRUE, 1e-3 );
257 Bool_t TMVA::Rule::operator<(
const Rule& other )
const
259 return (fImportance < other.GetImportance());
265 std::ostream& TMVA::operator<< ( std::ostream& os,
const Rule& rule )
274 const TString & TMVA::Rule::GetVarName( Int_t i )
const
276 return fRuleEnsemble->GetMethodBase()->GetInputLabel(i);
282 void TMVA::Rule::Copy(
const Rule& other )
285 SetRuleEnsemble( other.GetRuleEnsemble() );
286 fCut =
new RuleCut( *(other.GetRuleCut()) );
287 fSSB = other.GetSSB();
288 fSSBNeve = other.GetSSBNeve();
289 SetCoefficient(other.GetCoefficient());
290 SetSupport( other.GetSupport() );
291 SetSigma( other.GetSigma() );
292 SetNorm( other.GetNorm() );
294 SetImportanceRef( other.GetImportanceRef() );
301 void TMVA::Rule::Print( std::ostream& os )
const
303 const UInt_t nvars = fCut->GetNvars();
304 if (nvars<1) os <<
" *** WARNING - <EMPTY RULE> ***" << std::endl;
307 Double_t valmin, valmax;
309 os <<
" Importance = " << Form(
"%1.4f", fImportance/fImportanceRef) << std::endl;
310 os <<
" Coefficient = " << Form(
"%1.4f", fCoefficient) << std::endl;
311 os <<
" Support = " << Form(
"%1.4f", fSupport) << std::endl;
312 os <<
" S/(S+B) = " << Form(
"%1.4f", fSSB) << std::endl;
314 for ( UInt_t i=0; i<nvars; i++) {
316 sel = fCut->GetSelector(i);
317 valmin = fCut->GetCutMin(i);
318 valmax = fCut->GetCutMax(i);
320 os << Form(
"* Cut %2d",i+1) <<
" : " << std::flush;
321 if (fCut->GetCutDoMin(i)) os << Form(
"%10.3g",valmin) <<
" < " << std::flush;
322 else os <<
" " << std::flush;
323 os << GetVarName(sel) << std::flush;
324 if (fCut->GetCutDoMax(i)) os <<
" < " << Form(
"%10.3g",valmax) << std::flush;
325 else os <<
" " << std::flush;
333 void TMVA::Rule::PrintLogger(
const char *title)
const
335 const UInt_t nvars = fCut->GetNvars();
336 if (nvars<1) Log() << kWARNING <<
"BUG TRAP: EMPTY RULE!!!" << Endl;
339 Double_t valmin, valmax;
341 if (title) Log() << kINFO << title;
343 <<
"Importance = " << Form(
"%1.4f", fImportance/fImportanceRef) << Endl;
345 for ( UInt_t i=0; i<nvars; i++) {
347 Log() << kINFO <<
" ";
348 sel = fCut->GetSelector(i);
349 valmin = fCut->GetCutMin(i);
350 valmax = fCut->GetCutMax(i);
352 Log() << kINFO << Form(
"Cut %2d",i+1) <<
" : ";
353 if (fCut->GetCutDoMin(i)) Log() << kINFO << Form(
"%10.3g",valmin) <<
" < ";
354 else Log() << kINFO <<
" ";
355 Log() << kINFO << GetVarName(sel);
356 if (fCut->GetCutDoMax(i)) Log() << kINFO <<
" < " << Form(
"%10.3g",valmax);
357 else Log() << kINFO <<
" ";
365 void TMVA::Rule::PrintRaw( std::ostream& os )
const
367 Int_t dp = os.precision();
368 const UInt_t nvars = fCut->GetNvars();
370 << std::setprecision(10)
371 << fImportance <<
" "
372 << fImportanceRef <<
" "
373 << fCoefficient <<
" "
380 os <<
"N(cuts): " << nvars << std::endl;
381 for ( UInt_t i=0; i<nvars; i++) {
382 os <<
"Cut " << i <<
" : " << std::flush;
383 os << fCut->GetSelector(i)
384 << std::setprecision(10)
385 <<
" " << fCut->GetCutMin(i)
386 <<
" " << fCut->GetCutMax(i)
387 <<
" " << (fCut->GetCutDoMin(i) ?
"T":
"F")
388 <<
" " << (fCut->GetCutDoMax(i) ?
"T":
"F")
391 os << std::setprecision(dp);
396 void* TMVA::Rule::AddXMLTo(
void* parent )
const
398 void* rule = gTools().AddChild( parent,
"Rule" );
399 const UInt_t nvars = fCut->GetNvars();
401 gTools().AddAttr( rule,
"Importance", fImportance );
402 gTools().AddAttr( rule,
"Ref", fImportanceRef );
403 gTools().AddAttr( rule,
"Coeff", fCoefficient );
404 gTools().AddAttr( rule,
"Support", fSupport );
405 gTools().AddAttr( rule,
"Sigma", fSigma );
406 gTools().AddAttr( rule,
"Norm", fNorm );
407 gTools().AddAttr( rule,
"SSB", fSSB );
408 gTools().AddAttr( rule,
"SSBNeve", fSSBNeve );
409 gTools().AddAttr( rule,
"Nvars", nvars );
411 for (UInt_t i=0; i<nvars; i++) {
412 void* cut = gTools().AddChild( rule,
"Cut" );
413 gTools().AddAttr( cut,
"Selector", fCut->GetSelector(i) );
414 gTools().AddAttr( cut,
"Min", fCut->GetCutMin(i) );
415 gTools().AddAttr( cut,
"Max", fCut->GetCutMax(i) );
416 gTools().AddAttr( cut,
"DoMin", (fCut->GetCutDoMin(i) ?
"T":
"F") );
417 gTools().AddAttr( cut,
"DoMax", (fCut->GetCutDoMax(i) ?
"T":
"F") );
426 void TMVA::Rule::ReadFromXML(
void* wghtnode )
428 TString nodeName = TString( gTools().GetName(wghtnode) );
429 if (nodeName !=
"Rule") Log() << kFATAL <<
"<ReadFromXML> Unexpected node name: " << nodeName << Endl;
431 gTools().ReadAttr( wghtnode,
"Importance", fImportance );
432 gTools().ReadAttr( wghtnode,
"Ref", fImportanceRef );
433 gTools().ReadAttr( wghtnode,
"Coeff", fCoefficient );
434 gTools().ReadAttr( wghtnode,
"Support", fSupport );
435 gTools().ReadAttr( wghtnode,
"Sigma", fSigma );
436 gTools().ReadAttr( wghtnode,
"Norm", fNorm );
437 gTools().ReadAttr( wghtnode,
"SSB", fSSB );
438 gTools().ReadAttr( wghtnode,
"SSBNeve", fSSBNeve );
441 gTools().ReadAttr( wghtnode,
"Nvars", nvars );
442 if (fCut)
delete fCut;
443 fCut =
new RuleCut();
444 fCut->SetNvars( nvars );
447 void* ch = gTools().GetChild( wghtnode );
453 gTools().ReadAttr( ch,
"Selector", ui );
454 fCut->SetSelector( i, ui );
455 gTools().ReadAttr( ch,
"Min", d );
456 fCut->SetCutMin ( i, d );
457 gTools().ReadAttr( ch,
"Max", d );
458 fCut->SetCutMax ( i, d );
459 gTools().ReadAttr( ch,
"DoMin", c );
460 fCut->SetCutDoMin( i, (c ==
'T' ? kTRUE : kFALSE ) );
461 gTools().ReadAttr( ch,
"DoMax", c );
462 fCut->SetCutDoMax( i, (c ==
'T' ? kTRUE : kFALSE ) );
465 ch = gTools().GetNextChild(ch);
469 if (i != nvars) Log() << kFATAL <<
"<ReadFromXML> Mismatch in number of cuts: " << i <<
" != " << nvars << Endl;
475 void TMVA::Rule::ReadRaw( std::istream& istr )
489 istr >> dummy >> nvars;
490 Double_t cutmin,cutmax;
494 if (fCut)
delete fCut;
495 fCut =
new RuleCut();
496 fCut->SetNvars( nvars );
497 for ( UInt_t i=0; i<nvars; i++) {
498 istr >> dummy >> idum;
500 istr >> sel >> cutmin >> cutmax >> bA >> bB;
501 fCut->SetSelector(i,sel);
502 fCut->SetCutMin(i,cutmin);
503 fCut->SetCutMax(i,cutmax);
504 fCut->SetCutDoMin(i,(bA==
'T' ? kTRUE:kFALSE));
505 fCut->SetCutDoMax(i,(bB==
'T' ? kTRUE:kFALSE));