56 TMVA::VarTransformHandler::VarTransformHandler( DataLoader* dl )
57 : fLogger ( new MsgLogger(TString(
"VarTransformHandler").Data(), kINFO) ),
58 fDataSetInfo(dl->GetDataSetInfo()),
60 fEvents (fDataSetInfo.GetDataSet()->GetEventCollection())
62 Log() << kINFO <<
"Number of events - " << fEvents.size() << Endl;
68 TMVA::VarTransformHandler::~VarTransformHandler()
91 TMVA::DataLoader* TMVA::VarTransformHandler::VarianceThreshold(Double_t threshold)
94 const UInt_t nvars = fDataSetInfo.GetNVariables();
95 Log() << kINFO <<
"Number of variables before transformation: " << nvars << Endl;
96 std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
102 TMVA::DataLoader *transformedLoader =
new TMVA::DataLoader(
"vt_transformed_dataset");
103 Log() << kINFO <<
"Selecting variables whose variance is above threshold value = " << threshold << Endl;
104 Int_t maxL = fDataSetInfo.GetVariableNameMaxLength();
106 Log() << kINFO <<
"----------------------------------------------------------------" << Endl;
107 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) <<
"Selected Variables";
108 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) <<
"Variance" << Endl;
109 Log() << kINFO <<
"----------------------------------------------------------------" << Endl;
110 for (UInt_t ivar=0; ivar<nvars; ivar++) {
111 Double_t variance = vars[ivar].GetVariance();
112 if (variance > threshold)
114 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << vars[ivar].GetExpression();
115 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
116 transformedLoader->AddVariable(vars[ivar].GetExpression(), vars[ivar].GetVarType());
119 CopyDataLoader(transformedLoader,fDataLoader);
120 Log() << kINFO <<
"----------------------------------------------------------------" << Endl;
123 transformedLoader->PrepareTrainingAndTestTree(fDataLoader->GetDataSetInfo().GetCut(
"Signal"), fDataLoader->GetDataSetInfo().GetCut(
"Background"), fDataLoader->GetDataSetInfo().GetSplitOptions());
124 Log() << kINFO <<
"Number of variables after transformation: " << transformedLoader->GetDataSetInfo().GetNVariables() << Endl;
126 return transformedLoader;
136 void TMVA::VarTransformHandler::UpdateNorm (Int_t ivar, Double_t x)
138 Int_t nvars = fDataSetInfo.GetNVariables();
139 std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
140 std::vector<VariableInfo>& tars = fDataSetInfo.GetTargetInfos();
142 if (x < vars[ivar].GetMin()) vars[ivar].SetMin(x);
143 if (x > vars[ivar].GetMax()) vars[ivar].SetMax(x);
146 if (x < tars[ivar-nvars].GetMin()) tars[ivar-nvars].SetMin(x);
147 if (x > tars[ivar-nvars].GetMax()) tars[ivar-nvars].SetMax(x);
155 void TMVA::VarTransformHandler::CalcNorm()
157 const std::vector<TMVA::Event*>& events = fDataSetInfo.GetDataSet()->GetEventCollection();
159 const UInt_t nvars = fDataSetInfo.GetNVariables();
160 const UInt_t ntgts = fDataSetInfo.GetNTargets();
161 std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
162 std::vector<VariableInfo>& tars = fDataSetInfo.GetTargetInfos();
164 UInt_t nevts = events.size();
166 TVectorD x2( nvars+ntgts ); x2 *= 0;
167 TVectorD x0( nvars+ntgts ); x0 *= 0;
168 TVectorD v0( nvars+ntgts ); v0 *= 0;
170 Double_t sumOfWeights = 0;
171 for (UInt_t ievt=0; ievt<nevts; ievt++) {
172 const Event* ev = events[ievt];
174 Double_t weight = ev->GetWeight();
175 sumOfWeights += weight;
176 for (UInt_t ivar=0; ivar<nvars; ivar++) {
177 Double_t x = ev->GetValue(ivar);
179 vars[ivar].SetMin(x);
180 vars[ivar].SetMax(x);
183 UpdateNorm(ivar, x );
185 x0(ivar) += x*weight;
186 x2(ivar) += x*x*weight;
188 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
189 Double_t x = ev->GetTarget(itgt);
191 tars[itgt].SetMin(x);
192 tars[itgt].SetMax(x);
195 UpdateNorm( nvars+itgt, x );
197 x0(nvars+itgt) += x*weight;
198 x2(nvars+itgt) += x*x*weight;
202 if (sumOfWeights <= 0) {
203 Log() << kFATAL <<
" the sum of event weights calculated for your input is == 0"
204 <<
" or exactly: " << sumOfWeights <<
" there is obviously some problem..."<< Endl;
208 for (UInt_t ivar=0; ivar<nvars; ivar++) {
209 Double_t mean = x0(ivar)/sumOfWeights;
211 vars[ivar].SetMean( mean );
212 if (x2(ivar)/sumOfWeights - mean*mean < 0) {
213 Log() << kFATAL <<
" the RMS of your input variable " << ivar
214 <<
" evaluates to an imaginary number: sqrt("<< x2(ivar)/sumOfWeights - mean*mean
215 <<
") .. sometimes related to a problem with outliers and negative event weights"
218 vars[ivar].SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
220 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
221 Double_t mean = x0(nvars+itgt)/sumOfWeights;
222 tars[itgt].SetMean( mean );
223 if (x2(nvars+itgt)/sumOfWeights - mean*mean < 0) {
224 Log() << kFATAL <<
" the RMS of your target variable " << itgt
225 <<
" evaluates to an imaginary number: sqrt(" << x2(nvars+itgt)/sumOfWeights - mean*mean
226 <<
") .. sometimes related to a problem with outliers and negative event weights"
229 tars[itgt].SetRMS( TMath::Sqrt( x2(nvars+itgt)/sumOfWeights - mean*mean) );
233 for (UInt_t ievt=0; ievt<nevts; ievt++) {
234 const Event* ev = events[ievt];
235 Double_t weight = ev->GetWeight();
237 for (UInt_t ivar=0; ivar<nvars; ivar++) {
238 Double_t x = ev->GetValue(ivar);
239 Double_t mean = vars[ivar].GetMean();
240 v0(ivar) += weight*(x-mean)*(x-mean);
243 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
244 Double_t x = ev->GetTarget(itgt);
245 Double_t mean = tars[itgt].GetMean();
246 v0(nvars+itgt) += weight*(x-mean)*(x-mean);
250 Int_t maxL = fDataSetInfo.GetVariableNameMaxLength();
252 Log() << kINFO <<
"----------------------------------------------------------------" << Endl;
253 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) <<
"Variables";
254 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) <<
"Variance" << Endl;
255 Log() << kINFO <<
"----------------------------------------------------------------" << Endl;
258 Log() << std::setprecision(5);
259 for (UInt_t ivar=0; ivar<nvars; ivar++) {
260 Double_t variance = v0(ivar)/sumOfWeights;
261 vars[ivar].SetVariance( variance );
262 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << vars[ivar].GetExpression();
263 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
266 maxL = fDataSetInfo.GetTargetNameMaxLength();
268 Log() << kINFO <<
"----------------------------------------------------------------" << Endl;
269 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) <<
"Targets";
270 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) <<
"Variance" << Endl;
271 Log() << kINFO <<
"----------------------------------------------------------------" << Endl;
273 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
274 Double_t variance = v0(nvars+itgt)/sumOfWeights;
275 tars[itgt].SetVariance( variance );
276 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << tars[itgt].GetExpression();
277 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
280 Log() << kINFO <<
"Set minNorm/maxNorm for variables to: " << Endl;
281 Log() << std::setprecision(3);
282 for (UInt_t ivar=0; ivar<nvars; ivar++)
283 Log() <<
" " << vars[ivar].GetExpression()
284 <<
"\t: [" << vars[ivar].GetMin() <<
"\t, " << vars[ivar].GetMax() <<
"\t] " << Endl;
285 Log() << kINFO <<
"Set minNorm/maxNorm for targets to: " << Endl;
286 Log() << std::setprecision(3);
287 for (UInt_t itgt=0; itgt<ntgts; itgt++)
288 Log() <<
" " << tars[itgt].GetExpression()
289 <<
"\t: [" << tars[itgt].GetMin() <<
"\t, " << tars[itgt].GetMax() <<
"\t] " << Endl;
290 Log() << std::setprecision(5);
294 void TMVA::VarTransformHandler::CopyDataLoader(TMVA::DataLoader* des, TMVA::DataLoader* src)
296 for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();++treeinfo)
298 des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
301 for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();++treeinfo)
303 des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());