128 TMVA::Reader::Reader(
const TString& theOption, Bool_t verbose )
129 : Configurable( theOption ),
130 fDataSetManager( NULL ),
135 fCalculateError(kFALSE),
137 fMvaEventErrorUpper( 0 ),
140 fDataSetManager =
new DataSetManager( fDataInputHandler );
141 fDataSetManager->AddDataSetInfo(fDataSetInfo);
142 fLogger =
new MsgLogger(
this);
143 SetConfigName( GetName() );
153 TMVA::Reader::Reader( std::vector<TString>& inputVars,
const TString& theOption, Bool_t verbose )
154 : Configurable( theOption ),
155 fDataSetManager( NULL ),
160 fCalculateError(kFALSE),
162 fMvaEventErrorUpper( 0 ),
165 fDataSetManager =
new DataSetManager( fDataInputHandler );
166 fDataSetManager->AddDataSetInfo(fDataSetInfo);
167 fLogger =
new MsgLogger(
this);
168 SetConfigName( GetName() );
174 for (std::vector<TString>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ++ivar)
175 DataInfo().AddVariable( *ivar );
183 TMVA::Reader::Reader( std::vector<std::string>& inputVars,
const TString& theOption, Bool_t verbose )
184 : Configurable( theOption ),
185 fDataSetManager( NULL ),
190 fCalculateError(kFALSE),
192 fMvaEventErrorUpper( 0 ),
195 fDataSetManager =
new DataSetManager( fDataInputHandler );
196 fDataSetManager->AddDataSetInfo(fDataSetInfo);
197 fLogger =
new MsgLogger(
this);
198 SetConfigName( GetName() );
204 for (std::vector<std::string>::iterator ivar = inputVars.begin(); ivar != inputVars.end(); ++ivar)
205 DataInfo().AddVariable( ivar->c_str() );
213 TMVA::Reader::Reader(
const std::string& varNames,
const TString& theOption, Bool_t verbose )
214 : Configurable( theOption ),
215 fDataSetManager( NULL ),
220 fCalculateError(kFALSE),
222 fMvaEventErrorUpper( 0 ),
225 fDataSetManager =
new DataSetManager( fDataInputHandler );
226 fDataSetManager->AddDataSetInfo(fDataSetInfo);
227 fLogger =
new MsgLogger(
this);
228 SetConfigName( GetName() );
234 DecodeVarNames(varNames);
241 TMVA::Reader::Reader(
const TString& varNames,
const TString& theOption, Bool_t verbose )
242 : Configurable( theOption ),
243 fDataSetManager( NULL ),
248 fCalculateError(kFALSE),
250 fMvaEventErrorUpper( 0 ),
253 fDataSetManager =
new DataSetManager( fDataInputHandler );
254 fDataSetManager->AddDataSetInfo(fDataSetInfo);
255 fLogger =
new MsgLogger(
this);
256 SetConfigName( GetName() );
262 DecodeVarNames(varNames);
269 void TMVA::Reader::DeclareOptions()
271 if (gTools().CheckForSilentOption( GetOptions() )) Log().InhibitOutput();
273 DeclareOptionRef( fVerbose,
"V",
"Verbose flag" );
274 DeclareOptionRef( fColor,
"Color",
"Color flag (default True)" );
275 DeclareOptionRef( fSilent,
"Silent",
"Boolean silent flag (default False)" );
276 DeclareOptionRef( fCalculateError,
"Error",
"Calculates errors (default False)" );
282 TMVA::Reader::~Reader(
void )
284 delete fDataSetManager;
288 for (
auto it=fMethodMap.begin(); it!=fMethodMap.end(); it++){
289 MethodBase * kl =
dynamic_cast<TMVA::MethodBase*
>(it->second);
297 void TMVA::Reader::Init(
void )
299 if (Verbose()) fLogger->SetMinType( kVERBOSE );
301 gConfig().SetUseColor( fColor );
302 gConfig().SetSilent ( fSilent );
308 void TMVA::Reader::AddVariable(
const TString& expression, Float_t* datalink )
310 DataInfo().AddVariable( expression,
"",
"", 0, 0,
'F', kFALSE ,(
void*)datalink );
315 void TMVA::Reader::AddVariable(
const TString& expression, Int_t* datalink )
317 Log() << kFATAL <<
"Reader::AddVariable( const TString& expression, Int_t* datalink ), this function is deprecated, please provide all variables to the reader as floats" << Endl;
319 Log() << kFATAL <<
"Reader::AddVariable( const TString& expression, Int_t* datalink ), this function is deprecated, please provide all variables to the reader as floats" << Endl;
320 DataInfo().AddVariable(expression,
"",
"", 0, 0,
'I', kFALSE, (
void*)datalink );
326 void TMVA::Reader::AddSpectator(
const TString& expression, Float_t* datalink )
328 DataInfo().AddSpectator( expression,
"",
"", 0, 0,
'F', kFALSE ,(
void*)datalink );
334 void TMVA::Reader::AddSpectator(
const TString& expression, Int_t* datalink )
336 DataInfo().AddSpectator(expression,
"",
"", 0, 0,
'I', kFALSE, (
void*)datalink );
342 TString TMVA::Reader::GetMethodTypeFromFile(
const TString& filename )
344 std::ifstream fin( filename );
346 Log() << kFATAL <<
"<BookMVA> fatal error: "
347 <<
"unable to open input weight file: " << filename << Endl;
350 TString fullMethodName(
"");
351 if (filename.EndsWith(
".xml")) {
353 void* doc = gTools().xmlengine().ParseFile(filename,gTools().xmlenginebuffersize());
354 void* rootnode = gTools().xmlengine().DocGetRootElement(doc);
355 gTools().ReadAttr(rootnode,
"Method", fullMethodName);
356 gTools().xmlengine().FreeDoc(doc);
360 fin.getline(buf,512);
361 while (!TString(buf).BeginsWith(
"Method")) fin.getline(buf,512);
362 fullMethodName = TString(buf);
365 TString methodType = fullMethodName(0,fullMethodName.Index(
"::"));
366 if (methodType.Contains(
" ")) methodType = methodType(methodType.Last(
' ')+1,methodType.Length());
373 TMVA::IMethod* TMVA::Reader::BookMVA(
const TString& methodTag,
const TString& weightfile )
376 if (fMethodMap.find( methodTag ) != fMethodMap.end())
377 Log() << kFATAL <<
"<BookMVA> method tag \"" << methodTag <<
"\" already exists!" << Endl;
379 TString methodType(GetMethodTypeFromFile(weightfile));
381 Log() << kINFO <<
"Booking \"" << methodTag <<
"\" of type \"" << methodType <<
"\" from " << weightfile <<
"." << Endl;
383 MethodBase* method =
dynamic_cast<MethodBase*
>(this->BookMVA( Types::Instance().GetMethodType(methodType),
385 if( method && method->GetMethodType() == Types::kCategory ){
386 MethodCategory *methCat = (
dynamic_cast<MethodCategory*
>(method));
388 Log() << kFATAL <<
"Method with type kCategory cannot be casted to MethodCategory. /Reader" << Endl;
389 methCat->fDataSetManager = fDataSetManager;
392 return fMethodMap[methodTag] = method;
398 TMVA::IMethod* TMVA::Reader::BookMVA( TMVA::Types::EMVA methodType,
const TString& weightfile )
401 ClassifierFactory::Instance().Create(Types::Instance().GetMethodName(methodType).Data(), DataInfo(), weightfile);
403 MethodBase *method = (
dynamic_cast<MethodBase*
>(im));
405 if (method==0)
return im;
407 if( method->GetMethodType() == Types::kCategory ){
408 MethodCategory *methCat = (
dynamic_cast<MethodCategory*
>(method));
410 Log() << kERROR <<
"Method with type kCategory cannot be casted to MethodCategory. /Reader" << Endl;
411 methCat->fDataSetManager = fDataSetManager;
414 method->SetupMethod();
418 method->DeclareCompatibilityOptions();
421 method->ReadStateFromFile();
424 method->CheckSetup();
426 Log() << kINFO <<
"Booked classifier \"" << method->GetMethodName()
427 <<
"\" of type: \"" << method->GetMethodTypeName() <<
"\"" << Endl;
434 TMVA::IMethod* TMVA::Reader::BookMVA( TMVA::Types::EMVA methodType,
const char* xmlstr )
438 ClassifierFactory::Instance().Create(Types::Instance().GetMethodName(methodType).Data(), DataInfo(),
"");
440 MethodBase *method = (
dynamic_cast<MethodBase*
>(im));
442 if(!method)
return 0;
444 if( method->GetMethodType() == Types::kCategory ){
445 MethodCategory *methCat = (
dynamic_cast<MethodCategory*
>(method));
447 Log() << kFATAL <<
"Method with type kCategory cannot be casted to MethodCategory. /Reader" << Endl;
448 methCat->fDataSetManager = fDataSetManager;
451 method->SetupMethod();
455 method->DeclareCompatibilityOptions();
458 method->ReadStateFromXMLString( xmlstr );
461 method->CheckSetup();
463 Log() << kINFO <<
"Booked classifier \"" << method->GetMethodName()
464 <<
"\" of type: \"" << method->GetMethodTypeName() <<
"\"" << Endl;
473 Double_t TMVA::Reader::EvaluateMVA(
const std::vector<Float_t>& inputVec,
const TString& methodTag, Double_t aux )
476 IMethod* imeth = FindMVA( methodTag );
477 MethodBase* meth =
dynamic_cast<TMVA::MethodBase*
>(imeth);
478 if(meth==0)
return 0;
481 Event* tmpEvent=
new Event(inputVec, DataInfo().GetNVariables());
482 for (UInt_t i=0; i<inputVec.size(); i++){
483 if (TMath::IsNaN(inputVec[i])) {
484 Log() << kERROR << i <<
"-th variable of the event is NaN --> return MVA value -999, \n that's all I can do, please fix or remove this event." << Endl;
490 if (meth->GetMethodType() == TMVA::Types::kCuts) {
491 TMVA::MethodCuts* mc =
dynamic_cast<TMVA::MethodCuts*
>(meth);
493 mc->SetTestSignalEfficiency( aux );
495 Double_t val = meth->GetMvaValue( tmpEvent, (fCalculateError?&fMvaEventError:0));
504 Double_t TMVA::Reader::EvaluateMVA(
const std::vector<Double_t>& inputVec,
const TString& methodTag, Double_t aux )
507 if(fTmpEvalVec.size() != inputVec.size())
508 fTmpEvalVec.resize(inputVec.size());
510 for (UInt_t idx=0; idx!=inputVec.size(); idx++ )
511 fTmpEvalVec[idx]=inputVec[idx];
513 return EvaluateMVA( fTmpEvalVec, methodTag, aux );
519 Double_t TMVA::Reader::EvaluateMVA(
const TString& methodTag, Double_t aux )
523 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
524 if (it == fMethodMap.end()) {
525 Log() << kINFO <<
"<EvaluateMVA> unknown classifier in map; "
526 <<
"you looked for \"" << methodTag <<
"\" within available methods: " << Endl;
527 for (it = fMethodMap.begin(); it!=fMethodMap.end(); ++it) Log() <<
"--> " << it->first << Endl;
528 Log() <<
"Check calling string" << kFATAL << Endl;
531 else method = it->second;
533 MethodBase * kl =
dynamic_cast<TMVA::MethodBase*
>(method);
536 Log() << kFATAL << methodTag <<
" is not a method" << Endl;
540 const Event* ev = kl->GetEvent();
541 for (UInt_t i=0; i<ev->GetNVariables(); i++){
542 if (TMath::IsNaN(ev->GetValue(i))) {
543 Log() << kERROR << i <<
"-th variable of the event is NaN --> return MVA value -999, \n that's all I can do, please fix or remove this event." << Endl;
547 return this->EvaluateMVA( kl, aux );
553 Double_t TMVA::Reader::EvaluateMVA( MethodBase* method, Double_t aux )
557 if (method->GetMethodType() == TMVA::Types::kCuts) {
558 TMVA::MethodCuts* mc =
dynamic_cast<TMVA::MethodCuts*
>(method);
560 mc->SetTestSignalEfficiency( aux );
563 return method->GetMvaValue( (fCalculateError?&fMvaEventError:0),
564 (fCalculateError?&fMvaEventErrorUpper:0) );
570 const std::vector< Float_t >& TMVA::Reader::EvaluateRegression(
const TString& methodTag, Double_t aux )
574 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
575 if (it == fMethodMap.end()) {
576 Log() << kINFO <<
"<EvaluateMVA> unknown method in map; "
577 <<
"you looked for \"" << methodTag <<
"\" within available methods: " << Endl;
578 for (it = fMethodMap.begin(); it!=fMethodMap.end(); ++it) Log() <<
"--> " << it->first << Endl;
579 Log() <<
"Check calling string" << kFATAL << Endl;
581 else method = it->second;
583 MethodBase * kl =
dynamic_cast<TMVA::MethodBase*
>(method);
586 Log() << kFATAL << methodTag <<
" is not a method" << Endl;
589 const Event* ev = kl->GetEvent();
590 for (UInt_t i=0; i<ev->GetNVariables(); i++){
591 if (TMath::IsNaN(ev->GetValue(i))) {
592 Log() << kERROR << i <<
"-th variable of the event is NaN, \n regression values might evaluate to .. what do I know. \n sorry this warning is all I can do, please fix or remove this event." << Endl;
596 return this->EvaluateRegression( kl, aux );
604 const std::vector< Float_t >& TMVA::Reader::EvaluateRegression( MethodBase* method, Double_t )
606 const Event* ev = method->GetEvent();
607 for (UInt_t i=0; i<ev->GetNVariables(); i++){
608 if (TMath::IsNaN(ev->GetValue(i))) {
609 Log() << kERROR << i <<
"-th variable of the event is NaN, \n regression values might evaluate to .. what do I know. \n sorry this warning is all I can do, please fix or remove this event." << Endl;
612 return method->GetRegressionValues();
619 Float_t TMVA::Reader::EvaluateRegression( UInt_t tgtNumber,
const TString& methodTag, Double_t aux )
622 return EvaluateRegression(methodTag, aux).at(tgtNumber);
624 catch (std::out_of_range &) {
625 Log() << kWARNING <<
"Regression could not be evaluated for target-number " << tgtNumber << Endl;
635 const std::vector< Float_t >& TMVA::Reader::EvaluateMulticlass(
const TString& methodTag, Double_t aux )
639 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
640 if (it == fMethodMap.end()) {
641 Log() << kINFO <<
"<EvaluateMVA> unknown method in map; "
642 <<
"you looked for \"" << methodTag <<
"\" within available methods: " << Endl;
643 for (it = fMethodMap.begin(); it!=fMethodMap.end(); ++it) Log() <<
"--> " << it->first << Endl;
644 Log() <<
"Check calling string" << kFATAL << Endl;
646 else method = it->second;
648 MethodBase * kl =
dynamic_cast<TMVA::MethodBase*
>(method);
651 Log() << kFATAL << methodTag <<
" is not a method" << Endl;
655 const Event* ev = kl->GetEvent();
656 for (UInt_t i=0; i<ev->GetNVariables(); i++){
657 if (TMath::IsNaN(ev->GetValue(i))) {
658 Log() << kERROR << i <<
"-th variable of the event is NaN, \n regression values might evaluate to .. what do I know. \n sorry this warning is all I can do, please fix or remove this event." << Endl;
662 return this->EvaluateMulticlass( kl, aux );
670 const std::vector< Float_t >& TMVA::Reader::EvaluateMulticlass( MethodBase* method, Double_t )
672 const Event* ev = method->GetEvent();
673 for (UInt_t i=0; i<ev->GetNVariables(); i++){
674 if (TMath::IsNaN(ev->GetValue(i))) {
675 Log() << kERROR << i <<
"-th variable of the event is NaN, \n regression values might evaluate to .. what do I know. \n sorry this warning is all I can do, please fix or remove this event." << Endl;
678 return method->GetMulticlassValues();
685 Float_t TMVA::Reader::EvaluateMulticlass( UInt_t clsNumber,
const TString& methodTag, Double_t aux )
688 return EvaluateMulticlass(methodTag, aux).at(clsNumber);
690 catch (std::out_of_range &) {
691 Log() << kWARNING <<
"Multiclass could not be evaluated for class-number " << clsNumber << Endl;
700 TMVA::IMethod* TMVA::Reader::FindMVA(
const TString& methodTag )
702 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
703 if (it != fMethodMap.end())
return it->second;
704 Log() << kERROR <<
"Method " << methodTag <<
" not found!" << Endl;
712 TMVA::MethodCuts* TMVA::Reader::FindCutsMVA(
const TString& methodTag )
714 return dynamic_cast<MethodCuts*
>(FindMVA(methodTag));
720 Double_t TMVA::Reader::GetProba(
const TString& methodTag, Double_t ap_sig, Double_t mvaVal )
723 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
724 if (it == fMethodMap.end()) {
725 for (it = fMethodMap.begin(); it!=fMethodMap.end(); ++it) Log() <<
"M" << it->first << Endl;
726 Log() << kFATAL <<
"<EvaluateMVA> unknown classifier in map: " << method <<
"; "
727 <<
"you looked for " << methodTag<<
" while the available methods are : " << Endl;
729 else method = it->second;
731 MethodBase* kl =
dynamic_cast<MethodBase*
>(method);
735 const Event* ev = kl->GetEvent();
736 for (UInt_t i=0; i<ev->GetNVariables(); i++){
737 if (TMath::IsNaN(ev->GetValue(i))) {
738 Log() << kERROR << i <<
"-th variable of the event is NaN --> return MVA value -999, \n that's all I can do, please fix or remove this event." << Endl;
743 if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
745 return kl->GetProba( mvaVal, ap_sig );
751 Double_t TMVA::Reader::GetRarity(
const TString& methodTag, Double_t mvaVal )
754 std::map<TString, IMethod*>::iterator it = fMethodMap.find( methodTag );
755 if (it == fMethodMap.end()) {
756 for (it = fMethodMap.begin(); it!=fMethodMap.end(); ++it) Log() <<
"M" << it->first << Endl;
757 Log() << kFATAL <<
"<EvaluateMVA> unknown classifier in map: \"" << method <<
"\"; "
758 <<
"you looked for \"" << methodTag<<
"\" while the available methods are : " << Endl;
760 else method = it->second;
762 MethodBase* kl =
dynamic_cast<MethodBase*
>(method);
766 const Event* ev = kl->GetEvent();
767 for (UInt_t i=0; i<ev->GetNVariables(); i++){
768 if (TMath::IsNaN(ev->GetValue(i))) {
769 Log() << kERROR << i <<
"-th variable of the event is NaN --> return MVA value -999, \n that's all I can do, please fix or remove this event." << Endl;
774 if (mvaVal == -9999999) mvaVal = kl->GetMvaValue();
776 return kl->GetRarity( mvaVal );
786 void TMVA::Reader::DecodeVarNames(
const std::string& varNames )
788 size_t ipos = 0, f = 0;
789 while (f != varNames.length()) {
790 f = varNames.find(
':', ipos );
791 if (f > varNames.length()) f = varNames.length();
792 std::string subs = varNames.substr( ipos, f-ipos ); ipos = f+1;
793 DataInfo().AddVariable( subs.c_str() );
800 void TMVA::Reader::DecodeVarNames(
const TString& varNames )
803 Int_t n = varNames.Length();
806 for (
int i=0; i< n+1 ; i++) {
807 format.Append(varNames(i));
808 if (varNames(i) ==
':' || i == n) {
811 format_obj.ReplaceAll(
"@",
"");
812 DataInfo().AddVariable( format_obj );