30 std::vector<TControlBar*> TMVA::BDTReg_Global__cbar;
32 TMVA::StatDialogBDTReg* TMVA::StatDialogBDTReg::fThis = 0;
34 void TMVA::StatDialogBDTReg::SetItree()
36 fItree = Int_t(fInput->GetNumber());
39 void TMVA::StatDialogBDTReg::Redraw()
44 void TMVA::StatDialogBDTReg::Close()
49 TMVA::StatDialogBDTReg::StatDialogBDTReg(TString dataset,
const TGWindow* p, TString wfile, TString methName, Int_t itree )
62 UInt_t totalWidth = 500;
63 UInt_t totalHeight = 200;
71 fMain =
new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
73 TGLabel *sigLab =
new TGLabel( fMain, Form(
"Regression tree [%i-%i]",0,fNtrees-1 ) );
74 fMain->AddFrame(sigLab,
new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
76 fInput =
new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
77 fMain->AddFrame(fInput,
new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
78 fInput->Resize(100,24);
79 fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
81 fButtons =
new TGHorizontalFrame(fMain, totalWidth,30);
83 fCloseButton =
new TGTextButton(fButtons,
"&Close");
84 fButtons->AddFrame(fCloseButton,
new TGLayoutHints(kLHintsLeft | kLHintsTop));
86 fDrawButton =
new TGTextButton(fButtons,
"&Draw");
87 fButtons->AddFrame(fDrawButton,
new TGLayoutHints(kLHintsRight | kLHintsTop,15));
89 fMain->AddFrame(fButtons,
new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
91 fMain->SetWindowName(
"Regression tree");
92 fMain->SetWMPosition(0,0);
93 fMain->MapSubwindows();
94 fMain->Resize(fMain->GetDefaultSize());
97 fInput->Connect(
"ValueSet(Long_t)",
"TMVA::StatDialogBDTReg",
this,
"SetItree()");
101 fDrawButton->Connect(
"Clicked()",
"TMVA::StatDialogBDTReg",
this,
"Redraw()");
103 fCloseButton->Connect(
"Clicked()",
"TMVA::StatDialogBDTReg",
this,
"Close()");
106 void TMVA::StatDialogBDTReg::UpdateCanvases()
111 void TMVA::StatDialogBDTReg::GetNtrees()
113 if(!fWfile.EndsWith(
".xml") ){
114 std::ifstream fin( fWfile );
116 std::cout <<
"*** ERROR: Weight file: " << fWfile <<
" does not exist" << std::endl;
124 while (!dummy.Contains(
"NTrees")) {
128 std::cout << std::endl;
129 std::cout <<
"*** Huge problem: could not locate term \"NTrees\" in BDT weight file: "
130 << fWfile << std::endl;
131 std::cout <<
"==> panic abort (please contact the TMVA authors)" << std::endl;
132 std::cout << std::endl;
137 fNtrees = dummy.ReplaceAll(
"\"",
"").Atoi();
141 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
142 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
143 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
145 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
146 if(nodeName==
"Weights") {
147 TMVA::gTools().ReadAttr( ch,
"NTrees", fNtrees );
150 ch = TMVA::gTools().xmlengine().GetNext(ch);
153 std::cout <<
"--- Found " << fNtrees <<
" decision trees in weight file" << std::endl;
161 void TMVA::StatDialogBDTReg::DrawNode( TMVA::DecisionTreeNode *n,
162 Double_t x, Double_t y,
163 Double_t xscale, Double_t yscale, TString * vars)
165 Float_t xsize=xscale*1.5;
166 Float_t ysize=yscale/3;
167 if (xsize>0.15) xsize=0.1;
168 if (n->GetLeft() != NULL){
169 TLine *a1 =
new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
172 DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
174 if (n->GetRight() != NULL){
175 TLine *a1 =
new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
178 DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars );
182 TPaveText *t =
new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize,
"NDC");
186 t->SetFillStyle(1001);
187 if (n->GetNodeType() == 1) { t->SetFillColor( getSigColorF() ); t->SetTextColor( getSigColorT() ); }
188 else if (n->GetNodeType() == -1) { t->SetFillColor( getBkgColorF() ); t->SetTextColor( getBkgColorT() ); }
189 else if (n->GetNodeType() == 0) { t->SetFillColor( getIntColorF() ); t->SetTextColor( getIntColorT() ); }
194 sprintf( buffer,
"R=%4.1f +- %4.1f", n->GetResponse(),n->GetRMS() );
197 if (n->GetNodeType() == 0){
198 if (n->GetCutType()){
199 t->AddText(TString(vars[n->GetSelector()]+
">"+=::Form(
"%5.3g",n->GetCutValue())));
201 t->AddText(TString(vars[n->GetSelector()]+
"<"+=::Form(
"%5.3g",n->GetCutValue())));
210 TMVA::DecisionTree* TMVA::StatDialogBDTReg::ReadTree( TString* &vars, Int_t itree )
212 std::cout <<
"--- Reading Tree " << itree <<
" from weight file: " << fWfile << std::endl;
213 TMVA::DecisionTree *d =
new TMVA::DecisionTree();
216 if(!fWfile.EndsWith(
".xml") ){
218 std::ifstream fin( fWfile );
220 std::cout <<
"*** ERROR: Weight file: " << fWfile <<
" does not exist" << std::endl;
227 if (itree >= fNtrees) {
228 std::cout <<
"*** ERROR: requested decision tree: " << itree
229 <<
", but number of trained trees only: " << fNtrees << std::endl;
236 while (!dummy.Contains(
"#VAR")) fin >> dummy;
237 fin >> dummy >> dummy >> dummy;
241 fin >> dummy >> nVars;
244 vars =
new TString[nVars+1];
245 for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
246 vars[nVars]=
"FisherCrit";
250 sprintf(buffer,
"Tree %d",itree);
252 while (!dummy.Contains(buffer)) {
253 fin.getline(line,256);
254 dummy = TString(line);
262 if (itree >= fNtrees) {
263 std::cout <<
"*** ERROR: requested decision tree: " << itree
264 <<
", but number of trained trees only: " << fNtrees << std::endl;
270 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
271 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
272 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
274 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
275 if(nodeName==
"Variables"){
276 TMVA::gTools().ReadAttr( ch,
"NVar", nVars);
277 vars =
new TString[nVars+1];
278 void* varnode = TMVA::gTools().xmlengine().GetChild(ch);
279 for (Int_t i = 0; i < nVars; i++){
280 TMVA::gTools().ReadAttr( varnode,
"Expression", vars[i]);
281 varnode = TMVA::gTools().xmlengine().GetNext(varnode);
283 vars[nVars]=
"FisherCrit";
285 if(nodeName==
"Weights")
break;
286 ch = TMVA::gTools().xmlengine().GetNext(ch);
288 ch = TMVA::gTools().xmlengine().GetChild(ch);
289 for (
int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
297 void TMVA::StatDialogBDTReg::DrawTree( Int_t itree )
301 TMVA::DecisionTree* d = ReadTree( vars, itree );
304 UInt_t depth = d->GetTotalTreeDepth();
305 Double_t ystep = 1.0/(depth + 1.0);
307 std::cout <<
"--- Tree depth: " << depth << std::endl;
309 TStyle* TMVAStyle = gROOT->GetStyle(
"Plain");
310 Int_t canvasColor = TMVAStyle->GetCanvasColor();
312 TString cbuffer = Form(
"Reading weight file: %s", fWfile.Data() );
313 TString tbuffer = Form(
"Regression Tree no.: %d", itree );
314 if (!fCanvas) fCanvas =
new TCanvas(
"c1", cbuffer, 200, 0, 1000, 600 );
315 else fCanvas->Clear();
317 DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
321 Double_t ydown=yup-ystep/2.5;
322 Double_t dy= ystep/2.5 * 0.2;
324 TPaveText *whichTree =
new TPaveText(0.85,ydown,0.98,yup,
"NDC");
325 whichTree->SetBorderSize(1);
326 whichTree->SetFillStyle(1001);
327 whichTree->SetFillColor( TColor::GetColor(
"#ffff33" ) );
328 whichTree->AddText( tbuffer );
331 TPaveText *intermediate =
new TPaveText(0.02,ydown,0.15,yup,
"NDC");
332 intermediate->SetBorderSize(1);
333 intermediate->SetFillStyle(1001);
334 intermediate->SetFillColor( getIntColorF() );
335 intermediate->AddText(
"Intermediate Nodes");
336 intermediate->SetTextColor( getIntColorT() );
337 intermediate->Draw();
339 ydown = ydown - ystep/2.5 -dy;
340 yup = yup - ystep/2.5 -dy;
341 TPaveText *signalleaf =
new TPaveText(0.02,ydown ,0.15,yup,
"NDC");
342 signalleaf->SetBorderSize(1);
343 signalleaf->SetFillStyle(1001);
344 signalleaf->SetFillColor( getSigColorF() );
345 signalleaf->AddText(
"Leaf Nodes");
346 signalleaf->SetTextColor( getSigColorT() );
361 TString fname = fDataset+Form(
"/plots/%s_%i", fMethName.Data(), itree );
362 std::cout <<
"--- Creating image: " << fname << std::endl;
363 TMVAGlob::imgconv( fCanvas, fname );
365 TMVAStyle->SetCanvasColor( canvasColor );
371 void TMVA::BDT_Reg(TString dataset,
const TString& fin )
376 TMVAGlob::DestroyCanvases();
379 TFile* file = TMVAGlob::OpenFile( fin );
381 TDirectory* dir = file->GetDirectory(dataset.Data())->GetDirectory(
"Method_BDT" );
383 std::cout <<
"*** Error in macro \"BDT_Reg.C\": cannot find directory \"Method_BDT\" in file: " << fin << std::endl;
388 TIter next( dir->GetListOfKeys() );
390 std::vector<TString> methname;
391 std::vector<TString> path;
392 std::vector<TString> wfile;
393 while ((key = (TKey*)next())) {
394 TDirectory* mdir = dir->GetDirectory( key->GetName() );
396 std::cout <<
"*** Error in macro \"BDT_Reg.C\": cannot find sub-directory: " << key->GetName()
397 <<
" in directory: " << dir->GetName() << std::endl;
402 TObjString* strPath = (TObjString*)mdir->Get(
"TrainingPath" );
403 TObjString* strWFile = (TObjString*)mdir->Get(
"WeightFileName" );
404 if (!strPath || !strWFile) {
405 std::cout <<
"*** Error in macro \"BDT_Reg.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << std::endl;
406 std::cout <<
"*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << std::endl;
410 methname.push_back( key->GetName() );
411 path .push_back( strPath->GetString() );
412 wfile .push_back( strWFile->GetString() );
416 TControlBar* cbar =
new TControlBar(
"vertical",
"Choose weight file:", 50, 50 );
417 BDTReg_Global__cbar.push_back(cbar);
419 for (UInt_t im=0; im<path.size(); im++) {
420 TString fname = path[im];
421 if (fname[fname.Length()-1] !=
'/') fname +=
"/";
423 TString macro = Form(
"TMVA::BDT_Reg(\"%s\",0,\"%s\",\"%s\")",dataset.Data(), fname.Data(), methname[im].Data() );
424 cbar->AddButton( fname, macro,
"Plot decision trees from this weight file",
"button" );
428 cbar->SetTextColor(
"blue");
434 void TMVA::BDTReg_DeleteTBar(
int i)
437 StatDialogBDTReg::Delete();
438 TMVAGlob::DestroyCanvases();
440 delete BDTReg_Global__cbar[i];
441 BDTReg_Global__cbar[i] = 0;
446 void TMVA::BDT_Reg(TString dataset, Int_t itree, TString wfile , TString methName, Bool_t useTMVAStyle )
449 StatDialogBDTReg::Delete();
450 TMVAGlob::DestroyCanvases();
452 wfile = dataset+
"/weights/TMVARegression_BDT.weights.xml";
455 if(!wfile.EndsWith(
".xml") ){
456 std::ifstream fin( wfile );
458 std::cout <<
"*** ERROR: Weight file: " << wfile <<
" does not exist" << std::endl;
462 std::cout <<
"test1";
464 TMVAGlob::Initialize( useTMVAStyle );
466 StatDialogBDTReg* gGui =
new StatDialogBDTReg(dataset, gClient->GetRoot(), wfile, methName, itree );
468 gGui->DrawTree( itree );