52 ClassImp(TMVA::VariableDecorrTransform);
57 TMVA::VariableDecorrTransform::VariableDecorrTransform( DataSetInfo& dsi )
58 : VariableTransformBase( dsi, Types::kDecorrelated,
"Deco" )
65 TMVA::VariableDecorrTransform::~VariableDecorrTransform()
67 for (std::vector<TMatrixD*>::iterator it = fDecorrMatrices.begin(); it != fDecorrMatrices.end(); ++it) {
68 if ((*it) != 0)
delete (*it);
75 void TMVA::VariableDecorrTransform::Initialize()
82 Bool_t TMVA::VariableDecorrTransform::PrepareTransformation (
const std::vector<Event*>& events)
86 if (!IsEnabled() || IsCreated())
return kTRUE;
88 Log() << kINFO <<
"Preparing the Decorrelation transformation..." << Endl;
90 Int_t inputSize = fGet.size();
91 SetNVariables(inputSize);
93 if (inputSize > 200) {
94 Log() << kINFO <<
"----------------------------------------------------------------------------"
97 <<
": More than 200 variables, will not calculate decorrelation matrix "
99 Log() << kINFO <<
"----------------------------------------------------------------------------"
104 CalcSQRMats( events, GetNClasses() );
114 std::vector<TString>* TMVA::VariableDecorrTransform::GetTransformationStrings( Int_t cls )
const
116 Int_t whichMatrix = cls;
120 if (cls < 0 || cls > GetNClasses()) whichMatrix = GetNClasses();
122 TMatrixD* m = fDecorrMatrices.at(whichMatrix);
124 if (whichMatrix == GetNClasses() )
125 Log() << kFATAL <<
"Transformation matrix all classes is not defined"
128 Log() << kFATAL <<
"Transformation matrix for class " << whichMatrix <<
" is not defined"
132 const Int_t nvar = fGet.size();
133 std::vector<TString>* strVec =
new std::vector<TString>;
136 for (Int_t ivar=0; ivar<nvar; ivar++) {
138 for (Int_t jvar=0; jvar<nvar; jvar++) {
139 str += ((*m)(ivar,jvar) > 0) ?
" + " :
" - ";
141 Char_t type = fGet.at(jvar).first;
142 Int_t idx = fGet.at(jvar).second;
146 str += Form(
"%10.5g*[%s]", TMath::Abs((*m)(ivar,jvar)), Variables()[idx].GetLabel().Data() );
149 str += Form(
"%10.5g*[%s]", TMath::Abs((*m)(ivar,jvar)), Targets()[idx].GetLabel().Data() );
152 str += Form(
"%10.5g*[%s]", TMath::Abs((*m)(ivar,jvar)), Spectators()[idx].GetLabel().Data() );
155 Log() << kFATAL <<
"VariableDecorrTransform::GetTransformationStrings : unknown type '" << type <<
"'." << Endl;
158 strVec->push_back( str );
167 const TMVA::Event* TMVA::VariableDecorrTransform::Transform(
const TMVA::Event*
const ev, Int_t cls )
const
170 Log() << kFATAL <<
"Transformation matrix not yet created"
173 Int_t whichMatrix = cls;
176 if (cls < 0 || cls >= (
int) fDecorrMatrices.size()) whichMatrix = fDecorrMatrices.size()-1;
183 TMatrixD* m = fDecorrMatrices.at(whichMatrix);
185 if (whichMatrix == GetNClasses() )
186 Log() << kFATAL <<
"Transformation matrix all classes is not defined"
189 Log() << kFATAL <<
"Transformation matrix for class " << whichMatrix <<
" is not defined"
193 if (fTransformedEvent==0 || fTransformedEvent->GetNVariables()!=ev->GetNVariables()) {
194 if (fTransformedEvent!=0) {
delete fTransformedEvent; fTransformedEvent = 0; }
195 fTransformedEvent =
new Event();
199 const Int_t nvar = fGet.size();
201 std::vector<Float_t> input;
202 std::vector<Char_t> mask;
203 Bool_t hasMaskedEntries = GetInput( ev, input, mask );
205 if( hasMaskedEntries ){
206 UInt_t numMasked = std::count(mask.begin(), mask.end(), (Char_t)kTRUE);
207 UInt_t numOK = std::count(mask.begin(), mask.end(), (Char_t)kFALSE);
208 if( numMasked>0 && numOK>0 ){
209 Log() << kFATAL <<
"You mixed variables and targets in the decorrelation transformation. This is not possible." << Endl;
211 SetOutput( fTransformedEvent, input, mask, ev );
212 return fTransformedEvent;
215 TVectorD vec( nvar );
216 for (Int_t ivar=0; ivar<nvar; ivar++) vec(ivar) = input.at(ivar);
222 for (Int_t ivar=0; ivar<nvar; ivar++) input.push_back( vec(ivar) );
224 SetOutput( fTransformedEvent, input, mask, ev );
226 return fTransformedEvent;
233 const TMVA::Event* TMVA::VariableDecorrTransform::InverseTransform(
const TMVA::Event*
const , Int_t )
const
235 Log() << kFATAL <<
"Inverse transformation for decorrelation transformation not yet implemented. Hence, this transformation cannot be applied together with regression if targets should be transformed. Please contact the authors if necessary." << Endl;
238 return fBackTransformedEvent;
244 void TMVA::VariableDecorrTransform::CalcSQRMats(
const std::vector< Event*>& events, Int_t maxCls )
247 for (std::vector<TMatrixD*>::iterator it = fDecorrMatrices.begin();
248 it != fDecorrMatrices.end(); ++it)
249 if (0 != (*it) ) {
delete (*it); *it=0; }
253 const UInt_t matNum = (maxCls<=1)?maxCls:maxCls+1;
254 fDecorrMatrices.resize( matNum, (TMatrixD*) 0 );
256 std::vector<TMatrixDSym*>* covMat = gTools().CalcCovarianceMatrices( events, maxCls,
this );
259 for (UInt_t cls=0; cls<matNum; cls++) {
260 TMatrixD* sqrMat = gTools().GetSQRootMatrix( covMat->at(cls) );
262 Log() << kFATAL <<
"<GetSQRMats> Zero pointer returned for SQR matrix" << Endl;
263 fDecorrMatrices[cls] = sqrMat;
264 delete (*covMat)[cls];
272 void TMVA::VariableDecorrTransform::WriteTransformationToStream( std::ostream& o )
const
275 Int_t dp = o.precision();
276 for (std::vector<TMatrixD*>::const_iterator itm = fDecorrMatrices.begin(); itm != fDecorrMatrices.end(); ++itm) {
277 o <<
"# correlation matrix " << std::endl;
278 TMatrixD* mat = (*itm);
279 o << cls <<
" " << mat->GetNrows() <<
" x " << mat->GetNcols() << std::endl;
280 for (Int_t row = 0; row<mat->GetNrows(); row++) {
281 for (Int_t col = 0; col<mat->GetNcols(); col++) {
282 o << std::setprecision(12) << std::setw(20) << (*mat)[row][col] <<
" ";
288 o <<
"##" << std::endl;
289 o << std::setprecision(dp);
295 void TMVA::VariableDecorrTransform::AttachXMLTo(
void* parent)
297 void* trf = gTools().AddChild(parent,
"Transform");
298 gTools().AddAttr(trf,
"Name",
"Decorrelation");
300 VariableTransformBase::AttachXMLTo( trf );
302 for (std::vector<TMatrixD*>::const_iterator itm = fDecorrMatrices.begin(); itm != fDecorrMatrices.end(); ++itm) {
303 TMatrixD* mat = (*itm);
315 gTools().WriteTMatrixDToXML(trf,
"Matrix",mat);
322 void TMVA::VariableDecorrTransform::ReadFromXML(
void* trfnode )
325 for( std::vector<TMatrixD*>::iterator it = fDecorrMatrices.begin(); it != fDecorrMatrices.end(); ++it )
326 if( (*it) != 0 )
delete (*it);
327 fDecorrMatrices.clear();
329 Bool_t newFormat = kFALSE;
331 void* inpnode = NULL;
333 inpnode = gTools().GetChild(trfnode,
"Selection");
341 VariableTransformBase::ReadFromXML( inpnode );
343 ch = gTools().GetNextChild(inpnode);
345 ch = gTools().GetChild(trfnode);
350 gTools().ReadAttr(ch,
"Rows", nrows);
351 gTools().ReadAttr(ch,
"Columns", ncols);
352 TMatrixD* mat =
new TMatrixD(nrows,ncols);
353 const char* content = gTools().GetContent(ch);
354 std::stringstream s(content);
355 for (Int_t row = 0; row<nrows; row++) {
356 for (Int_t col = 0; col<ncols; col++) {
357 s >> (*mat)[row][col];
360 fDecorrMatrices.push_back(mat);
361 ch = gTools().GetNextChild(ch);
369 void TMVA::VariableDecorrTransform::ReadTransformationFromStream( std::istream& istr,
const TString& classname )
372 istr.getline(buf,512);
373 TString strvar, dummy;
374 Int_t nrows(0), ncols(0);
376 while (!(buf[0]==
'#'&& buf[1]==
'#')) {
378 while (*p==
' ' || *p==
'\t') p++;
379 if (*p==
'#' || *p==
'\0') {
380 istr.getline(buf,512);
383 std::stringstream sstr(buf);
386 if (strvar==
"signal" || strvar==
"background") {
388 if(strvar==
"background") cls=1;
389 if(strvar==classname) classIdx = cls;
391 sstr >> nrows >> dummy >> ncols;
392 if (fDecorrMatrices.size() <= cls ) fDecorrMatrices.resize(cls+1);
393 if (fDecorrMatrices.at(cls) != 0)
delete fDecorrMatrices.at(cls);
394 TMatrixD* mat = fDecorrMatrices.at(cls) =
new TMatrixD(nrows,ncols);
396 for (Int_t row = 0; row<mat->GetNrows(); row++) {
397 for (Int_t col = 0; col<mat->GetNcols(); col++) {
398 istr >> (*mat)[row][col];
402 istr.getline(buf,512);
405 fDecorrMatrices.push_back(
new TMatrixD(*fDecorrMatrices[classIdx]) );
413 void TMVA::VariableDecorrTransform::PrintTransformation( std::ostream& )
416 for (std::vector<TMatrixD*>::iterator itm = fDecorrMatrices.begin(); itm != fDecorrMatrices.end(); ++itm) {
417 Log() << kINFO <<
"Transformation matrix "<< cls <<
":" << Endl;
425 void TMVA::VariableDecorrTransform::MakeFunction( std::ostream& fout,
const TString& fcncName, Int_t part, UInt_t trCounter, Int_t )
427 Int_t dp = fout.precision();
429 UInt_t numC = fDecorrMatrices.size();
432 TMatrixD* mat = fDecorrMatrices.at(0);
434 fout <<
" double fDecTF_"<<trCounter<<
"["<<numC<<
"]["<<mat->GetNrows()<<
"]["<<mat->GetNcols()<<
"];" << std::endl;
439 fout <<
"//_______________________________________________________________________" << std::endl;
440 fout <<
"inline void " << fcncName <<
"::InitTransform_"<<trCounter<<
"()" << std::endl;
441 fout <<
"{" << std::endl;
442 fout <<
" // Decorrelation transformation, initialisation" << std::endl;
443 for (UInt_t icls = 0; icls < numC; icls++){
444 TMatrixD* matx = fDecorrMatrices.at(icls);
445 for (
int i=0; i<matx->GetNrows(); i++) {
446 for (
int j=0; j<matx->GetNcols(); j++) {
447 fout <<
" fDecTF_"<<trCounter<<
"["<<icls<<
"]["<<i<<
"]["<<j<<
"] = " << std::setprecision(12) << (*matx)[i][j] <<
";" << std::endl;
451 fout <<
"}" << std::endl;
453 TMatrixD* matx = fDecorrMatrices.at(0);
454 fout <<
"//_______________________________________________________________________" << std::endl;
455 fout <<
"inline void " << fcncName <<
"::Transform_"<<trCounter<<
"( std::vector<double>& iv, int cls) const" << std::endl;
456 fout <<
"{" << std::endl;
457 fout <<
" // Decorrelation transformation" << std::endl;
458 fout <<
" if (cls < 0 || cls > "<<GetNClasses()<<
") {"<< std::endl;
459 fout <<
" if ("<<GetNClasses()<<
" > 1 ) cls = "<<GetNClasses()<<
";"<< std::endl;
460 fout <<
" else cls = "<<(fDecorrMatrices.size()==1?0:2)<<
";"<< std::endl;
461 fout <<
" }"<< std::endl;
463 VariableTransformBase::MakeFunction(fout, fcncName, 0, trCounter, 0 );
465 fout <<
" std::vector<double> tv;" << std::endl;
466 fout <<
" for (int i=0; i<"<<matx->GetNrows()<<
";i++) {" << std::endl;
467 fout <<
" double v = 0;" << std::endl;
468 fout <<
" for (int j=0; j<"<<matx->GetNcols()<<
"; j++)" << std::endl;
469 fout <<
" v += iv[indicesGet.at(j)] * fDecTF_"<<trCounter<<
"[cls][i][j];" << std::endl;
470 fout <<
" tv.push_back(v);" << std::endl;
471 fout <<
" }" << std::endl;
472 fout <<
" for (int i=0; i<"<<matx->GetNrows()<<
";i++) iv[indicesPut.at(i)] = tv[i];" << std::endl;
473 fout <<
"}" << std::endl;
476 fout << std::setprecision(dp);