79 REGISTER_METHOD(Category)
81 ClassImp(TMVA::MethodCategory);
86 TMVA::MethodCategory::MethodCategory( const TString& jobName,
87 const TString& methodTitle,
89 const TString& theOption )
90 : TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption),
99 TMVA::MethodCategory::MethodCategory( DataSetInfo& dsi,
100 const TString& theWeightFile)
101 : TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile),
103 fDataSetManager(NULL)
110 TMVA::MethodCategory::~MethodCategory(
void )
112 std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
113 std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
114 for(;formIt!=lastF; ++formIt)
delete *formIt;
122 Bool_t TMVA::MethodCategory::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets )
124 std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
127 for(; itrMethod != fMethods.end(); ++itrMethod ) {
128 if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
137 void TMVA::MethodCategory::DeclareOptions()
144 TMVA::IMethod* TMVA::MethodCategory::AddMethod(
const TCut& theCut,
145 const TString& theVariables,
146 Types::EMVA theMethod ,
147 const TString& theTitle,
148 const TString& theOptions )
150 std::string addedMethodName(Types::Instance().GetMethodName(theMethod).Data());
152 Log() << kINFO <<
"Adding sub-classifier: " << addedMethodName <<
"::" << theTitle << Endl;
154 DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
156 IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
158 MethodBase *method = (
dynamic_cast<MethodBase*
>(addedMethod));
159 if(method==0)
return 0;
161 if(fModelPersistence) method->SetWeightFileDir(fFileDir);
162 method->SetModelPersistence(fModelPersistence);
163 method->SetAnalysisType( fAnalysisType );
164 method->SetupMethod();
165 method->ParseOptions();
166 method->ProcessSetup();
167 method->SetFile(fFile);
168 method->SetSilentFile(IsSilentFile());
172 const TString dirName(Form(
"Method_%s",method->GetMethodTypeName().Data()));
173 TDirectory * dir = BaseDir()->GetDirectory(dirName);
174 if (dir != 0) method->SetMethodBaseDir( dir );
175 else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form(
"Directory for all %s methods", method->GetMethodTypeName().Data())) );
181 method->CheckSetup();
184 method->DisableWriting( kTRUE );
187 fMethods.push_back(method);
188 fCategoryCuts.push_back(theCut);
189 fVars.push_back(theVariables);
191 DataSetInfo& primaryDSI = DataInfo();
193 UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
194 fCategorySpecIdx.push_back(newSpectatorIndex);
196 primaryDSI.AddSpectator( Form(
"%s_cat%i:=%s", GetName(),(
int)fMethods.size(),theCut.GetTitle()),
197 Form(
"%s:%s",GetName(),method->GetName()),
206 TMVA::DataSetInfo& TMVA::MethodCategory::CreateCategoryDSI(
const TCut& theCut,
207 const TString& theVariables,
208 const TString& theTitle)
211 TString dsiName=theTitle+
"_dsi";
212 DataSetInfo& oldDSI = DataInfo();
213 DataSetInfo* dsi =
new DataSetInfo(dsiName);
217 fDataSetManager->AddDataSetInfo(*dsi);
220 std::vector<VariableInfo>::iterator itrVarInfo;
222 for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); ++itrVarInfo)
223 dsi->AddTarget(*itrVarInfo);
225 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo)
226 dsi->AddSpectator(*itrVarInfo);
229 std::vector<TString> variables = gTools().SplitString(theVariables,
':' );
232 std::vector<UInt_t> varMap;
236 std::vector<TString>::iterator itrVariables;
237 Bool_t found = kFALSE;
240 for (itrVariables = variables.begin(); itrVariables != variables.end(); ++itrVariables) {
244 for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); ++itrVarInfo) {
245 if((*itrVariables==itrVarInfo->GetLabel()) ) {
248 dsi->AddVariable(*itrVarInfo);
249 varMap.push_back(counter);
256 for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); ++itrVarInfo) {
257 if((*itrVariables==itrVarInfo->GetLabel()) ) {
260 dsi->AddVariable(*itrVarInfo);
261 varMap.push_back(counter);
269 Log() << kFATAL <<
"The variable " << itrVariables->Data() <<
" was not found and could not be added " << Endl;
275 if (theVariables==
"") {
276 for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
277 dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
283 fVarMaps.push_back(varMap);
286 UInt_t nClasses=oldDSI.GetNClasses();
289 for (UInt_t i=0; i<nClasses; i++) {
290 className = oldDSI.GetClassInfo(i)->GetName();
291 dsi->AddClass(className);
292 dsi->SetCut(oldDSI.GetCut(i),className);
293 dsi->AddCut(theCut,className);
294 dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
298 dsi->SetSplitOptions(oldDSI.GetSplitOptions());
299 dsi->SetRootDir(oldDSI.GetRootDir());
300 TString norm(oldDSI.GetNormalization().Data());
301 dsi->SetNormalization(norm);
303 DataSetInfo& dsiReference= (*dsi);
311 void TMVA::MethodCategory::Init()
318 void TMVA::MethodCategory::InitCircularTree(
const DataSetInfo& dsi)
322 std::vector<VariableInfo>::const_iterator viIt;
323 const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
324 const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
326 Bool_t hasAllExternalLinks = kTRUE;
327 for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
328 if( viIt->GetExternalLink() == 0 ) {
329 hasAllExternalLinks = kFALSE;
332 for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
333 if( viIt->GetExternalLink() == 0 ) {
334 hasAllExternalLinks = kFALSE;
338 if(!hasAllExternalLinks)
return;
345 TDirectory::TContext ctxt(
nullptr);
346 fCatTree =
new TTree(Form(
"Circ%s",GetMethodName().Data()),
"Circular Tree for categorization");
347 fCatTree->SetCircular(1);
350 for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
351 const VariableInfo& vi = *viIt;
352 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString(
"/F"));
354 for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
355 const VariableInfo& vi = *viIt;
356 if(vi.GetVarType()==
'C')
continue;
357 fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString(
"/F"));
360 for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
361 fCatFormulas.push_back(
new TTreeFormula(Form(
"Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
368 void TMVA::MethodCategory::Train()
371 const Int_t MinNoTrainingEvents = 10;
373 Types::EAnalysisType analysisType = GetAnalysisType();
376 Log() << kINFO <<
"Train all sub-classifiers for "
377 << (analysisType == Types::kRegression ?
"Regression" :
"Classification") <<
" ..." << Endl;
380 if (fMethods.empty()) {
381 Log() << kINFO <<
"...nothing found to train" << Endl;
385 std::vector<IMethod*>::iterator itrMethod;
388 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
390 MethodBase* mva =
dynamic_cast<MethodBase*
>(*itrMethod);
392 mva->SetAnalysisType( analysisType );
393 if (!mva->HasAnalysisType( analysisType,
394 mva->DataInfo().GetNClasses(),
395 mva->DataInfo().GetNTargets() ) ) {
396 Log() << kWARNING <<
"Method " << mva->GetMethodTypeName() <<
" is not capable of handling " ;
397 if (analysisType == Types::kRegression)
398 Log() <<
"regression with " << mva->DataInfo().GetNTargets() <<
" targets." << Endl;
400 Log() <<
"classification with " << mva->DataInfo().GetNClasses() <<
" classes." << Endl;
401 itrMethod = fMethods.erase( itrMethod );
404 if (mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
406 Log() << kINFO <<
"Train method: " << mva->GetMethodName() <<
" for "
407 << (analysisType == Types::kRegression ?
"Regression" :
"Classification") << Endl;
409 Log() << kINFO <<
"Training finished" << Endl;
413 Log() << kWARNING <<
"Method " << mva->GetMethodName()
414 <<
" not trained (training tree has less entries ["
415 << mva->Data()->GetNTrainingEvents()
416 <<
"] than required [" << MinNoTrainingEvents <<
"]" << Endl;
418 Log() << kERROR <<
" w/o training/test events for that category, I better stop here and let you fix " << Endl;
419 Log() << kFATAL <<
"that one first, otherwise things get too messy later ... " << Endl;
424 if (analysisType != Types::kRegression) {
427 Log() << kINFO <<
"Begin ranking of input variables..." << Endl;
428 for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod) {
429 MethodBase* mva =
dynamic_cast<MethodBase*
>(*itrMethod);
430 if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
431 const Ranking* ranking = (*itrMethod)->CreateRanking();
435 Log() << kINFO <<
"No variable ranking supplied by classifier: "
436 <<
dynamic_cast<MethodBase*
>(*itrMethod)->GetMethodName() << Endl;
445 void TMVA::MethodCategory::AddWeightsXMLTo(
void* parent )
const
447 void* wght = gTools().AddChild(parent,
"Weights");
448 gTools().AddAttr( wght,
"NSubMethods", fMethods.size() );
452 for (UInt_t i=0; i<fMethods.size(); i++) {
453 MethodBase* method =
dynamic_cast<MethodBase*
>(fMethods[i]);
454 submethod = gTools().AddChild(wght,
"SubMethod");
455 gTools().AddAttr(submethod,
"Index", i);
456 gTools().AddAttr(submethod,
"Method", method->GetMethodTypeName() +
"::" + method->GetMethodName());
457 gTools().AddAttr(submethod,
"Cut", fCategoryCuts[i]);
458 gTools().AddAttr(submethod,
"Variables", fVars[i]);
459 method->WriteStateToXML( submethod );
466 void TMVA::MethodCategory::ReadWeightsFromXML(
void* wghtnode )
469 TString fullMethodName;
472 TString theCutString;
473 TString theVariables;
475 gTools().ReadAttr( wghtnode,
"NSubMethods", nSubMethods );
476 void* subMethodNode = gTools().GetChild(wghtnode);
478 Log() << kINFO <<
"Recreating sub-classifiers from XML-file " << Endl;
481 for (UInt_t i=0; i<nSubMethods; i++) {
482 gTools().ReadAttr( subMethodNode,
"Method", fullMethodName );
483 gTools().ReadAttr( subMethodNode,
"Cut", theCutString );
484 gTools().ReadAttr( subMethodNode,
"Variables", theVariables );
487 methodType = fullMethodName(0,fullMethodName.Index(
"::"));
488 if (methodType.Contains(
" ")) methodType = methodType(methodType.Last(
' ')+1,methodType.Length());
491 titleLength = fullMethodName.Length()-fullMethodName.Index(
"::")-2;
492 methodTitle = fullMethodName(fullMethodName.Index(
"::")+2,titleLength);
495 DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
498 MethodBase* method =
dynamic_cast<MethodBase*
>( ClassifierFactory::Instance().Create( methodType.Data(),
501 Log() << kFATAL <<
"Could not create sub-method " << method <<
" from XML." << Endl;
503 method->SetupMethod();
504 method->ReadStateFromXML(subMethodNode);
506 fMethods.push_back(method);
507 fCategoryCuts.push_back(TCut(theCutString));
508 fVars.push_back(theVariables);
510 DataSetInfo& primaryDSI = DataInfo();
512 UInt_t spectatorIdx = 10000;
516 std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
517 std::vector<VariableInfo>::iterator itrVarInfo;
518 TString specName= Form(
"%s_cat%i", GetName(),(
int)fCategorySpecIdx.size()+1);
520 for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
521 if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
522 spectatorIdx=counter;
523 fCategorySpecIdx.push_back(spectatorIdx);
528 subMethodNode = gTools().GetNextChild(subMethodNode);
531 InitCircularTree(DataInfo());
538 void TMVA::MethodCategory::ProcessOptions()
548 void TMVA::MethodCategory::GetHelpMessage()
const
551 Log() << gTools().Color(
"bold") <<
"--- Short description:" << gTools().Color(
"reset") << Endl;
553 Log() <<
"This method allows to define different categories of events. The" <<Endl;
554 Log() <<
"categories are defined via cuts on the variables. For each" << Endl;
555 Log() <<
"category, a different classifier and set of variables can be" <<Endl;
556 Log() <<
"specified. The categories which are defined for this method must" << Endl;
557 Log() <<
"be disjoint." << Endl;
563 const TMVA::Ranking* TMVA::MethodCategory::CreateRanking()
570 Bool_t TMVA::MethodCategory::PassesCut(
const Event* ev, UInt_t methodIdx )
576 if (methodIdx>=fCatFormulas.size()) {
577 Log() << kFATAL <<
"Large method index " << methodIdx <<
", number of category formulas = "
578 << fCatFormulas.size() << Endl;
580 TTreeFormula* f = fCatFormulas[methodIdx];
581 return f->EvalInstance(0) > 0.5;
587 if (methodIdx>=fCategorySpecIdx.size()) {
588 Log() << kFATAL <<
"Unknown method index " << methodIdx <<
" maximum allowed index="
589 << fCategorySpecIdx.size() << Endl;
591 UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
592 Float_t specVal = ev->GetSpectator(spectatorIdx);
593 Bool_t pass = (specVal>0.5);
601 Double_t TMVA::MethodCategory::GetMvaValue( Double_t* err, Double_t* errUpper )
603 if (fMethods.empty())
return 0;
605 UInt_t methodToUse = 0;
606 const Event* ev = GetEvent();
609 Int_t suitableCutsN = 0;
611 for (UInt_t i=0; i<fMethods.size(); ++i) {
612 if (PassesCut(ev, i)) {
618 if (suitableCutsN == 0) {
619 Log() << kWARNING <<
"Event does not lie within the cut of any sub-classifier." << Endl;
623 if (suitableCutsN > 1) {
624 Log() << kFATAL <<
"The defined categories are not disjoint." << Endl;
629 ev->SetVariableArrangement(&fVarMaps[methodToUse]);
630 Double_t mvaValue =
dynamic_cast<MethodBase*
>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
631 ev->SetVariableArrangement(0);
641 const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
643 if (fMethods.empty())
return MethodBase::GetRegressionValues();
645 UInt_t methodToUse = 0;
646 const Event* ev = GetEvent();
649 Int_t suitableCutsN = 0;
651 for (UInt_t i=0; i<fMethods.size(); ++i) {
652 if (PassesCut(ev, i)) {
658 if (suitableCutsN == 0) {
659 Log() << kWARNING <<
"Event does not lie within the cut of any sub-classifier." << Endl;
660 return MethodBase::GetRegressionValues();
663 if (suitableCutsN > 1) {
664 Log() << kFATAL <<
"The defined categories are not disjoint." << Endl;
665 return MethodBase::GetRegressionValues();
667 MethodBase* meth =
dynamic_cast<MethodBase*
>(fMethods[methodToUse]);
669 Log() << kFATAL <<
"method not found in Category Regression method" << Endl;
670 return MethodBase::GetRegressionValues();
673 return meth->GetRegressionValues(ev);