69 CCPruner::CCPruner( DecisionTree* t_max,
const EventList* validationSample,
70 SeparationBase* qualityIndex ) :
72 fValidationSample(validationSample),
73 fValidationDataSet(NULL),
78 if(qualityIndex == NULL) {
80 fQualityIndex =
new MisClassificationError();
84 fQualityIndex = qualityIndex;
92 CCPruner::CCPruner( DecisionTree* t_max,
const DataSet* validationSample,
93 SeparationBase* qualityIndex ) :
95 fValidationSample(NULL),
96 fValidationDataSet(validationSample),
101 if(qualityIndex == NULL) {
103 fQualityIndex =
new MisClassificationError();
107 fQualityIndex = qualityIndex;
115 CCPruner::~CCPruner( )
117 if(fOwnQIndex)
delete fQualityIndex;
124 void CCPruner::Optimize( )
126 Bool_t HaveStopCondition = fAlpha > 0;
129 CCTreeWrapper* dTWrapper =
new CCTreeWrapper(fTree, fQualityIndex);
132 Double_t epsilon = std::numeric_limits<double>::epsilon();
133 Double_t alpha = -1.0e10;
135 std::ofstream outfile;
136 if (fDebug) outfile.open(
"costcomplexity.log");
137 if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
138 if (fDebug) outfile <<
"ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
140 if (fDebug) outfile.close();
144 CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
145 while(R->GetNLeafDaughters() > 1) {
146 if(R->GetMinAlphaC() > alpha)
147 alpha = R->GetMinAlphaC();
149 if(HaveStopCondition && alpha > fAlpha)
break;
151 CCTreeWrapper::CCTreeNode* t = R;
153 while(t->GetMinAlphaC() < t->GetAlphaC()) {
155 if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon)
156 t = t->GetLeftDaughter();
158 t = t->GetRightDaughter();
162 if (fDebug) outfile << std::endl <<
"Caught trying to prune the root node!" << std::endl;
166 CCTreeWrapper::CCTreeNode* n = t;
169 outfile <<
"===========================" << std::endl
170 <<
"Pruning branch listed below" << std::endl
171 <<
"===========================" << std::endl;
172 t->PrintRec( outfile );
175 if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
178 dTWrapper->PruneNode(t);
182 t->SetNLeafDaughters(t->GetLeftDaughter()->GetNLeafDaughters() + t->GetRightDaughter()->GetNLeafDaughters());
183 t->SetResubstitutionEstimate(t->GetLeftDaughter()->GetResubstitutionEstimate() +
184 t->GetRightDaughter()->GetResubstitutionEstimate());
185 t->SetAlphaC((t->GetNodeResubstitutionEstimate() - t->GetResubstitutionEstimate())/(t->GetNLeafDaughters() - 1));
186 t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
187 t->GetRightDaughter()->GetMinAlphaC())));
190 if(!HaveStopCondition) {
192 if (fValidationDataSet != NULL) q = dTWrapper->TestTreeQuality(fValidationDataSet);
193 else q = dTWrapper->TestTreeQuality(fValidationSample);
194 fQualityIndexList.push_back(q);
197 fQualityIndexList.push_back(1.0);
199 fPruneSequence.push_back(n->GetDTNode());
200 fPruneStrengthList.push_back(alpha);
203 Double_t qmax = -1.0e6;
204 if(!HaveStopCondition) {
205 for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
206 if(fQualityIndexList[i] > qmax) {
207 qmax = fQualityIndexList[i];
214 fOptimalK = fPruneSequence.size() - 1;
218 outfile << std::endl <<
"************ Summary **************" << std::endl
219 <<
"Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
221 outfile <<
"Pruning strength parameters: [";
222 for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
223 outfile << fPruneStrengthList[i] <<
", ";
224 outfile << fPruneStrengthList[fPruneStrengthList.size()-1] <<
"]" << std::endl;
226 outfile <<
"Misclassification rates: [";
227 for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
228 outfile << fQualityIndexList[i] <<
", ";
229 outfile << fQualityIndexList[fQualityIndexList.size()-1] <<
"]" << std::endl;
231 outfile <<
"Optimal index: " << fOptimalK+1 << std::endl;
240 std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( )
const
242 std::vector<DecisionTreeNode*> optimalSequence;
243 if( fOptimalK >= 0 ) {
244 for( Int_t i = 0; i < fOptimalK; i++ ) {
245 optimalSequence.push_back(fPruneSequence[i]);
248 return optimalSequence;