30 std::vector<TControlBar*> TMVA::BDT_Global__cbar;
32 TMVA::StatDialogBDT* TMVA::StatDialogBDT::fThis = 0;
33 bool TMVA::DecisionTreeNode::fgIsTraining =
false;
35 void TMVA::StatDialogBDT::SetItree()
37 fItree = Int_t(fInput->GetNumber());
40 void TMVA::StatDialogBDT::Redraw()
45 void TMVA::StatDialogBDT::Close()
50 TMVA::StatDialogBDT::StatDialogBDT(TString dataset,
const TGWindow* p, TString wfile, TString methName, Int_t itree )
60 fMethName( methName ),
63 UInt_t totalWidth = 500;
64 UInt_t totalHeight = 200;
68 TMVA::DecisionTreeNode::fgIsTraining=
true;
74 fMain =
new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
76 TGLabel *sigLab =
new TGLabel( fMain, Form(
"Decision tree [%i-%i]",0,fNtrees-1 ) );
77 fMain->AddFrame(sigLab,
new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
79 fInput =
new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
80 fMain->AddFrame(fInput,
new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
81 fInput->Resize(100,24);
82 fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
84 fButtons =
new TGHorizontalFrame(fMain, totalWidth,30);
86 fCloseButton =
new TGTextButton(fButtons,
"&Close");
87 fButtons->AddFrame(fCloseButton,
new TGLayoutHints(kLHintsLeft | kLHintsTop));
89 fDrawButton =
new TGTextButton(fButtons,
"&Draw");
90 fButtons->AddFrame(fDrawButton,
new TGLayoutHints(kLHintsRight | kLHintsTop,15));
92 fMain->AddFrame(fButtons,
new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
94 fMain->SetWindowName(
"Decision tree");
95 fMain->SetWMPosition(0,0);
96 fMain->MapSubwindows();
97 fMain->Resize(fMain->GetDefaultSize());
100 fInput->Connect(
"ValueSet(Long_t)",
"TMVA::StatDialogBDT",
this,
"SetItree()");
104 fDrawButton->Connect(
"Clicked()",
"TMVA::StatDialogBDT",
this,
"Redraw()");
106 fCloseButton->Connect(
"Clicked()",
"TMVA::StatDialogBDT",
this,
"Close()");
109 void TMVA::StatDialogBDT::UpdateCanvases()
114 void TMVA::StatDialogBDT::GetNtrees()
116 if(!fWfile.EndsWith(
".xml") ){
117 std::ifstream fin( fWfile );
119 cout <<
"*** ERROR: Weight file: " << fWfile <<
" does not exist" << endl;
127 while (!dummy.Contains(
"NTrees")) {
132 cout <<
"*** Huge problem: could not locate term \"NTrees\" in BDT weight file: "
134 cout <<
"==> panic abort (please contact the TMVA authors)" << endl;
140 fNtrees = dummy.ReplaceAll(
"\"",
"").Atoi();
144 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
145 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
146 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
148 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
149 if(nodeName==
"Weights") {
150 TMVA::gTools().ReadAttr( ch,
"NTrees", fNtrees );
153 ch = TMVA::gTools().xmlengine().GetNext(ch);
156 cout <<
"--- Found " << fNtrees <<
" decision trees in weight file" << endl;
164 void TMVA::StatDialogBDT::DrawNode( TMVA::DecisionTreeNode *n,
165 Double_t x, Double_t y,
166 Double_t xscale, Double_t yscale, TString * vars)
168 Float_t xsize=xscale*1.5;
169 Float_t ysize=yscale/3;
170 if (xsize>0.15) xsize=0.1;
171 if (n->GetLeft() != NULL){
172 TLine *a1 =
new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
175 DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
177 if (n->GetRight() != NULL){
178 TLine *a1 =
new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
181 DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars );
185 TPaveText *t =
new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize,
"NDC");
189 t->SetFillStyle(1001);
192 Double_t pur=n->GetPurity();
193 t->SetFillColor(fColorOffset+Int_t(pur*100));
196 sprintf( buffer,
"N=%f", n->GetNEvents() );
197 if (n->GetNEvents()>0) t->AddText(buffer);
198 sprintf( buffer,
"S/(S+B)=%4.3f", n->GetPurity() );
201 if (n->GetNodeType() == 0){
202 if (n->GetCutType()){
203 t->AddText(TString(vars[n->GetSelector()]+
">"+=::Form(
"%5.3g",n->GetCutValue())));
205 t->AddText(TString(vars[n->GetSelector()]+
"<"+=::Form(
"%5.3g",n->GetCutValue())));
213 TMVA::DecisionTree* TMVA::StatDialogBDT::ReadTree( TString* &vars, Int_t itree )
215 cout <<
"--- Reading Tree " << itree <<
" from weight file: " << fWfile << endl;
216 TMVA::DecisionTree *d =
new TMVA::DecisionTree();
217 if(!fWfile.EndsWith(
".xml") ){
218 std::ifstream fin( fWfile );
220 cout <<
"*** ERROR: Weight file: " << fWfile <<
" does not exist" << endl;
228 if (itree >= fNtrees) {
229 cout <<
"*** ERROR: requested decision tree: " << itree
230 <<
", but number of trained trees only: " << fNtrees << endl;
237 while (!dummy.Contains(
"#VAR")) fin >> dummy;
238 fin >> dummy >> dummy >> dummy;
242 fin >> dummy >> nVars;
245 vars =
new TString[nVars+1];
246 for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
247 vars[nVars]=
"FisherCrit";
251 sprintf(buffer,
"Tree %d",itree);
253 while (!dummy.Contains(buffer)) {
254 fin.getline(line,256);
255 dummy = TString(line);
263 if (itree >= fNtrees) {
264 cout <<
"*** ERROR: requested decision tree: " << itree
265 <<
", but number of trained trees only: " << fNtrees << endl;
271 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
272 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
273 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
275 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
276 if(nodeName==
"Variables"){
277 TMVA::gTools().ReadAttr( ch,
"NVar", nVars);
278 vars =
new TString[nVars+1];
279 void* varnode = TMVA::gTools().xmlengine().GetChild(ch);
280 for (Int_t i = 0; i < nVars; i++){
281 TMVA::gTools().ReadAttr( varnode,
"Expression", vars[i]);
282 varnode = TMVA::gTools().xmlengine().GetNext(varnode);
284 vars[nVars]=
"FisherCrit";
286 if(nodeName==
"Weights")
break;
287 ch = TMVA::gTools().xmlengine().GetNext(ch);
289 ch = TMVA::gTools().xmlengine().GetChild(ch);
290 for (
int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
298 void TMVA::StatDialogBDT::DrawTree( Int_t itree )
301 TMVA::DecisionTree* d = ReadTree( vars, itree );
304 UInt_t depth = d->GetTotalTreeDepth();
305 Double_t ystep = 1.0/(depth + 1.0);
307 cout <<
"--- Tree depth: " << depth << endl;
309 TStyle* TMVAStyle = gROOT->GetStyle(
"Plain");
313 Double_t r[2] = {1., 0.};
314 Double_t g[2] = {0., 0.};
315 Double_t b[2] = {0., 1.};
316 Double_t stop[2] = {0., 1.0};
317 fColorOffset = TColor::CreateGradientColorTable(2, stop, r, g, b, 100);
319 Int_t MyPalette[100];
320 for (
int i=0;i<100;i++) MyPalette[i] = fColorOffset+i;
321 TMVAStyle->SetPalette(100, MyPalette);
325 Int_t canvasColor = TMVAStyle->GetCanvasColor();
327 TString cbuffer = Form(
"Reading weight file: %s", fWfile.Data() );
328 TString tbuffer = Form(
"Decision Tree no.: %d", itree );
329 if (!fCanvas) fCanvas =
new TCanvas(
"c1", cbuffer, 200, 0, 1000, 600 );
330 else fCanvas->Clear();
333 DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
337 Double_t ydown=yup-ystep/2.5;
338 Double_t dy= ystep/2.5 * 0.2;
340 TPaveText *whichTree =
new TPaveText(0.85,ydown,0.98,yup,
"NDC");
341 whichTree->SetBorderSize(1);
342 whichTree->SetFillStyle(1001);
343 whichTree->SetFillColor( TColor::GetColor(
"#ffff33" ) );
344 whichTree->AddText( tbuffer );
347 TPaveText *signalleaf =
new TPaveText(0.02,ydown ,0.15,yup,
"NDC");
348 signalleaf->SetBorderSize(1);
349 signalleaf->SetFillStyle(1001);
350 signalleaf->SetFillColor( getSigColorF() );
351 signalleaf->AddText(
"Pure Signal Nodes");
352 signalleaf->SetTextColor( getSigColorT() );
355 ydown = ydown - ystep/2.5 -dy;
356 yup = yup - ystep/2.5 -dy;
357 TPaveText *backgroundleaf =
new TPaveText(0.02,ydown,0.15,yup,
"NDC");
358 backgroundleaf->SetBorderSize(1);
359 backgroundleaf->SetFillStyle(1001);
360 backgroundleaf->SetFillColor( getBkgColorF() );
362 backgroundleaf->AddText(
"Pure Backgr. Nodes");
363 backgroundleaf->SetTextColor( getBkgColorT() );
364 backgroundleaf->Draw();
368 TString fname = fDataset+Form(
"/plots/%s_%i", fMethName.Data(), itree );
369 cout <<
"--- Creating image: " << fname << endl;
370 TMVAGlob::imgconv( fCanvas, fname );
372 TMVAStyle->SetCanvasColor( canvasColor );
379 void TMVA::BDT(TString dataset,
const TString& fin )
384 TMVAGlob::DestroyCanvases();
387 TFile* file = TMVAGlob::OpenFile( fin );
389 TDirectory* dir = file->GetDirectory(dataset.Data())->GetDirectory(
"Method_BDT" );
391 cout <<
"*** Error in macro \"BDT.C\": cannot find directory \"Method_BDT\" in file: " << fin << endl;
396 TIter next( dir->GetListOfKeys() );
398 std::vector<TString> methname;
399 std::vector<TString> path;
400 std::vector<TString> wfile;
401 while ((key = (TKey*)next())) {
402 TDirectory* mdir = dir->GetDirectory( key->GetName() );
404 cout <<
"*** Error in macro \"BDT.C\": cannot find sub-directory: " << key->GetName()
405 <<
" in directory: " << dir->GetName() << endl;
410 TObjString* strPath = (TObjString*)mdir->Get(
"TrainingPath" );
411 TObjString* strWFile = (TObjString*)mdir->Get(
"WeightFileName" );
412 if (!strPath || !strWFile) {
413 cout <<
"*** Error in macro \"BDT.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << endl;
414 cout <<
"*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << endl;
418 methname.push_back( key->GetName() );
419 path .push_back( strPath->GetString() );
420 wfile .push_back( strWFile->GetString() );
424 TControlBar* cbar =
new TControlBar(
"vertical",
"Choose weight file:", 50, 50 );
425 BDT_Global__cbar.push_back(cbar);
427 for (UInt_t im=0; im<path.size(); im++) {
428 TString fname = path[im];
429 if (fname[fname.Length()-1] !=
'/') fname +=
"/";
431 TString macro = Form(
"TMVA::BDT(\"%s\",0,\"%s\",\"%s\")",dataset.Data(), fname.Data(), methname[im].Data() );
432 cbar->AddButton( fname, macro,
"Plot decision trees from this weight file",
"button" );
436 cbar->SetTextColor(
"blue");
442 void TMVA::BDT_DeleteTBar(
int i)
445 StatDialogBDT::Delete();
446 TMVAGlob::DestroyCanvases();
448 delete BDT_Global__cbar[i];
449 BDT_Global__cbar[i] = 0;
454 void TMVA::BDT(TString dataset, Int_t itree, TString wfile , TString methName , Bool_t useTMVAStyle )
457 StatDialogBDT::Delete();
458 TMVAGlob::DestroyCanvases();
460 wfile = dataset+
"/weights/TMVAnalysis_test_BDT.weights.txt";
462 if(!wfile.EndsWith(
".xml") ){
463 std::ifstream fin( wfile );
465 cout <<
"*** ERROR: Weight file: " << wfile <<
" does not exist" << endl;
469 std::cout <<
"test1";
471 TMVAGlob::Initialize( useTMVAStyle );
473 StatDialogBDT* gGui =
new StatDialogBDT(dataset, gClient->GetRoot(), wfile, methName, itree );
475 gGui->DrawTree(itree );