Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
BDT_Reg.cxx
Go to the documentation of this file.
1 #include "TMVA/BDT_Reg.h"
2 #include <iostream>
3 #include <iomanip>
4 #include <fstream>
5 
6 
7 
8 #include "RQ_OBJECT.h"
9 
10 #include "TROOT.h"
11 #include "TStyle.h"
12 #include "TPad.h"
13 #include "TCanvas.h"
14 #include "TLine.h"
15 #include "TFile.h"
16 #include "TColor.h"
17 #include "TPaveText.h"
18 #include "TObjString.h"
19 #include "TControlBar.h"
20 
21 #include "TGWindow.h"
22 #include "TGButton.h"
23 #include "TGLabel.h"
24 #include "TGNumberEntry.h"
25 
26 #include "TMVA/DecisionTree.h"
27 #include "TMVA/Tools.h"
28 #include "TXMLEngine.h"
29 
30 std::vector<TControlBar*> TMVA::BDTReg_Global__cbar;
31 
32 TMVA::StatDialogBDTReg* TMVA::StatDialogBDTReg::fThis = 0;
33 
34 void TMVA::StatDialogBDTReg::SetItree()
35 {
36  fItree = Int_t(fInput->GetNumber());
37 }
38 
39 void TMVA::StatDialogBDTReg::Redraw()
40 {
41  UpdateCanvases();
42 }
43 
44 void TMVA::StatDialogBDTReg::Close()
45 {
46  delete this;
47 }
48 
49 TMVA::StatDialogBDTReg::StatDialogBDTReg(TString dataset, const TGWindow* p, TString wfile, TString methName, Int_t itree )
50  : fMain( 0 ),
51  fItree(itree),
52  fNtrees(0),
53  fCanvas(0),
54  fDataset(dataset),
55  fInput(0),
56  fButtons(0),
57  fDrawButton(0),
58  fCloseButton(0),
59  fWfile( wfile ),
60  fMethName( methName )
61 {
62  UInt_t totalWidth = 500;
63  UInt_t totalHeight = 200;
64 
65  fThis = this;
66 
67  // read number of decision trees from weight file
68  GetNtrees();
69 
70  // main frame
71  fMain = new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
72 
73  TGLabel *sigLab = new TGLabel( fMain, Form( "Regression tree [%i-%i]",0,fNtrees-1 ) );
74  fMain->AddFrame(sigLab, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
75 
76  fInput = new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
77  fMain->AddFrame(fInput, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
78  fInput->Resize(100,24);
79  fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
80 
81  fButtons = new TGHorizontalFrame(fMain, totalWidth,30);
82 
83  fCloseButton = new TGTextButton(fButtons,"&Close");
84  fButtons->AddFrame(fCloseButton, new TGLayoutHints(kLHintsLeft | kLHintsTop));
85 
86  fDrawButton = new TGTextButton(fButtons,"&Draw");
87  fButtons->AddFrame(fDrawButton, new TGLayoutHints(kLHintsRight | kLHintsTop,15));
88 
89  fMain->AddFrame(fButtons,new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
90 
91  fMain->SetWindowName("Regression tree");
92  fMain->SetWMPosition(0,0);
93  fMain->MapSubwindows();
94  fMain->Resize(fMain->GetDefaultSize());
95  fMain->MapWindow();
96 
97  fInput->Connect("ValueSet(Long_t)","TMVA::StatDialogBDTReg",this, "SetItree()");
98 
99  // doesn't seem to exist .. gives an 'error message' and seems to work just fine without ... :)
100  // fDrawButton->Connect("Clicked()","TGNumberEntry",fInput, "ValueSet(Long_t)");
101  fDrawButton->Connect("Clicked()", "TMVA::StatDialogBDTReg", this, "Redraw()");
102 
103  fCloseButton->Connect("Clicked()", "TMVA::StatDialogBDTReg", this, "Close()");
104 }
105 
106 void TMVA::StatDialogBDTReg::UpdateCanvases()
107 {
108  DrawTree( fItree );
109 }
110 
111 void TMVA::StatDialogBDTReg::GetNtrees()
112 {
113  if(!fWfile.EndsWith(".xml") ){
114  std::ifstream fin( fWfile );
115  if (!fin.good( )) { // file not found --> Error
116  std::cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << std::endl;
117  return;
118  }
119 
120  TString dummy = "";
121 
122  // read total number of trees, and check whether requested tree is in range
123  Int_t nc = 0;
124  while (!dummy.Contains("NTrees")) {
125  fin >> dummy;
126  nc++;
127  if (nc > 200) {
128  std::cout << std::endl;
129  std::cout << "*** Huge problem: could not locate term \"NTrees\" in BDT weight file: "
130  << fWfile << std::endl;
131  std::cout << "==> panic abort (please contact the TMVA authors)" << std::endl;
132  std::cout << std::endl;
133  exit(1);
134  }
135  }
136  fin >> dummy;
137  fNtrees = dummy.ReplaceAll("\"","").Atoi();
138  fin.close();
139  }
140  else{
141  void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
142  void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
143  void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
144  while(ch){
145  TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
146  if(nodeName=="Weights") {
147  TMVA::gTools().ReadAttr( ch, "NTrees", fNtrees );
148  break;
149  }
150  ch = TMVA::gTools().xmlengine().GetNext(ch);
151  }
152  }
153  std::cout << "--- Found " << fNtrees << " decision trees in weight file" << std::endl;
154 
155 }
156 
157 ////////////////////////////////////////////////////////////////////////////////
158 /// recursively puts an entries in the histogram for the node and its daughters
159 ///
160 
161 void TMVA::StatDialogBDTReg::DrawNode( TMVA::DecisionTreeNode *n,
162  Double_t x, Double_t y,
163  Double_t xscale, Double_t yscale, TString * vars)
164 {
165  Float_t xsize=xscale*1.5;
166  Float_t ysize=yscale/3;
167  if (xsize>0.15) xsize=0.1;
168  if (n->GetLeft() != NULL){
169  TLine *a1 = new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
170  a1->SetLineWidth(2);
171  a1->Draw();
172  DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
173  }
174  if (n->GetRight() != NULL){
175  TLine *a1 = new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
176  a1->SetLineWidth(2);
177  a1->Draw();
178  DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars );
179  }
180 
181  // TPaveText *t = new TPaveText(x-xscale/2,y-yscale/2,x+xscale/2,y+yscale/2, "NDC");
182  TPaveText *t = new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize, "NDC");
183 
184  t->SetBorderSize(1);
185 
186  t->SetFillStyle(1001);
187  if (n->GetNodeType() == 1) { t->SetFillColor( getSigColorF() ); t->SetTextColor( getSigColorT() ); }
188  else if (n->GetNodeType() == -1) { t->SetFillColor( getBkgColorF() ); t->SetTextColor( getBkgColorT() ); }
189  else if (n->GetNodeType() == 0) { t->SetFillColor( getIntColorF() ); t->SetTextColor( getIntColorT() ); }
190 
191  char buffer[25];
192  // sprintf( buffer, "N=%f", n->GetNEvents() );
193  // t->AddText(buffer);
194  sprintf( buffer, "R=%4.1f +- %4.1f", n->GetResponse(),n->GetRMS() );
195  t->AddText(buffer);
196 
197  if (n->GetNodeType() == 0){
198  if (n->GetCutType()){
199  t->AddText(TString(vars[n->GetSelector()]+">"+=::Form("%5.3g",n->GetCutValue())));
200  }else{
201  t->AddText(TString(vars[n->GetSelector()]+"<"+=::Form("%5.3g",n->GetCutValue())));
202  }
203  }
204 
205  t->Draw();
206 
207  return;
208 }
209 
210 TMVA::DecisionTree* TMVA::StatDialogBDTReg::ReadTree( TString* &vars, Int_t itree )
211 {
212  std::cout << "--- Reading Tree " << itree << " from weight file: " << fWfile << std::endl;
213  TMVA::DecisionTree *d = new TMVA::DecisionTree();
214 
215 
216  if(!fWfile.EndsWith(".xml") ){
217 
218  std::ifstream fin( fWfile );
219  if (!fin.good( )) { // file not found --> Error
220  std::cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << std::endl;
221  delete d;
222  d = nullptr;
223  return 0;
224  }
225  TString dummy = "";
226 
227  if (itree >= fNtrees) {
228  std::cout << "*** ERROR: requested decision tree: " << itree
229  << ", but number of trained trees only: " << fNtrees << std::endl;
230  delete d;
231  d = nullptr;
232  return 0;
233  }
234 
235  // file header with name
236  while (!dummy.Contains("#VAR")) fin >> dummy;
237  fin >> dummy >> dummy >> dummy; // the rest of header line
238 
239  // number of variables
240  Int_t nVars;
241  fin >> dummy >> nVars;
242 
243  // variable mins and maxes
244  vars = new TString[nVars+1];
245  for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
246  vars[nVars]="FisherCrit";
247 
248  char buffer[20];
249  char line[256];
250  sprintf(buffer,"Tree %d",itree);
251 
252  while (!dummy.Contains(buffer)) {
253  fin.getline(line,256);
254  dummy = TString(line);
255  }
256 
257  d->Read(fin);
258 
259  fin.close();
260  }
261  else{
262  if (itree >= fNtrees) {
263  std::cout << "*** ERROR: requested decision tree: " << itree
264  << ", but number of trained trees only: " << fNtrees << std::endl;
265  delete d;
266  d = nullptr;
267  return 0;
268  }
269  Int_t nVars;
270  void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
271  void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
272  void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
273  while(ch){
274  TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
275  if(nodeName=="Variables"){
276  TMVA::gTools().ReadAttr( ch, "NVar", nVars);
277  vars = new TString[nVars+1];
278  void* varnode = TMVA::gTools().xmlengine().GetChild(ch);
279  for (Int_t i = 0; i < nVars; i++){
280  TMVA::gTools().ReadAttr( varnode, "Expression", vars[i]);
281  varnode = TMVA::gTools().xmlengine().GetNext(varnode);
282  }
283  vars[nVars]="FisherCrit";
284  }
285  if(nodeName=="Weights") break;
286  ch = TMVA::gTools().xmlengine().GetNext(ch);
287  }
288  ch = TMVA::gTools().xmlengine().GetChild(ch);
289  for (int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
290  d->ReadXML(ch);
291  }
292  return d;
293 }
294 
295 ////////////////////////////////////////////////////////////////////////////////
296 
297 void TMVA::StatDialogBDTReg::DrawTree( Int_t itree )
298 {
299  TString *vars;
300 
301  TMVA::DecisionTree* d = ReadTree( vars, itree );
302  if (d == 0) return;
303 
304  UInt_t depth = d->GetTotalTreeDepth();
305  Double_t ystep = 1.0/(depth + 1.0);
306 
307  std::cout << "--- Tree depth: " << depth << std::endl;
308 
309  TStyle* TMVAStyle = gROOT->GetStyle("Plain"); // our style is based on Plain
310  Int_t canvasColor = TMVAStyle->GetCanvasColor(); // backup
311 
312  TString cbuffer = Form( "Reading weight file: %s", fWfile.Data() );
313  TString tbuffer = Form( "Regression Tree no.: %d", itree );
314  if (!fCanvas) fCanvas = new TCanvas( "c1", cbuffer, 200, 0, 1000, 600 );
315  else fCanvas->Clear();
316  fCanvas->Draw();
317  DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
318 
319  // make the legend
320  Double_t yup=0.99;
321  Double_t ydown=yup-ystep/2.5;
322  Double_t dy= ystep/2.5 * 0.2;
323 
324  TPaveText *whichTree = new TPaveText(0.85,ydown,0.98,yup, "NDC");
325  whichTree->SetBorderSize(1);
326  whichTree->SetFillStyle(1001);
327  whichTree->SetFillColor( TColor::GetColor( "#ffff33" ) );
328  whichTree->AddText( tbuffer );
329  whichTree->Draw();
330 
331  TPaveText *intermediate = new TPaveText(0.02,ydown,0.15,yup, "NDC");
332  intermediate->SetBorderSize(1);
333  intermediate->SetFillStyle(1001);
334  intermediate->SetFillColor( getIntColorF() );
335  intermediate->AddText("Intermediate Nodes");
336  intermediate->SetTextColor( getIntColorT() );
337  intermediate->Draw();
338 
339  ydown = ydown - ystep/2.5 -dy;
340  yup = yup - ystep/2.5 -dy;
341  TPaveText *signalleaf = new TPaveText(0.02,ydown ,0.15,yup, "NDC");
342  signalleaf->SetBorderSize(1);
343  signalleaf->SetFillStyle(1001);
344  signalleaf->SetFillColor( getSigColorF() );
345  signalleaf->AddText("Leaf Nodes");
346  signalleaf->SetTextColor( getSigColorT() );
347  signalleaf->Draw();
348  /*
349  ydown = ydown - ystep/2.5 -dy;
350  yup = yup - ystep/2.5 -dy;
351  TPaveText *backgroundleaf = new TPaveText(0.02,ydown,0.15,yup, "NDC");
352  backgroundleaf->SetBorderSize(1);
353  backgroundleaf->SetFillStyle(1001);
354  backgroundleaf->SetFillColor( kBkgColorF );
355 
356  backgroundleaf->AddText("Backgr. Leaf Nodes");
357  backgroundleaf->SetTextColor( kBkgColorT );
358  backgroundleaf->Draw();
359  */
360  fCanvas->Update();
361  TString fname = fDataset+Form("/plots/%s_%i", fMethName.Data(), itree );
362  std::cout << "--- Creating image: " << fname << std::endl;
363  TMVAGlob::imgconv( fCanvas, fname );
364 
365  TMVAStyle->SetCanvasColor( canvasColor );
366 }
367 
368 // ========================================================================================
369 
370 // intermediate GUI
371 void TMVA::BDT_Reg(TString dataset, const TString& fin )
372 {
373  // --- read the available BDT weight files
374 
375  // destroy all open cavases
376  TMVAGlob::DestroyCanvases();
377 
378  // checks if file with name "fin" is already open, and if not opens one
379  TFile* file = TMVAGlob::OpenFile( fin );
380 
381  TDirectory* dir = file->GetDirectory(dataset.Data())->GetDirectory( "Method_BDT" );
382  if (!dir) {
383  std::cout << "*** Error in macro \"BDT_Reg.C\": cannot find directory \"Method_BDT\" in file: " << fin << std::endl;
384  return;
385  }
386 
387  // read all directories
388  TIter next( dir->GetListOfKeys() );
389  TKey *key(0);
390  std::vector<TString> methname;
391  std::vector<TString> path;
392  std::vector<TString> wfile;
393  while ((key = (TKey*)next())) {
394  TDirectory* mdir = dir->GetDirectory( key->GetName() );
395  if (!mdir) {
396  std::cout << "*** Error in macro \"BDT_Reg.C\": cannot find sub-directory: " << key->GetName()
397  << " in directory: " << dir->GetName() << std::endl;
398  return;
399  }
400 
401  // retrieve weight file name and path
402  TObjString* strPath = (TObjString*)mdir->Get( "TrainingPath" );
403  TObjString* strWFile = (TObjString*)mdir->Get( "WeightFileName" );
404  if (!strPath || !strWFile) {
405  std::cout << "*** Error in macro \"BDT_Reg.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << std::endl;
406  std::cout << "*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << std::endl;
407  return;
408  }
409 
410  methname.push_back( key->GetName() );
411  path .push_back( strPath->GetString() );
412  wfile .push_back( strWFile->GetString() );
413  }
414 
415  // create the control bar
416  TControlBar* cbar = new TControlBar( "vertical", "Choose weight file:", 50, 50 );
417  BDTReg_Global__cbar.push_back(cbar);
418 
419  for (UInt_t im=0; im<path.size(); im++) {
420  TString fname = path[im];
421  if (fname[fname.Length()-1] != '/') fname += "/";
422  fname += wfile[im];
423  TString macro = Form( "TMVA::BDT_Reg(\"%s\",0,\"%s\",\"%s\")",dataset.Data(), fname.Data(), methname[im].Data() );
424  cbar->AddButton( fname, macro, "Plot decision trees from this weight file", "button" );
425  }
426 
427  // set the style
428  cbar->SetTextColor("blue");
429 
430  // draw
431  cbar->Show();
432 }
433 
434 void TMVA::BDTReg_DeleteTBar(int i)
435 {
436  // destroy all open canvases
437  StatDialogBDTReg::Delete();
438  TMVAGlob::DestroyCanvases();
439 
440  delete BDTReg_Global__cbar[i];
441  BDTReg_Global__cbar[i] = 0;
442 }
443 
444 // input: - No. of tree
445 // - the weight file from which the tree is read
446 void TMVA::BDT_Reg(TString dataset, Int_t itree, TString wfile , TString methName, Bool_t useTMVAStyle )
447 {
448  // destroy possibly existing dialog windows and/or canvases
449  StatDialogBDTReg::Delete();
450  TMVAGlob::DestroyCanvases();
451  if(wfile=="")
452  wfile = dataset+"/weights/TMVARegression_BDT.weights.xml";
453 
454  // quick check if weight file exist
455  if(!wfile.EndsWith(".xml") ){
456  std::ifstream fin( wfile );
457  if (!fin.good( )) { // file not found --> Error
458  std::cout << "*** ERROR: Weight file: " << wfile << " does not exist" << std::endl;
459  return;
460  }
461  }
462  std::cout << "test1";
463  // set style and remove existing canvas'
464  TMVAGlob::Initialize( useTMVAStyle );
465 
466  StatDialogBDTReg* gGui = new StatDialogBDTReg(dataset, gClient->GetRoot(), wfile, methName, itree );
467 
468  gGui->DrawTree( itree );
469 
470  gGui->RaiseDialog();
471 }
472