Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
BDT.cxx
Go to the documentation of this file.
1 #include "TMVA/BDT.h"
2 #include <iostream>
3 #include <iomanip>
4 #include <fstream>
5 
6 
7 
8 #include "RQ_OBJECT.h"
9 
10 #include "TROOT.h"
11 #include "TStyle.h"
12 #include "TPad.h"
13 #include "TCanvas.h"
14 #include "TLine.h"
15 #include "TFile.h"
16 #include "TColor.h"
17 #include "TPaveText.h"
18 #include "TObjString.h"
19 #include "TControlBar.h"
20 
21 #include "TGWindow.h"
22 #include "TGButton.h"
23 #include "TGLabel.h"
24 #include "TGNumberEntry.h"
25 
26 #include "TMVA/DecisionTree.h"
27 #include "TMVA/Tools.h"
28 #include "TXMLEngine.h"
29 
30 std::vector<TControlBar*> TMVA::BDT_Global__cbar;
31 
32 TMVA::StatDialogBDT* TMVA::StatDialogBDT::fThis = 0;
33 bool TMVA::DecisionTreeNode::fgIsTraining = false;
34 
35 void TMVA::StatDialogBDT::SetItree()
36 {
37  fItree = Int_t(fInput->GetNumber());
38 }
39 
40 void TMVA::StatDialogBDT::Redraw()
41 {
42  UpdateCanvases();
43 }
44 
45 void TMVA::StatDialogBDT::Close()
46 {
47  delete this;
48 }
49 
50 TMVA::StatDialogBDT::StatDialogBDT(TString dataset, const TGWindow* p, TString wfile, TString methName, Int_t itree )
51  : fMain( 0 ),
52  fItree(itree),
53  fNtrees(0),
54  fCanvas(0),
55  fInput(0),
56  fButtons(0),
57  fDrawButton(0),
58  fCloseButton(0),
59  fWfile( wfile ),
60  fMethName( methName ),
61  fDataset(dataset)
62 {
63  UInt_t totalWidth = 500;
64  UInt_t totalHeight = 200;
65 
66  fThis = this;
67 
68  TMVA::DecisionTreeNode::fgIsTraining=true;
69 
70  // read number of decision trees from weight file
71  GetNtrees();
72 
73  // main frame
74  fMain = new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
75 
76  TGLabel *sigLab = new TGLabel( fMain, Form( "Decision tree [%i-%i]",0,fNtrees-1 ) );
77  fMain->AddFrame(sigLab, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
78 
79  fInput = new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
80  fMain->AddFrame(fInput, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
81  fInput->Resize(100,24);
82  fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
83 
84  fButtons = new TGHorizontalFrame(fMain, totalWidth,30);
85 
86  fCloseButton = new TGTextButton(fButtons,"&Close");
87  fButtons->AddFrame(fCloseButton, new TGLayoutHints(kLHintsLeft | kLHintsTop));
88 
89  fDrawButton = new TGTextButton(fButtons,"&Draw");
90  fButtons->AddFrame(fDrawButton, new TGLayoutHints(kLHintsRight | kLHintsTop,15));
91 
92  fMain->AddFrame(fButtons,new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
93 
94  fMain->SetWindowName("Decision tree");
95  fMain->SetWMPosition(0,0);
96  fMain->MapSubwindows();
97  fMain->Resize(fMain->GetDefaultSize());
98  fMain->MapWindow();
99 
100  fInput->Connect("ValueSet(Long_t)","TMVA::StatDialogBDT",this, "SetItree()");
101 
102  // doesn't seem to exist .. gives an 'error message' and seems to work just fine without ... :)
103  // fDrawButton->Connect("ValueSet(Long_t)","TGNumberEntry",fInput, "Clicked()");
104  fDrawButton->Connect("Clicked()", "TMVA::StatDialogBDT", this, "Redraw()");
105 
106  fCloseButton->Connect("Clicked()", "TMVA::StatDialogBDT", this, "Close()");
107 }
108 
109 void TMVA::StatDialogBDT::UpdateCanvases()
110 {
111  DrawTree(fItree );
112 }
113 
114 void TMVA::StatDialogBDT::GetNtrees()
115 {
116  if(!fWfile.EndsWith(".xml") ){
117  std::ifstream fin( fWfile );
118  if (!fin.good( )) { // file not found --> Error
119  cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
120  return;
121  }
122 
123  TString dummy = "";
124 
125  // read total number of trees, and check whether requested tree is in range
126  Int_t nc = 0;
127  while (!dummy.Contains("NTrees")) {
128  fin >> dummy;
129  nc++;
130  if (nc > 200) {
131  cout << endl;
132  cout << "*** Huge problem: could not locate term \"NTrees\" in BDT weight file: "
133  << fWfile << endl;
134  cout << "==> panic abort (please contact the TMVA authors)" << endl;
135  cout << endl;
136  exit(1);
137  }
138  }
139  fin >> dummy;
140  fNtrees = dummy.ReplaceAll("\"","").Atoi();
141  fin.close();
142  }
143  else{
144  void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
145  void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
146  void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
147  while(ch){
148  TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
149  if(nodeName=="Weights") {
150  TMVA::gTools().ReadAttr( ch, "NTrees", fNtrees );
151  break;
152  }
153  ch = TMVA::gTools().xmlengine().GetNext(ch);
154  }
155  }
156  cout << "--- Found " << fNtrees << " decision trees in weight file" << endl;
157 
158 }
159 
160 ////////////////////////////////////////////////////////////////////////////////
161 /// recursively puts an entries in the histogram for the node and its daughters
162 ///
163 
164 void TMVA::StatDialogBDT::DrawNode( TMVA::DecisionTreeNode *n,
165  Double_t x, Double_t y,
166  Double_t xscale, Double_t yscale, TString * vars)
167 {
168  Float_t xsize=xscale*1.5;
169  Float_t ysize=yscale/3;
170  if (xsize>0.15) xsize=0.1; //xscale/2;
171  if (n->GetLeft() != NULL){
172  TLine *a1 = new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
173  a1->SetLineWidth(2);
174  a1->Draw();
175  DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
176  }
177  if (n->GetRight() != NULL){
178  TLine *a1 = new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
179  a1->SetLineWidth(2);
180  a1->Draw();
181  DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars );
182  }
183 
184  // TPaveText *t = new TPaveText(x-xscale/2,y-yscale/2,x+xscale/2,y+yscale/2, "NDC");
185  TPaveText *t = new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize, "NDC");
186 
187  t->SetBorderSize(1);
188 
189  t->SetFillStyle(1001);
190 
191 
192  Double_t pur=n->GetPurity();
193  t->SetFillColor(fColorOffset+Int_t(pur*100));
194 
195  char buffer[25];
196  sprintf( buffer, "N=%f", n->GetNEvents() );
197  if (n->GetNEvents()>0) t->AddText(buffer);
198  sprintf( buffer, "S/(S+B)=%4.3f", n->GetPurity() );
199  t->AddText(buffer);
200 
201  if (n->GetNodeType() == 0){
202  if (n->GetCutType()){
203  t->AddText(TString(vars[n->GetSelector()]+">"+=::Form("%5.3g",n->GetCutValue())));
204  }else{
205  t->AddText(TString(vars[n->GetSelector()]+"<"+=::Form("%5.3g",n->GetCutValue())));
206  }
207  }
208 
209  t->Draw();
210 
211  return;
212 }
213 TMVA::DecisionTree* TMVA::StatDialogBDT::ReadTree( TString* &vars, Int_t itree )
214 {
215  cout << "--- Reading Tree " << itree << " from weight file: " << fWfile << endl;
216  TMVA::DecisionTree *d = new TMVA::DecisionTree();
217  if(!fWfile.EndsWith(".xml") ){
218  std::ifstream fin( fWfile );
219  if (!fin.good( )) { // file not found --> Error
220  cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
221  delete d;
222  d = nullptr;
223  return 0;
224  }
225 
226  TString dummy = "";
227 
228  if (itree >= fNtrees) {
229  cout << "*** ERROR: requested decision tree: " << itree
230  << ", but number of trained trees only: " << fNtrees << endl;
231  delete d;
232  d = nullptr;
233  return 0;
234  }
235 
236  // file header with name
237  while (!dummy.Contains("#VAR")) fin >> dummy;
238  fin >> dummy >> dummy >> dummy; // the rest of header line
239 
240  // number of variables
241  Int_t nVars;
242  fin >> dummy >> nVars;
243 
244  // variable mins and maxes
245  vars = new TString[nVars+1]; // last one is if "fisher cut criterium"
246  for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
247  vars[nVars]="FisherCrit";
248 
249  char buffer[20];
250  char line[256];
251  sprintf(buffer,"Tree %d",itree);
252 
253  while (!dummy.Contains(buffer)) {
254  fin.getline(line,256);
255  dummy = TString(line);
256  }
257 
258  d->Read(fin);
259 
260  fin.close();
261  }
262  else{
263  if (itree >= fNtrees) {
264  cout << "*** ERROR: requested decision tree: " << itree
265  << ", but number of trained trees only: " << fNtrees << endl;
266  delete d;
267  d = nullptr;
268  return 0;
269  }
270  Int_t nVars;
271  void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
272  void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
273  void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
274  while(ch){
275  TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
276  if(nodeName=="Variables"){
277  TMVA::gTools().ReadAttr( ch, "NVar", nVars);
278  vars = new TString[nVars+1];
279  void* varnode = TMVA::gTools().xmlengine().GetChild(ch);
280  for (Int_t i = 0; i < nVars; i++){
281  TMVA::gTools().ReadAttr( varnode, "Expression", vars[i]);
282  varnode = TMVA::gTools().xmlengine().GetNext(varnode);
283  }
284  vars[nVars]="FisherCrit";
285  }
286  if(nodeName=="Weights") break;
287  ch = TMVA::gTools().xmlengine().GetNext(ch);
288  }
289  ch = TMVA::gTools().xmlengine().GetChild(ch);
290  for (int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
291  d->ReadXML(ch);
292  }
293  return d;
294 }
295 
296 ////////////////////////////////////////////////////////////////////////////////
297 
298 void TMVA::StatDialogBDT::DrawTree( Int_t itree )
299 {
300  TString *vars;
301  TMVA::DecisionTree* d = ReadTree( vars, itree );
302  if (d == 0) return;
303 
304  UInt_t depth = d->GetTotalTreeDepth();
305  Double_t ystep = 1.0/(depth + 1.0);
306 
307  cout << "--- Tree depth: " << depth << endl;
308 
309  TStyle* TMVAStyle = gROOT->GetStyle("Plain"); // our style is based on Plain
310 
311 
312 
313  Double_t r[2] = {1., 0.};
314  Double_t g[2] = {0., 0.};
315  Double_t b[2] = {0., 1.};
316  Double_t stop[2] = {0., 1.0};
317  fColorOffset = TColor::CreateGradientColorTable(2, stop, r, g, b, 100);
318 
319  Int_t MyPalette[100];
320  for (int i=0;i<100;i++) MyPalette[i] = fColorOffset+i;
321  TMVAStyle->SetPalette(100, MyPalette);
322 
323 
324 
325  Int_t canvasColor = TMVAStyle->GetCanvasColor(); // backup
326 
327  TString cbuffer = Form( "Reading weight file: %s", fWfile.Data() );
328  TString tbuffer = Form( "Decision Tree no.: %d", itree );
329  if (!fCanvas) fCanvas = new TCanvas( "c1", cbuffer, 200, 0, 1000, 600 );
330  else fCanvas->Clear();
331  fCanvas->Draw();
332 
333  DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
334 
335  // make the legend
336  Double_t yup=0.99;
337  Double_t ydown=yup-ystep/2.5;
338  Double_t dy= ystep/2.5 * 0.2;
339 
340  TPaveText *whichTree = new TPaveText(0.85,ydown,0.98,yup, "NDC");
341  whichTree->SetBorderSize(1);
342  whichTree->SetFillStyle(1001);
343  whichTree->SetFillColor( TColor::GetColor( "#ffff33" ) );
344  whichTree->AddText( tbuffer );
345  whichTree->Draw();
346 
347  TPaveText *signalleaf = new TPaveText(0.02,ydown ,0.15,yup, "NDC");
348  signalleaf->SetBorderSize(1);
349  signalleaf->SetFillStyle(1001);
350  signalleaf->SetFillColor( getSigColorF() );
351  signalleaf->AddText("Pure Signal Nodes");
352  signalleaf->SetTextColor( getSigColorT() );
353  signalleaf->Draw();
354 
355  ydown = ydown - ystep/2.5 -dy;
356  yup = yup - ystep/2.5 -dy;
357  TPaveText *backgroundleaf = new TPaveText(0.02,ydown,0.15,yup, "NDC");
358  backgroundleaf->SetBorderSize(1);
359  backgroundleaf->SetFillStyle(1001);
360  backgroundleaf->SetFillColor( getBkgColorF() );
361 
362  backgroundleaf->AddText("Pure Backgr. Nodes");
363  backgroundleaf->SetTextColor( getBkgColorT() );
364  backgroundleaf->Draw();
365 
366 
367  fCanvas->Update();
368  TString fname = fDataset+Form("/plots/%s_%i", fMethName.Data(), itree );
369  cout << "--- Creating image: " << fname << endl;
370  TMVAGlob::imgconv( fCanvas, fname );
371 
372  TMVAStyle->SetCanvasColor( canvasColor );
373 }
374 
375 // ========================================================================================
376 
377 
378 // intermediate GUI
379 void TMVA::BDT(TString dataset, const TString& fin )
380 {
381  // --- read the available BDT weight files
382 
383  // destroy all open cavases
384  TMVAGlob::DestroyCanvases();
385 
386  // checks if file with name "fin" is already open, and if not opens one
387  TFile* file = TMVAGlob::OpenFile( fin );
388 
389  TDirectory* dir = file->GetDirectory(dataset.Data())->GetDirectory( "Method_BDT" );
390  if (!dir) {
391  cout << "*** Error in macro \"BDT.C\": cannot find directory \"Method_BDT\" in file: " << fin << endl;
392  return;
393  }
394 
395  // read all directories
396  TIter next( dir->GetListOfKeys() );
397  TKey *key(0);
398  std::vector<TString> methname;
399  std::vector<TString> path;
400  std::vector<TString> wfile;
401  while ((key = (TKey*)next())) {
402  TDirectory* mdir = dir->GetDirectory( key->GetName() );
403  if (!mdir) {
404  cout << "*** Error in macro \"BDT.C\": cannot find sub-directory: " << key->GetName()
405  << " in directory: " << dir->GetName() << endl;
406  return;
407  }
408 
409  // retrieve weight file name and path
410  TObjString* strPath = (TObjString*)mdir->Get( "TrainingPath" );
411  TObjString* strWFile = (TObjString*)mdir->Get( "WeightFileName" );
412  if (!strPath || !strWFile) {
413  cout << "*** Error in macro \"BDT.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << endl;
414  cout << "*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << endl;
415  return;
416  }
417 
418  methname.push_back( key->GetName() );
419  path .push_back( strPath->GetString() );
420  wfile .push_back( strWFile->GetString() );
421  }
422 
423  // create the control bar
424  TControlBar* cbar = new TControlBar( "vertical", "Choose weight file:", 50, 50 );
425  BDT_Global__cbar.push_back(cbar);
426 
427  for (UInt_t im=0; im<path.size(); im++) {
428  TString fname = path[im];
429  if (fname[fname.Length()-1] != '/') fname += "/";
430  fname += wfile[im];
431  TString macro = Form( "TMVA::BDT(\"%s\",0,\"%s\",\"%s\")",dataset.Data(), fname.Data(), methname[im].Data() );
432  cbar->AddButton( fname, macro, "Plot decision trees from this weight file", "button" );
433  }
434 
435  // set the style
436  cbar->SetTextColor("blue");
437 
438  // draw
439  cbar->Show();
440 }
441 
442 void TMVA::BDT_DeleteTBar(int i)
443 {
444  // destroy all open canvases
445  StatDialogBDT::Delete();
446  TMVAGlob::DestroyCanvases();
447 
448  delete BDT_Global__cbar[i];
449  BDT_Global__cbar[i] = 0;
450 }
451 
452 // input: - No. of tree
453 // - the weight file from which the tree is read
454 void TMVA::BDT(TString dataset, Int_t itree, TString wfile , TString methName , Bool_t useTMVAStyle )
455 {
456  // destroy possibly existing dialog windows and/or canvases
457  StatDialogBDT::Delete();
458  TMVAGlob::DestroyCanvases();
459  if(wfile=="")
460  wfile = dataset+"/weights/TMVAnalysis_test_BDT.weights.txt";
461  // quick check if weight file exist
462  if(!wfile.EndsWith(".xml") ){
463  std::ifstream fin( wfile );
464  if (!fin.good( )) { // file not found --> Error
465  cout << "*** ERROR: Weight file: " << wfile << " does not exist" << endl;
466  return;
467  }
468  }
469  std::cout << "test1";
470  // set style and remove existing canvas'
471  TMVAGlob::Initialize( useTMVAStyle );
472 
473  StatDialogBDT* gGui = new StatDialogBDT(dataset, gClient->GetRoot(), wfile, methName, itree );
474 
475  gGui->DrawTree(itree );
476 
477  gGui->RaiseDialog();
478 }
479