58 TMVA::ResultsMulticlass::ResultsMulticlass(
const DataSetInfo* dsi, TString resultsName )
59 : Results( dsi, resultsName ),
61 fLogger( new MsgLogger(Form(
"ResultsMultiClass%s",resultsName.Data()) , kINFO) ),
63 fAchievableEff(dsi->GetNClasses()),
64 fAchievablePur(dsi->GetNClasses()),
65 fBestCuts(dsi->GetNClasses(),std::vector<Double_t>(dsi->GetNClasses()))
72 TMVA::ResultsMulticlass::~ResultsMulticlass()
79 void TMVA::ResultsMulticlass::SetValue( std::vector<Float_t>& value, Int_t ievt )
81 if (ievt >= (Int_t)fMultiClassValues.size()) fMultiClassValues.resize( ievt+1 );
82 fMultiClassValues[ievt] = value;
89 TMatrixD TMVA::ResultsMulticlass::GetConfusionMatrix(Double_t effB)
91 const DataSet *ds = GetDataSet();
92 const DataSetInfo *dsi = GetDataSetInfo();
93 ds->SetCurrentType(GetTreeType());
95 UInt_t numClasses = dsi->GetNClasses();
96 TMatrixD mat(numClasses, numClasses);
99 for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
100 for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
104 mat(iRow, iCol) = std::numeric_limits<double>::quiet_NaN();
107 std::vector<Float_t> valueVector;
108 std::vector<Bool_t> classVector;
109 std::vector<Float_t> weightVector;
111 for (UInt_t iEvt = 0; iEvt < ds->GetNEvents(); ++iEvt) {
112 const Event *ev = ds->GetEvent(iEvt);
113 const UInt_t cls = ev->GetClass();
114 const Float_t weight = ev->GetWeight();
115 const Float_t mvaValue = fMultiClassValues[iEvt][iRow];
117 if (cls != iRow && cls != iCol) {
121 classVector.push_back(cls == iRow);
122 weightVector.push_back(weight);
123 valueVector.push_back(mvaValue);
126 ROCCurve roc(valueVector, classVector, weightVector);
127 mat(iRow, iCol) = roc.GetEffSForEffB(effB);
136 Double_t TMVA::ResultsMulticlass::EstimatorFunction( std::vector<Double_t> & cutvalues ){
138 DataSet* ds = GetDataSet();
139 ds->SetCurrentType( GetTreeType() );
143 Float_t positives[2] = {0, 0};
145 for (Int_t ievt = 0; ievt < ds->GetNEvents(); ievt++) {
146 UInt_t evClass = fEventClasses[ievt];
147 Float_t w = fEventWeights[ievt];
149 Bool_t break_outer_loop =
false;
150 for (UInt_t icls = 0; icls < cutvalues.size(); ++icls) {
151 auto value = fMultiClassValues[ievt][icls];
152 auto cutvalue = cutvalues.at(icls);
153 if (cutvalue < 0. ? (-value < cutvalue) : (+value <= cutvalue)) {
154 break_outer_loop =
true;
159 if (break_outer_loop) {
163 Bool_t isEvCurrClass = (evClass == fClassToOptimize);
164 positives[isEvCurrClass] += w;
167 const Float_t truePositive = positives[1];
168 const Float_t falsePositive = positives[0];
170 Float_t eff = truePositive / fClassSumWeights[fClassToOptimize];
171 Float_t pur = truePositive / (truePositive + falsePositive);
172 Float_t effTimesPur = eff*pur;
174 Float_t toMinimize = std::numeric_limits<float>::max();
175 if (effTimesPur > std::numeric_limits<float>::min())
176 toMinimize = 1./(effTimesPur);
178 fAchievableEff.at(fClassToOptimize) = eff;
179 fAchievablePur.at(fClassToOptimize) = pur;
188 std::vector<Double_t> TMVA::ResultsMulticlass::GetBestMultiClassCuts(UInt_t targetClass){
190 const DataSetInfo* dsi = GetDataSetInfo();
191 Log() << kINFO <<
"Calculating best set of cuts for class "
192 << dsi->GetClassInfo( targetClass )->GetName() << Endl;
194 fClassToOptimize = targetClass;
195 std::vector<Interval*> ranges(dsi->GetNClasses(),
new Interval(-1,1));
197 fClassSumWeights.clear();
198 fEventWeights.clear();
199 fEventClasses.clear();
201 for (UInt_t icls = 0; icls < dsi->GetNClasses(); ++icls) {
202 fClassSumWeights.push_back(0);
205 DataSet *ds = GetDataSet();
206 for (Int_t ievt = 0; ievt < ds->GetNEvents(); ievt++) {
207 const Event *ev = ds->GetEvent(ievt);
208 fClassSumWeights[ev->GetClass()] += ev->GetWeight();
209 fEventWeights.push_back(ev->GetWeight());
210 fEventClasses.push_back(ev->GetClass());
213 const TString name(
"MulticlassGA" );
214 const TString opts(
"PopSize=100:Steps=30" );
215 GeneticFitter mg( *
this, name, ranges, opts);
217 std::vector<Double_t> result;
220 fBestCuts.at(targetClass) = result;
223 for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); ++it ){
224 Log() << kINFO <<
" cutValue[" <<dsi->GetClassInfo( n )->GetName() <<
"] = " << (*it) <<
";"<< Endl;
240 void TMVA::ResultsMulticlass::CreateMulticlassPerformanceHistos(TString prefix)
243 Log() << kINFO <<
"Creating multiclass performance histograms..." << Endl;
245 DataSet *ds = GetDataSet();
246 ds->SetCurrentType(GetTreeType());
247 const DataSetInfo *dsi = GetDataSetInfo();
249 UInt_t numClasses = dsi->GetNClasses();
251 std::vector<std::vector<Float_t>> *rawMvaRes = GetValueVector();
256 for (
size_t iClass = 0; iClass < numClasses; ++iClass) {
258 TString className = dsi->GetClassInfo(iClass)->GetName();
259 TString name = Form(
"%s_rejBvsS_%s", prefix.Data(), className.Data());
260 TString title = Form(
"%s_%s", prefix.Data(), className.Data());
263 if ( DoesExist(name) ) {
268 std::vector<Float_t> mvaRes;
269 std::vector<Bool_t> mvaResTypes;
270 std::vector<Float_t> mvaResWeights;
275 mvaRes.reserve(rawMvaRes->size());
276 for (
auto item : *rawMvaRes) {
277 mvaRes.push_back(item[iClass]);
280 auto eventCollection = ds->GetEventCollection();
281 mvaResTypes.reserve(eventCollection.size());
282 mvaResWeights.reserve(eventCollection.size());
283 for (
auto ev : eventCollection) {
284 mvaResTypes.push_back(ev->GetClass() == iClass);
285 mvaResWeights.push_back(ev->GetWeight());
289 ROCCurve *roc =
new ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
290 TGraph *rocGraph =
new TGraph(*(roc->GetROCCurve()));
294 rocGraph->SetName(name);
295 rocGraph->SetTitle(title);
304 for (
size_t iClass = 0; iClass < numClasses; ++iClass) {
305 for (
size_t jClass = 0; jClass < numClasses; ++jClass) {
306 if (iClass == jClass) {
310 auto eventCollection = ds->GetEventCollection();
313 std::vector<Float_t> mvaRes;
314 std::vector<Bool_t> mvaResTypes;
315 std::vector<Float_t> mvaResWeights;
317 mvaRes.reserve(rawMvaRes->size());
318 mvaResTypes.reserve(eventCollection.size());
319 mvaResWeights.reserve(eventCollection.size());
321 for (
size_t iEvent = 0; iEvent < eventCollection.size(); ++iEvent) {
322 Event *ev = eventCollection[iEvent];
324 if (ev->GetClass() == iClass || ev->GetClass() == jClass) {
325 Float_t output_value = (*rawMvaRes)[iEvent][iClass];
326 mvaRes.push_back(output_value);
327 mvaResTypes.push_back(ev->GetClass() == iClass);
328 mvaResWeights.push_back(ev->GetWeight());
333 ROCCurve *roc =
new ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
334 TGraph *rocGraph =
new TGraph(*(roc->GetROCCurve()));
338 TString iClassName = dsi->GetClassInfo(iClass)->GetName();
339 TString jClassName = dsi->GetClassInfo(jClass)->GetName();
340 TString name = Form(
"%s_1v1rejBvsS_%s_vs_%s", prefix.Data(), iClassName.Data(), jClassName.Data());
341 TString title = Form(
"%s_%s_vs_%s", prefix.Data(), iClassName.Data(), jClassName.Data());
342 rocGraph->SetName(name);
343 rocGraph->SetTitle(title);
354 void TMVA::ResultsMulticlass::CreateMulticlassHistos( TString prefix, Int_t nbins, Int_t )
356 Log() << kINFO <<
"Creating multiclass response histograms..." << Endl;
358 DataSet* ds = GetDataSet();
359 ds->SetCurrentType( GetTreeType() );
360 const DataSetInfo* dsi = GetDataSetInfo();
362 std::vector<std::vector<TH1F*> > histos;
363 Float_t xmin = 0.-0.0002;
364 Float_t xmax = 1.+0.0002;
365 for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
366 histos.push_back(std::vector<TH1F*>(0));
367 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
368 TString name(Form(
"%s_%s_prob_for_%s",prefix.Data(),
369 dsi->GetClassInfo( jCls )->GetName(),
370 dsi->GetClassInfo( iCls )->GetName()));
373 if ( DoesExist(name) ) {
377 histos.at(iCls).push_back(
new TH1F(name,name,nbins,xmin,xmax));
381 for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
382 const Event* ev = ds->GetEvent(ievt);
383 Int_t cls = ev->GetClass();
384 Float_t w = ev->GetWeight();
385 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
386 histos.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
389 for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
390 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
391 gTools().NormHist( histos.at(iCls).at(jCls) );
392 Store(histos.at(iCls).at(jCls));