69 CostComplexityPruneTool::CostComplexityPruneTool( SeparationBase* qualityIndex ) :
71 fLogger(new MsgLogger(
"CostComplexityPruneTool") )
81 fQualityIndexTool = qualityIndex;
84 fLogger->SetMinType( kWARNING );
90 CostComplexityPruneTool::~CostComplexityPruneTool( ) {
91 if(fQualityIndexTool != NULL)
delete fQualityIndexTool;
99 CostComplexityPruneTool::CalculatePruningInfo( DecisionTree* dt,
100 const IPruneTool::EventSample* validationSample,
103 if( isAutomatic ) SetAutomatic();
105 if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
117 dt->ApplyValidationSample(validationSample);
118 W = dt->GetSumWeights(validationSample);
120 Q = dt->TestPrunedTreeQuality();
122 Log() << kDEBUG <<
"Node purity limit is: " << dt->GetNodePurityLimit() << Endl;
123 Log() << kDEBUG <<
"Sum of weights in pruning validation sample: " << W << Endl;
124 Log() << kDEBUG <<
"Quality of tree prior to any pruning is " << Q/W << Endl;
129 InitTreePruningMetaData((DecisionTreeNode*)dt->GetRoot());
131 catch(
const std::string &error) {
132 Log() << kERROR <<
"Couldn't initialize the tree meta data because of error ("
133 << error <<
")" << Endl;
137 Log() << kDEBUG <<
"Automatic cost complexity pruning is " << (IsAutomatic()?
"on":
"off") <<
"." << Endl;
142 catch(
const std::string &error) {
143 Log() << kERROR <<
"Error optimizing pruning sequence ("
144 << error <<
")" << Endl;
148 Log() << kDEBUG <<
"Index of pruning sequence to stop at: " << fOptimalK << Endl;
150 PruningInfo* info =
new PruningInfo();
155 info->PruneStrength = 0;
156 info->QualityIndex = Q/W;
157 info->PruneSequence.clear();
158 Log() << kINFO <<
"no proper pruning could be calculated. Tree "
159 << dt->GetTreeID() <<
" will not be pruned. Do not worry if this "
160 <<
" happens for a few trees " << Endl;
163 info->QualityIndex = fQualityIndexList[fOptimalK]/W;
164 Log() << kDEBUG <<
" prune until k=" << fOptimalK <<
" with alpha="<<fPruneStrengthList[fOptimalK]<< Endl;
165 for( Int_t i = 0; i < fOptimalK; i++ ){
166 info->PruneSequence.push_back(fPruneSequence[i]);
169 info->PruneStrength = fPruneStrengthList[fOptimalK];
172 info->PruneStrength = fPruneStrength;
182 void CostComplexityPruneTool::InitTreePruningMetaData( DecisionTreeNode* n ) {
183 if( n == NULL )
return;
185 Double_t s = n->GetNSigEvents();
186 Double_t b = n->GetNBkgEvents();
188 if (fQualityIndexTool) n->SetNodeR( (s+b)*fQualityIndexTool->GetSeparationIndex(s,b));
189 else n->SetNodeR( (s+b)*n->GetSeparationIndex() );
191 if(n->GetLeft() != NULL && n->GetRight() != NULL) {
192 n->SetTerminal(kFALSE);
194 InitTreePruningMetaData(n->GetLeft());
195 InitTreePruningMetaData(n->GetRight());
197 n->SetNTerminal( n->GetLeft()->GetNTerminal() +
198 n->GetRight()->GetNTerminal());
200 n->SetSubTreeR( (n->GetLeft()->GetSubTreeR() +
201 n->GetRight()->GetSubTreeR()));
203 n->SetAlpha( ((n->GetNodeR() - n->GetSubTreeR()) /
204 (n->GetNTerminal() - 1)));
208 n->SetAlphaMinSubtree( std::min(n->GetAlpha(), std::min(n->GetLeft()->GetAlphaMinSubtree(),
209 n->GetRight()->GetAlphaMinSubtree())));
210 n->SetCC(n->GetAlpha());
213 n->SetNTerminal( 1 ); n->SetTerminal( );
214 if (fQualityIndexTool) n->SetSubTreeR(((s+b)*fQualityIndexTool->GetSeparationIndex(s,b)));
215 else n->SetSubTreeR( (s+b)*n->GetSeparationIndex() );
216 n->SetAlpha(std::numeric_limits<double>::infinity( ));
217 n->SetAlphaMinSubtree(std::numeric_limits<double>::infinity( ));
218 n->SetCC(n->GetAlpha());
237 void CostComplexityPruneTool::Optimize( DecisionTree* dt, Double_t weights ) {
239 Double_t alpha = -1.0e10;
240 Double_t epsilon = std::numeric_limits<double>::epsilon();
242 fQualityIndexList.clear();
243 fPruneSequence.clear();
244 fPruneStrengthList.clear();
246 DecisionTreeNode* R = (DecisionTreeNode*)dt->GetRoot();
251 qmin = dt->TestPrunedTreeQuality()/weights;
263 while(R->GetNTerminal() > 1) {
266 alpha = TMath::Max(R->GetAlphaMinSubtree(), alpha);
268 if( R->GetAlphaMinSubtree() >= R->GetAlpha() ) {
269 Log() << kDEBUG <<
"\nCaught trying to prune the root node!" << Endl;
274 DecisionTreeNode* t = R;
277 while(t->GetAlphaMinSubtree() < t->GetAlpha()) {
282 if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree()) < epsilon) {
290 Log() << kDEBUG <<
"\nCaught trying to prune the root node!" << Endl;
294 DecisionTreeNode* n = t;
303 dt->PruneNodeInPlace(t);
307 t->SetNTerminal(t->GetLeft()->GetNTerminal() + t->GetRight()->GetNTerminal());
308 t->SetSubTreeR(t->GetLeft()->GetSubTreeR() + t->GetRight()->GetSubTreeR());
309 t->SetAlpha((t->GetNodeR() - t->GetSubTreeR())/(t->GetNTerminal() - 1));
310 t->SetAlphaMinSubtree(std::min(t->GetAlpha(), std::min(t->GetLeft()->GetAlphaMinSubtree(),
311 t->GetRight()->GetAlphaMinSubtree())));
312 t->SetCC(t->GetAlpha());
316 Log() << kDEBUG <<
"after this pruning step I would have " << R->GetNTerminal() <<
" remaining terminal nodes " << Endl;
319 Double_t q = dt->TestPrunedTreeQuality()/weights;
320 fQualityIndexList.push_back(q);
323 fQualityIndexList.push_back(1.0);
325 fPruneSequence.push_back(n);
326 fPruneStrengthList.push_back(alpha);
329 if(fPruneSequence.empty()) {
336 for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
337 if(fQualityIndexList[i] < qmin) {
338 qmin = fQualityIndexList[i];
346 fOptimalK = int(fPruneStrength/100.0 * fPruneSequence.size() );
347 Log() << kDEBUG <<
"SequenzeSize="<<fPruneSequence.size()
348 <<
" fOptimalK " << fOptimalK << Endl;
352 Log() << kDEBUG <<
"\n************ Summary for Tree " << dt->GetTreeID() <<
" *******" << Endl
353 <<
"Number of trees in the sequence: " << fPruneSequence.size() << Endl;
355 Log() << kDEBUG <<
"Pruning strength parameters: [";
356 for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
357 Log() << kDEBUG << fPruneStrengthList[i] <<
", ";
358 Log() << kDEBUG << fPruneStrengthList[fPruneStrengthList.size()-1] <<
"]" << Endl;
360 Log() << kDEBUG <<
"Misclassification rates: [";
361 for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
362 Log() << kDEBUG << fQualityIndexList[i] <<
", ";
363 Log() << kDEBUG << fQualityIndexList[fQualityIndexList.size()-1] <<
"]" << Endl;
365 Log() << kDEBUG <<
"Prune index: " << fOptimalK+1 << Endl;