26 REGISTER_METHOD(CrossValidation)
28 ClassImp(TMVA::MethodCrossValidation);
33 TMVA::MethodCrossValidation::MethodCrossValidation(const TString &jobName, const TString &methodTitle,
34 DataSetInfo &theData, const TString &theOption)
35 : TMVA::MethodBase(jobName, Types::kCrossValidation, methodTitle, theData, theOption), fSplitExpr(
nullptr)
41 TMVA::MethodCrossValidation::MethodCrossValidation(DataSetInfo &theData,
const TString &theWeightFile)
42 : TMVA::MethodBase(Types::kCrossValidation, theData, theWeightFile), fSplitExpr(nullptr)
50 TMVA::MethodCrossValidation::~MethodCrossValidation(
void) {}
54 void TMVA::MethodCrossValidation::DeclareOptions()
56 DeclareOptionRef(fEncapsulatedMethodName,
"EncapsulatedMethodName",
"");
57 DeclareOptionRef(fEncapsulatedMethodTypeName,
"EncapsulatedMethodTypeName",
"");
58 DeclareOptionRef(fNumFolds,
"NumFolds",
"Number of folds to generate");
59 DeclareOptionRef(fOutputEnsembling = TString(
"None"),
"OutputEnsembling",
60 "Combines output from contained methods. If None, no combination is performed. (default None)");
61 AddPreDefVal(TString(
"None"));
62 AddPreDefVal(TString(
"Avg"));
63 DeclareOptionRef(fSplitExprString,
"SplitExpr",
"The expression used to assign events to folds");
69 void TMVA::MethodCrossValidation::DeclareCompatibilityOptions()
71 MethodBase::DeclareCompatibilityOptions();
77 void TMVA::MethodCrossValidation::ProcessOptions()
79 Log() << kDEBUG <<
"ProcessOptions -- fNumFolds: " << fNumFolds << Endl;
80 Log() << kDEBUG <<
"ProcessOptions -- fEncapsulatedMethodName: " << fEncapsulatedMethodName << Endl;
81 Log() << kDEBUG <<
"ProcessOptions -- fEncapsulatedMethodTypeName: " << fEncapsulatedMethodTypeName << Endl;
83 if (fSplitExprString != TString(
"")) {
84 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(
new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
87 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
88 TString weightfile = GetWeightFileNameForFold(iFold);
90 Log() << kINFO <<
"Reading weightfile: " << weightfile << Endl;
91 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
92 fEncapsulatedMethods.push_back(fold_method);
99 void TMVA::MethodCrossValidation::Init(
void)
101 fMulticlassValues = std::vector<Float_t>(DataInfo().GetNClasses());
102 fRegressionValues = std::vector<Float_t>(DataInfo().GetNTargets());
108 void TMVA::MethodCrossValidation::Reset(
void) {}
114 TString TMVA::MethodCrossValidation::GetWeightFileNameForFold(UInt_t iFold)
const
116 if (iFold >= fNumFolds) {
117 Log() << kFATAL << iFold <<
" out of range. "
118 <<
"Should be < " << fNumFolds <<
"." << Endl;
121 TString foldStr = Form(
"fold%i", iFold + 1);
122 TString fileDir = gSystem->DirName(GetWeightFileName());
123 TString weightfile = fileDir +
"/" + fJobName +
"_" + fEncapsulatedMethodName +
"_" + foldStr +
".weights.xml";
146 void TMVA::MethodCrossValidation::Train() {}
154 TMVA::MethodCrossValidation::InstantiateMethodFromXML(TString methodTypeName, TString weightfile)
const
156 TMVA::MethodBase *m =
dynamic_cast<MethodBase *
>(
157 ClassifierFactory::Instance().Create(std::string(methodTypeName.Data()), DataInfo(), weightfile));
159 if (m->GetMethodType() == Types::kCategory) {
160 Log() << kFATAL <<
"MethodCategory not supported for the moment." << Endl;
163 TString fileDir = TString(DataInfo().GetName()) +
"/" + gConfig().GetIONames().fWeightFileDir;
164 m->SetWeightFileDir(fileDir);
167 m->SetAnalysisType(fAnalysisType);
169 m->ReadStateFromFile();
178 void TMVA::MethodCrossValidation::AddWeightsXMLTo(
void *parent)
const
180 void *wght = gTools().AddChild(parent,
"Weights");
182 gTools().AddAttr(wght,
"JobName", fJobName);
183 gTools().AddAttr(wght,
"SplitExpr", fSplitExprString);
184 gTools().AddAttr(wght,
"NumFolds", fNumFolds);
185 gTools().AddAttr(wght,
"EncapsulatedMethodName", fEncapsulatedMethodName);
186 gTools().AddAttr(wght,
"EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
187 gTools().AddAttr(wght,
"OutputEnsembling", fOutputEnsembling);
189 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
190 TString weightfile = GetWeightFileNameForFold(iFold);
208 void TMVA::MethodCrossValidation::ReadWeightsFromXML(
void *parent)
210 gTools().ReadAttr(parent,
"JobName", fJobName);
211 gTools().ReadAttr(parent,
"SplitExpr", fSplitExprString);
212 gTools().ReadAttr(parent,
"NumFolds", fNumFolds);
213 gTools().ReadAttr(parent,
"EncapsulatedMethodName", fEncapsulatedMethodName);
214 gTools().ReadAttr(parent,
"EncapsulatedMethodTypeName", fEncapsulatedMethodTypeName);
215 gTools().ReadAttr(parent,
"OutputEnsembling", fOutputEnsembling);
218 for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
219 TString weightfile = GetWeightFileNameForFold(iFold);
221 Log() << kINFO <<
"Reading weightfile: " << weightfile << Endl;
222 MethodBase *fold_method = InstantiateMethodFromXML(fEncapsulatedMethodTypeName, weightfile);
223 fEncapsulatedMethods.push_back(fold_method);
227 if (fSplitExprString != TString(
"")) {
228 fSplitExpr = std::unique_ptr<CvSplitKFoldsExpr>(
new CvSplitKFoldsExpr(DataInfo(), fSplitExprString));
236 void TMVA::MethodCrossValidation::ReadWeightsFromStream(std::istream & )
238 Log() << kFATAL <<
"CrossValidation currently supports only reading from XML." << Endl;
244 Double_t TMVA::MethodCrossValidation::GetMvaValue(Double_t *err, Double_t *errUpper)
246 const Event *ev = GetEvent();
248 if (fOutputEnsembling ==
"None") {
249 if (fSplitExpr !=
nullptr) {
251 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
252 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
255 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
256 return fEncapsulatedMethods.at(iFold)->GetMvaValue(err, errUpper);
258 }
else if (fOutputEnsembling ==
"Avg") {
260 for (
auto &m : fEncapsulatedMethods) {
261 val += m->GetMvaValue(err, errUpper);
263 return val / fEncapsulatedMethods.size();
265 Log() << kFATAL <<
"Ensembling type " << fOutputEnsembling <<
" unknown" << Endl;
273 const std::vector<Float_t> &TMVA::MethodCrossValidation::GetMulticlassValues()
275 const Event *ev = GetEvent();
277 if (fOutputEnsembling ==
"None") {
278 if (fSplitExpr !=
nullptr) {
280 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
281 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
284 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
285 return fEncapsulatedMethods.at(iFold)->GetMulticlassValues();
287 }
else if (fOutputEnsembling ==
"Avg") {
289 for (
auto &e : fMulticlassValues) {
293 for (
auto &m : fEncapsulatedMethods) {
294 auto methodValues = m->GetMulticlassValues();
295 for (
size_t i = 0; i < methodValues.size(); ++i) {
296 fMulticlassValues[i] += methodValues[i];
300 for (
auto &e : fMulticlassValues) {
301 e /= fEncapsulatedMethods.size();
304 return fMulticlassValues;
307 Log() << kFATAL <<
"Ensembling type " << fOutputEnsembling <<
" unknown" << Endl;
308 return fMulticlassValues;
315 const std::vector<Float_t> &TMVA::MethodCrossValidation::GetRegressionValues()
317 const Event *ev = GetEvent();
319 if (fOutputEnsembling ==
"None") {
320 if (fSplitExpr !=
nullptr) {
322 UInt_t iFold = fSplitExpr->Eval(fNumFolds, ev);
323 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
326 UInt_t iFold = fEventToFoldMapping.at(Data()->GetEvent());
327 return fEncapsulatedMethods.at(iFold)->GetRegressionValues();
329 }
else if (fOutputEnsembling ==
"Avg") {
331 for (
auto &e : fRegressionValues) {
335 for (
auto &m : fEncapsulatedMethods) {
336 auto methodValues = m->GetRegressionValues();
337 for (
size_t i = 0; i < methodValues.size(); ++i) {
338 fRegressionValues[i] += methodValues[i];
342 for (
auto &e : fRegressionValues) {
343 e /= fEncapsulatedMethods.size();
346 return fRegressionValues;
349 Log() << kFATAL <<
"Ensembling type " << fOutputEnsembling <<
" unknown" << Endl;
350 return fRegressionValues;
357 void TMVA::MethodCrossValidation::WriteMonitoringHistosToFile(
void)
const
368 void TMVA::MethodCrossValidation::GetHelpMessage()
const
371 <<
"Method CrossValidation should not be created manually,"
372 " only as part of using TMVA::Reader."
379 const TMVA::Ranking *TMVA::MethodCrossValidation::CreateRanking()
386 Bool_t TMVA::MethodCrossValidation::HasAnalysisType(Types::EAnalysisType , UInt_t ,
398 void TMVA::MethodCrossValidation::MakeClassSpecific(std::ostream & ,
const TString & )
const
400 Log() << kWARNING <<
"MakeClassSpecific not implemented for CrossValidation" << Endl;
406 void TMVA::MethodCrossValidation::MakeClassSpecificHeader(std::ostream & ,
const TString & )
const
408 Log() << kWARNING <<
"MakeClassSpecificHeader not implemented for CrossValidation" << Endl;