89 class EfficiencyPlotWrapper {
100 EfficiencyPlotWrapper(TString name, TString title, TString dataset,
size_t i);
102 Int_t addGraph(TGraph *graph);
103 void addLegendEntry(TString methodTitle, TGraph *graph);
113 TCanvas *newEfficiencyCanvas(TString name, TString title,
size_t i);
114 TLegend *newEfficiencyLegend();
117 using classcanvasmap_t = std::map<TString, EfficiencyPlotWrapper *>;
118 using roccurvelist_t = std::vector<std::tuple<TString, TString, TGraph *>>;
121 const char *BUTTON_TYPE =
"button";
125 std::vector<TString> getclassnames(TString dataset, TString fin);
126 roccurvelist_t getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef);
127 void plotEfficienciesMulticlass(roccurvelist_t rocCurves, classcanvasmap_t classCanvasMap);
137 std::vector<TString> TMVA::getclassnames(TString dataset, TString fin)
139 TFile *file = TMVA::TMVAGlob::OpenFile(fin);
140 TDirectory *dir = (TDirectory *)file->GetDirectory(dataset)->GetDirectory(
"InputVariables_Id");
142 std::cout <<
"Could not locate directory '" << dataset <<
"/InputVariables_Id' in file: " << fin << std::endl;
146 auto classnames = TMVA::TMVAGlob::GetClassNames(dir);
153 roccurvelist_t TMVA::getRocCurves(TDirectory *binDir, TString methodPrefix, TString graphNameRef)
155 roccurvelist_t rocCurves;
158 UInt_t nm = TMVAGlob::GetListOfMethods(methods, binDir);
160 cout <<
"ups .. no methods found in to plot ROC curve for ... give up" << endl;
164 TIter next(&methods);
168 while ((key = (TKey *)next())) {
169 TDirectory *mDir = (TDirectory *)key->ReadObj();
171 TMVAGlob::GetListOfTitles(mDir, titles);
174 TIter nextTitle(&titles);
177 while ((titkey = TMVAGlob::NextKey(nextTitle,
"TDirectory"))) {
178 titDir = (TDirectory *)titkey->ReadObj();
180 TMVAGlob::GetMethodTitle(methodTitle, titDir);
183 TIter nextKey(titDir->GetListOfKeys());
185 while ((hkey2 = TMVAGlob::NextKey(nextKey,
"TGraph"))) {
187 TGraph *h = (TGraph *)hkey2->ReadObj();
188 TString hname = h->GetName();
189 if (hname.Contains(graphNameRef) && hname.BeginsWith(methodPrefix) && !hname.Contains(
"Train")) {
192 UInt_t index = hname.Last(
'_');
193 TString classname = hname(index + 1, hname.Length() - (index + 1));
195 rocCurves.push_back(std::make_tuple(methodTitle, classname, h));
215 void TMVA::plotEfficienciesMulticlass(roccurvelist_t rocCurves, classcanvasmap_t classCanvasMap)
217 for (
auto &item : rocCurves) {
219 TString methodTitle = std::get<0>(item);
220 TString classname = std::get<1>(item);
221 TGraph *h = std::get<2>(item);
224 EfficiencyPlotWrapper *plotWrapper = classCanvasMap.at(classname);
225 plotWrapper->addGraph(h);
226 plotWrapper->addLegendEntry(methodTitle, h);
227 }
catch (
const std::out_of_range &oor) {
228 cout << Form(
"ERROR: Class %s discovered among plots but was not found by TMVAMulticlassGui. Skipping.",
242 void TMVA::efficienciesMulticlass1vsRest(TString dataset, TString filename_input, EEfficiencyPlotType plotType,
246 TMVAGlob::Initialize(useTMVAStyle);
247 plotEfficienciesMulticlass1vsRest(dataset, plotType, filename_input);
258 void TMVA::plotEfficienciesMulticlass1vsRest(TString dataset, EEfficiencyPlotType plotType, TString filename_input)
261 if (plotType != EEfficiencyPlotType::kRejBvsEffS) {
262 std::cout <<
"For multiclass, only rejB vs effS is currently implemented.";
267 TFile *file = TMVAGlob::OpenFile(filename_input);
268 if (file ==
nullptr) {
269 std::cout <<
"ERROR: filename \"" << filename_input <<
"\" is not found.";
272 auto binDir = file->GetDirectory(dataset.Data());
275 auto classnames = getclassnames(dataset, filename_input);
276 TString methodPrefix =
"MVA_";
277 TString graphNameRef =
"_rejBvsS_";
279 classcanvasmap_t classCanvasMap;
280 for (
auto &classname : classnames) {
281 TString name = Form(
"roc_%s_vs_rest", classname.Data());
282 TString title = Form(
"ROC Curve %s vs rest", classname.Data());
283 EfficiencyPlotWrapper *plotWrapper =
new EfficiencyPlotWrapper(name, title, dataset, iPlot++);
284 classCanvasMap.emplace(classname.Data(), plotWrapper);
287 roccurvelist_t rocCurves = getRocCurves(binDir, methodPrefix, graphNameRef);
288 plotEfficienciesMulticlass(rocCurves, classCanvasMap);
290 for (
auto const &item : classCanvasMap) {
291 auto plotWrapper = item.second;
302 void TMVA::efficienciesMulticlass1vs1(TString dataset, TString fin)
304 std::cout <<
"--- Running Roc1v1Gui for input file: " << fin << std::endl;
306 TMVAGlob::Initialize();
309 TString title =
"1v1 ROC curve comparison";
310 TControlBar *cbar =
new TControlBar(
"vertical", title, 50, 50);
313 auto classnames = getclassnames(dataset, fin);
316 for (
auto &classname : classnames) {
317 cbar->AddButton(Form(
"Class: %s", classname.Data()),
318 Form(
"TMVA::plotEfficienciesMulticlass1vs1(\"%s\", \"%s\", \"%s\")", dataset.Data(), fin.Data(),
323 cbar->SetTextColor(
"blue");
326 gROOT->SaveContext();
345 void TMVA::plotEfficienciesMulticlass1vs1(TString dataset, TString fin, TString baseClassname)
348 TMVAGlob::Initialize();
350 auto classnames = getclassnames(dataset, fin);
353 TString methodPrefix =
"MVA_";
354 TString graphNameRef = Form(
"_1v1rejBvsS_%s_vs_", baseClassname.Data());
356 TFile *file = TMVAGlob::OpenFile(fin);
357 if (file ==
nullptr) {
358 std::cout <<
"ERROR: filename \"" << fin <<
"\" is not found.";
361 auto binDir = file->GetDirectory(dataset.Data());
363 classcanvasmap_t classCanvasMap;
364 for (
auto &classname : classnames) {
366 if (classname == baseClassname) {
370 TString name = Form(
"1v1roc_%s_vs_%s", baseClassname.Data(), classname.Data());
371 TString title = Form(
"ROC Curve %s (Sig) vs %s (Bkg)", baseClassname.Data(), classname.Data());
372 EfficiencyPlotWrapper *plotWrapper =
new EfficiencyPlotWrapper(name, title, dataset, iPlot++);
373 classCanvasMap.emplace(classname.Data(), plotWrapper);
376 roccurvelist_t rocCurves = getRocCurves(binDir, methodPrefix, graphNameRef);
377 plotEfficienciesMulticlass(rocCurves, classCanvasMap);
379 for (
auto const &item : classCanvasMap) {
380 auto plotWrapper = item.second;
393 EfficiencyPlotWrapper::EfficiencyPlotWrapper(TString name, TString title, TString dataset,
size_t i)
401 fy0H = 1 - fy0H + fdyH + 0.07;
408 fCanvas = newEfficiencyCanvas(name, title, i);
409 fLegend = newEfficiencyLegend();
417 Int_t EfficiencyPlotWrapper::addGraph(TGraph *graph)
419 graph->SetLineWidth(3);
420 graph->SetLineColor(fColor);
422 if (fColor == 5 || fColor == 10 || fColor == 11) {
427 graph->DrawClone(
"");
441 void EfficiencyPlotWrapper::addLegendEntry(TString methodTitle, TGraph *graph)
443 fLegend->AddEntry(graph, methodTitle,
"l");
445 Float_t dyH_local = fdyH * (Float_t(TMath::Min((UInt_t)10, fNumMethods) - 3.0) / 4.0);
446 fLegend->SetY2(fy0H + dyH_local);
460 TCanvas *EfficiencyPlotWrapper::newEfficiencyCanvas(TString name, TString title,
size_t i)
462 TCanvas *c =
new TCanvas(name, title, 200 + i * 50, 0 + i * 50, 650, 500);
468 TString xtit =
"Signal Efficiency";
469 TString ytit =
"Background Rejection (1 - eff)";
475 TH2F *frame =
new TH2F(Form(
"%s_%s", title.Data(),
"frame"), title, 500, x1, x2, 500, y1, y2);
476 frame->GetXaxis()->SetTitle(xtit);
477 frame->GetYaxis()->SetTitle(ytit);
478 TMVA::TMVAGlob::SetFrameStyle(frame, 1.0);
487 TLegend *EfficiencyPlotWrapper::newEfficiencyLegend()
489 TLegend *legend =
new TLegend(fx0L, fy0H - fdyH, fx0L + fdxL, fy0H);
491 legend->SetHeader(
"MVA Method:");
492 legend->SetMargin(0.4);
502 void EfficiencyPlotWrapper::save()
504 TString fname = fDataset +
"/plots/" + fCanvas->GetName();
505 TMVA::TMVAGlob::imgconv(fCanvas, fname);