54 ClassImp(TMVA::VariableNormalizeTransform);
59 TMVA::VariableNormalizeTransform::VariableNormalizeTransform( DataSetInfo& dsi )
60 : VariableTransformBase( dsi, Types::kNormalized,
"Norm" )
66 TMVA::VariableNormalizeTransform::~VariableNormalizeTransform() {
72 void TMVA::VariableNormalizeTransform::Initialize()
74 UInt_t inputSize = fGet.size();
75 Int_t numC = GetNClasses()+1;
76 if (GetNClasses() <= 1 ) numC = 1;
80 for (Int_t i=0; i<numC; i++) {
81 fMin.at(i).resize(inputSize);
82 fMax.at(i).resize(inputSize);
83 fMin.at(i).assign(inputSize, 0);
84 fMax.at(i).assign(inputSize, 0);
91 Bool_t TMVA::VariableNormalizeTransform::PrepareTransformation (
const std::vector<Event*>& events)
93 if (!IsEnabled() || IsCreated())
return kTRUE;
95 Log() << kDEBUG <<
"\tPreparing the transformation." << Endl;
99 CalcNormalizationParams( events );
109 const TMVA::Event* TMVA::VariableNormalizeTransform::Transform(
const TMVA::Event*
const ev, Int_t cls )
const
111 if (!IsCreated()) Log() << kFATAL <<
"Transformation not yet created" << Endl;
120 if (cls < 0 || cls >= (
int) fMin.size()) cls = fMin.size()-1;
125 std::vector<Char_t> mask;
126 GetInput( ev, input, mask );
128 if (fTransformedEvent==0) fTransformedEvent =
new Event();
131 const FloatVector& minVector = fMin.at(cls);
132 const FloatVector& maxVector = fMax.at(cls);
135 std::vector<Char_t>::iterator itMask = mask.begin();
136 for ( std::vector<Float_t>::iterator itInp = input.begin(), itInpEnd = input.end(); itInp != itInpEnd; ++itInp) {
144 Float_t val = (*itInp);
146 min = minVector.at(iidx);
147 max = maxVector.at(iidx);
148 Float_t offset = min;
149 Float_t scale = 1.0/(max-min);
151 Float_t valnorm = (val-offset)*scale * 2 - 1;
152 output.push_back( valnorm );
158 SetOutput( fTransformedEvent, output, mask, ev );
159 return fTransformedEvent;
165 const TMVA::Event* TMVA::VariableNormalizeTransform::InverseTransform(
const TMVA::Event*
const ev, Int_t cls )
const
167 if (!IsCreated()) Log() << kFATAL <<
"Transformation not yet created" << Endl;
171 if (cls < 0 || cls > GetNClasses()) {
172 if (GetNClasses() > 1 ) cls = GetNClasses();
178 std::vector<Char_t> mask;
179 GetInput( ev, input, mask, kTRUE );
181 if (fBackTransformedEvent==0) fBackTransformedEvent =
new Event( *ev );
184 const FloatVector& minVector = fMin.at(cls);
185 const FloatVector& maxVector = fMax.at(cls);
188 for ( std::vector<Float_t>::iterator itInp = input.begin(), itInpEnd = input.end(); itInp != itInpEnd; ++itInp) {
189 Float_t val = (*itInp);
191 min = minVector.at(iidx);
192 max = maxVector.at(iidx);
193 Float_t offset = min;
194 Float_t scale = 1.0/(max-min);
196 Float_t valnorm = offset+((val+1)/(scale * 2));
197 output.push_back( valnorm );
202 SetOutput( fBackTransformedEvent, output, mask, ev, kTRUE );
204 return fBackTransformedEvent;
210 void TMVA::VariableNormalizeTransform::CalcNormalizationParams(
const std::vector< Event*>& events )
212 if (events.size() <= 1)
213 Log() << kFATAL <<
"Not enough events (found " << events.size() <<
") to calculate the normalization" << Endl;
216 std::vector<Char_t> mask;
218 UInt_t inputSize = fGet.size();
220 const UInt_t nCls = GetNClasses();
228 for (UInt_t iinp=0; iinp<inputSize; ++iinp) {
229 for (Int_t ic = 0; ic < numC; ic++) {
230 fMin.at(ic).at(iinp) = FLT_MAX;
231 fMax.at(ic).at(iinp) = -FLT_MAX;
235 std::vector<Event*>::const_iterator evIt = events.begin();
236 for (;evIt!=events.end();++evIt) {
237 const TMVA::Event*
event = (*evIt);
239 UInt_t cls = (*evIt)->GetClass();
241 FloatVector& minVector = fMin.at(cls);
242 FloatVector& maxVector = fMax.at(cls);
244 FloatVector& minVectorAll = fMin.at(all);
245 FloatVector& maxVectorAll = fMax.at(all);
247 GetInput(event,input,mask);
249 for ( std::vector<Float_t>::iterator itInp = input.begin(), itInpEnd = input.end(); itInp != itInpEnd; ++itInp) {
250 Float_t val = (*itInp);
252 if( minVector.at(iidx) > val ) minVector.at(iidx) = val;
253 if( maxVector.at(iidx) < val ) maxVector.at(iidx) = val;
256 if (minVectorAll.at(iidx) > val) minVectorAll.at(iidx) = val;
257 if (maxVectorAll.at(iidx) < val) maxVectorAll.at(iidx) = val;
270 std::vector<TString>* TMVA::VariableNormalizeTransform::GetTransformationStrings( Int_t cls )
const
274 if (cls < 0 || cls > GetNClasses()) cls = GetNClasses();
277 const UInt_t size = fGet.size();
278 std::vector<TString>* strVec =
new std::vector<TString>(size);
281 for( ItVarTypeIdxConst itGet = fGet.begin(), itGetEnd = fGet.end(); itGet != itGetEnd; ++itGet ) {
282 min = fMin.at(cls).at(iinp);
283 max = fMax.at(cls).at(iinp);
285 Char_t type = (*itGet).first;
286 UInt_t idx = (*itGet).second;
287 Float_t offset = min;
288 Float_t scale = 1.0/(max-min);
290 VariableInfo& varInfo = (type==
'v'?fDsi.GetVariableInfo(idx):(type==
't'?fDsi.GetTargetInfo(idx):fDsi.GetSpectatorInfo(idx)));
292 if (offset < 0) str = Form(
"2*%g*([%s] + %g) - 1", scale, varInfo.GetLabel().Data(), -offset );
293 else str = Form(
"2*%g*([%s] - %g) - 1", scale, varInfo.GetLabel().Data(), offset );
294 (*strVec)[iinp] = str;
305 void TMVA::VariableNormalizeTransform::WriteTransformationToStream( std::ostream& o )
const
307 o <<
"# min max for all variables for all classes one after the other and as a last entry for all classes together" << std::endl;
309 Int_t numC = GetNClasses()+1;
310 if (GetNClasses() <= 1 ) numC = 1;
312 UInt_t nvars = GetNVariables();
313 UInt_t ntgts = GetNTargets();
315 for (Int_t icls = 0; icls < numC; icls++ ) {
316 o << icls << std::endl;
317 for (UInt_t ivar=0; ivar<nvars; ivar++)
318 o << std::setprecision(12) << std::setw(20) << fMin.at(icls).at(ivar) <<
" "
319 << std::setprecision(12) << std::setw(20) << fMax.at(icls).at(ivar) << std::endl;
320 for (UInt_t itgt=0; itgt<ntgts; itgt++)
321 o << std::setprecision(12) << std::setw(20) << fMin.at(icls).at(nvars+itgt) <<
" "
322 << std::setprecision(12) << std::setw(20) << fMax.at(icls).at(nvars+itgt) << std::endl;
324 o <<
"##" << std::endl;
330 void TMVA::VariableNormalizeTransform::AttachXMLTo(
void* parent)
332 void* trfxml = gTools().AddChild(parent,
"Transform");
333 gTools().AddAttr(trfxml,
"Name",
"Normalize");
334 VariableTransformBase::AttachXMLTo( trfxml );
336 Int_t numC = (GetNClasses()<= 1)?1:GetNClasses()+1;
338 for( Int_t icls=0; icls<numC; icls++ ) {
339 void* clsxml = gTools().AddChild(trfxml,
"Class");
340 gTools().AddAttr(clsxml,
"ClassIndex", icls);
341 void* inpxml = gTools().AddChild(clsxml,
"Ranges");
343 for( ItVarTypeIdx itGet = fGet.begin(), itGetEnd = fGet.end(); itGet != itGetEnd; ++itGet ) {
344 void* mmxml = gTools().AddChild(inpxml,
"Range");
345 gTools().AddAttr(mmxml,
"Index", iinp);
346 gTools().AddAttr(mmxml,
"Min", fMin.at(icls).at(iinp) );
347 gTools().AddAttr(mmxml,
"Max", fMax.at(icls).at(iinp) );
356 void TMVA::VariableNormalizeTransform::ReadFromXML(
void* trfnode )
358 Bool_t newFormat = kFALSE;
360 void* inpnode = NULL;
362 inpnode = gTools().GetChild(trfnode,
"Selection");
363 if( inpnode != NULL )
369 VariableTransformBase::ReadFromXML( inpnode );
373 UInt_t size = fGet.size();
374 UInt_t classindex, idx;
376 void* ch = gTools().GetChild( trfnode,
"Class" );
379 gTools().ReadAttr(ch,
"ClassIndex", ci);
380 classindex = UInt_t(ci);
382 fMin.resize(classindex+1);
383 fMax.resize(classindex+1);
385 fMin[classindex].resize(size,Float_t(0));
386 fMax[classindex].resize(size,Float_t(0));
388 void* clch = gTools().GetChild( ch );
390 TString nodeName(gTools().GetName(clch));
391 if(nodeName==
"Ranges") {
392 void* varch = gTools().GetChild( clch );
394 gTools().ReadAttr(varch,
"Index", idx);
395 gTools().ReadAttr(varch,
"Min", fMin[classindex][idx]);
396 gTools().ReadAttr(varch,
"Max", fMax[classindex][idx]);
397 varch = gTools().GetNextChild( varch );
400 clch = gTools().GetNextChild( clch );
402 ch = gTools().GetNextChild( ch );
409 UInt_t classindex, varindex, tgtindex, nvars, ntgts;
411 gTools().ReadAttr(trfnode,
"NVariables", nvars);
413 gTools().ReadAttr(trfnode,
"NTargets", ntgts);
416 for( UInt_t ivar = 0; ivar < nvars; ++ivar ){
417 fGet.push_back(std::pair<Char_t,UInt_t>(
'v',ivar));
419 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ){
420 fGet.push_back(std::pair<Char_t,UInt_t>(
't',itgt));
422 void* ch = gTools().GetChild( trfnode );
424 gTools().ReadAttr(ch,
"ClassIndex", classindex);
426 fMin.resize(classindex+1);
427 fMax.resize(classindex+1);
428 fMin[classindex].resize(nvars+ntgts,Float_t(0));
429 fMax[classindex].resize(nvars+ntgts,Float_t(0));
431 void* clch = gTools().GetChild( ch );
433 TString nodeName(gTools().GetName(clch));
434 if(nodeName==
"Variables") {
435 void* varch = gTools().GetChild( clch );
437 gTools().ReadAttr(varch,
"VarIndex", varindex);
438 gTools().ReadAttr(varch,
"Min", fMin[classindex][varindex]);
439 gTools().ReadAttr(varch,
"Max", fMax[classindex][varindex]);
440 varch = gTools().GetNextChild( varch );
442 }
else if (nodeName==
"Targets") {
443 void* tgtch = gTools().GetChild( clch );
445 gTools().ReadAttr(tgtch,
"TargetIndex", tgtindex);
446 gTools().ReadAttr(tgtch,
"Min", fMin[classindex][nvars+tgtindex]);
447 gTools().ReadAttr(tgtch,
"Max", fMax[classindex][nvars+tgtindex]);
448 tgtch = gTools().GetNextChild( tgtch );
451 clch = gTools().GetNextChild( clch );
453 ch = gTools().GetNextChild( ch );
463 void TMVA::VariableNormalizeTransform::BuildTransformationFromVarInfo(
const std::vector<TMVA::VariableInfo>& var )
465 UInt_t nvars = GetNVariables();
467 if(var.size() != nvars)
468 Log() << kFATAL <<
"<BuildTransformationFromVarInfo> can't build transformation,"
469 <<
" since the number of variables disagree" << Endl;
471 UInt_t numC = (GetNClasses()<=1)?1:GetNClasses()+1;
472 fMin.clear();fMin.resize( numC );
473 fMax.clear();fMax.resize( numC );
476 for(UInt_t cls=0; cls<numC; ++cls) {
477 fMin[cls].resize(nvars+GetNTargets(),0);
478 fMax[cls].resize(nvars+GetNTargets(),0);
480 for(std::vector<TMVA::VariableInfo>::const_iterator v = var.begin(); v!=var.end(); ++v, ++vidx) {
481 fMin[cls][vidx] = v->GetMin();
482 fMax[cls][vidx] = v->GetMax();
483 fGet.push_back(std::pair<Char_t,UInt_t>(
'v',vidx));
492 void TMVA::VariableNormalizeTransform::ReadTransformationFromStream( std::istream& istr,
const TString& )
494 UInt_t nvars = GetNVariables();
495 UInt_t ntgts = GetNTargets();
496 for( UInt_t ivar = 0; ivar < nvars; ++ivar ){
497 fGet.push_back(std::pair<Char_t,UInt_t>(
'v',ivar));
499 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ){
500 fGet.push_back(std::pair<Char_t,UInt_t>(
't',itgt));
504 istr.getline(buf,512);
505 TString strvar, dummy;
508 while (!(buf[0]==
'#'&& buf[1]==
'#')) {
510 while (*p==
' ' || *p==
'\t') p++;
511 if (*p==
'#' || *p==
'\0') {
512 istr.getline(buf,512);
515 std::stringstream sstr(buf);
517 for (UInt_t ivar=0;ivar<nvars;ivar++) {
518 istr.getline(buf2,512);
519 std::stringstream sstr2(buf2);
520 sstr2 >> fMin[icls][ivar] >> fMax[icls][ivar];
522 for (UInt_t itgt=0;itgt<ntgts;itgt++) {
523 istr.getline(buf2,512);
524 std::stringstream sstr2(buf2);
525 sstr2 >> fMin[icls][nvars+itgt] >> fMax[icls][nvars+itgt];
527 istr.getline(buf,512);
535 void TMVA::VariableNormalizeTransform::PrintTransformation( std::ostream& )
537 Int_t nCls = GetNClasses();
539 if (nCls <= 1 ) numC = 1;
540 for (Int_t icls = 0; icls < numC; icls++ ) {
542 Log() << kINFO <<
"Transformation for all classes based on these ranges:" << Endl;
544 Log() << kINFO <<
"Transformation for class " << icls <<
" based on these ranges:" << Endl;
546 for( ItVarTypeIdxConst itGet = fGet.begin(), itGetEnd = fGet.end(); itGet != itGetEnd; ++itGet ){
547 Char_t type = (*itGet).first;
548 UInt_t idx = (*itGet).second;
550 TString typeString = (type==
'v'?
"Variable: ": (type==
't'?
"Target : ":
"Spectator : ") );
551 Log() << typeString.Data() << std::setw(20) << fMin[icls][idx] << std::setw(20) << fMax[icls][idx] << Endl;
562 void TMVA::VariableNormalizeTransform::MakeFunction( std::ostream& fout,
const TString& fcncName,
563 Int_t part, UInt_t trCounter, Int_t )
565 UInt_t nVar = fGet.size();
566 UInt_t numC = fMin.size();
569 fout <<
" double fOff_" << trCounter <<
"[" << numC <<
"][" << nVar <<
"];" << std::endl;
570 fout <<
" double fScal_" << trCounter <<
"[" << numC <<
"][" << nVar <<
"];" << std::endl;
575 fout <<
"//_______________________________________________________________________" << std::endl;
576 fout <<
"inline void " << fcncName <<
"::InitTransform_" << trCounter <<
"()" << std::endl;
577 fout <<
"{" << std::endl;
578 fout <<
" double fMin_" << trCounter <<
"[" << numC <<
"][" << nVar <<
"];" << std::endl;
579 fout <<
" double fMax_" << trCounter <<
"[" << numC <<
"][" << nVar <<
"];" << std::endl;
580 fout <<
" // Normalization transformation, initialisation" << std::endl;
581 for (UInt_t ivar = 0; ivar < nVar; ivar++) {
582 for (UInt_t icls = 0; icls < numC; icls++) {
583 Double_t min = TMath::Min(FLT_MAX, fMin.at(icls).at(ivar));
584 Double_t max = TMath::Max(-FLT_MAX, fMax.at(icls).at(ivar));
585 fout <<
" fMin_" << trCounter <<
"[" << icls <<
"][" << ivar <<
"] = " << std::setprecision(12) << min
587 fout <<
" fMax_" << trCounter <<
"[" << icls <<
"][" << ivar <<
"] = " << std::setprecision(12) << max
589 fout <<
" fScal_" << trCounter <<
"[" << icls <<
"][" << ivar <<
"] = 2.0/(fMax_" << trCounter <<
"["
590 << icls <<
"][" << ivar <<
"]-fMin_" << trCounter <<
"[" << icls <<
"][" << ivar <<
"]);" << std::endl;
591 fout <<
" fOff_" << trCounter <<
"[" << icls <<
"][" << ivar <<
"] = fMin_" << trCounter <<
"[" << icls
592 <<
"][" << ivar <<
"]*fScal_" << trCounter <<
"[" << icls <<
"][" << ivar <<
"]+1.;" << std::endl;
595 fout <<
"}" << std::endl;
597 fout <<
"//_______________________________________________________________________" << std::endl;
598 fout <<
"inline void " << fcncName <<
"::Transform_" << trCounter <<
"( std::vector<double>& iv, int cls) const"
600 fout <<
"{" << std::endl;
601 fout <<
" // Normalization transformation" << std::endl;
602 fout <<
" if (cls < 0 || cls > " << GetNClasses() <<
") {" << std::endl;
603 fout <<
" if (" << GetNClasses() <<
" > 1 ) cls = " << GetNClasses() <<
";" << std::endl;
604 fout <<
" else cls = " << (fMin.size() == 1 ? 0 : 2) <<
";" << std::endl;
605 fout <<
" }" << std::endl;
606 fout <<
" const int nVar = " << nVar <<
";" << std::endl << std::endl;
607 fout <<
" // get indices of used variables" << std::endl;
608 VariableTransformBase::MakeFunction(fout, fcncName, 0, trCounter, 0);
609 fout <<
" static std::vector<double> dv;"
611 fout <<
" dv.resize(nVar);" << std::endl;
612 fout <<
" for (int ivar=0; ivar<nVar; ivar++) dv[ivar] = iv[indicesGet.at(ivar)];" << std::endl;
614 fout <<
" for (int ivar=0;ivar<" << nVar <<
";ivar++) {" << std::endl;
615 fout <<
" double offset = fOff_" << trCounter <<
"[cls][ivar];" << std::endl;
616 fout <<
" double scale = fScal_" << trCounter <<
"[cls][ivar];" << std::endl;
617 fout <<
" iv[indicesPut.at(ivar)] = scale*dv[ivar]-offset;" << std::endl;
618 fout <<
" }" << std::endl;
619 fout <<
"}" << std::endl;