28 #ifndef ROOT_TMVA_DataLoader
29 #define ROOT_TMVA_DataLoader
47 class DataInputHandler;
50 class VariableTransformBase;
52 class DataLoader :
public Configurable {
55 DataLoader(TString thedlName=
"default");
58 virtual ~DataLoader();
62 void AddSignalTrainingEvent (
const std::vector<Double_t>& event, Double_t weight = 1.0 );
63 void AddBackgroundTrainingEvent(
const std::vector<Double_t>& event, Double_t weight = 1.0 );
64 void AddSignalTestEvent (
const std::vector<Double_t>& event, Double_t weight = 1.0 );
65 void AddBackgroundTestEvent (
const std::vector<Double_t>& event, Double_t weight = 1.0 );
66 void AddTrainingEvent(
const TString& className,
const std::vector<Double_t>& event, Double_t weight );
67 void AddTestEvent (
const TString& className,
const std::vector<Double_t>& event, Double_t weight );
68 void AddEvent (
const TString& className, Types::ETreeType tt,
const std::vector<Double_t>& event, Double_t weight );
69 Bool_t UserAssignEvents(UInt_t clIndex);
70 TTree* CreateEventAssignTrees(
const TString& name );
72 DataSetInfo& AddDataSet( DataSetInfo& );
73 DataSetInfo& AddDataSet(
const TString& );
74 DataSetInfo& GetDataSetInfo();
75 DataLoader* VarTransform(TString trafoDefinition);
80 void SetInputTrees(
const TString& signalFileName,
const TString& backgroundFileName,
81 Double_t signalWeight=1.0, Double_t backgroundWeight=1.0 );
82 void SetInputTrees( TTree* inputTree,
const TCut& SigCut,
const TCut& BgCut );
84 void SetInputTrees( TTree* signal, TTree* background,
85 Double_t signalWeight=1.0, Double_t backgroundWeight=1.0) ;
87 void AddSignalTree( TTree* signal, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
88 void AddSignalTree( TString datFileS, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
89 void AddSignalTree( TTree* signal, Double_t weight,
const TString& treetype );
92 void SetSignalTree( TTree* signal, Double_t weight=1.0);
94 void AddBackgroundTree( TTree* background, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
95 void AddBackgroundTree( TString datFileB, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
96 void AddBackgroundTree( TTree* background, Double_t weight,
const TString & treetype );
99 void SetBackgroundTree( TTree* background, Double_t weight=1.0 );
101 void SetSignalWeightExpression(
const TString& variable );
102 void SetBackgroundWeightExpression(
const TString& variable );
105 void AddRegressionTree( TTree* tree, Double_t weight = 1.0,
106 Types::ETreeType treetype = Types::kMaxTreeType ) {
107 AddTree( tree,
"Regression", weight,
"", treetype );
113 void SetTree( TTree* tree,
const TString& className, Double_t weight );
114 void AddTree( TTree* tree,
const TString& className, Double_t weight=1.0,
115 const TCut& cut =
"",
116 Types::ETreeType tt = Types::kMaxTreeType );
117 void AddTree( TTree* tree,
const TString& className, Double_t weight,
const TCut& cut,
const TString& treeType );
120 void SetInputVariables ( std::vector<TString>* theVariables );
122 void AddVariable (
const TString& expression,
const TString& title,
const TString& unit,
123 char type=
'F', Double_t min = 0, Double_t max = 0 );
124 void AddVariable (
const TString& expression,
char type=
'F',
125 Double_t min = 0, Double_t max = 0 );
128 void AddVariablesArray(
const TString &expression,
int size,
char type =
'F',
129 Double_t min = 0, Double_t max = 0);
132 void AddTarget (
const TString& expression,
const TString& title =
"",
const TString& unit =
"",
133 Double_t min = 0, Double_t max = 0 );
134 void AddRegressionTarget(
const TString& expression,
const TString& title =
"",
const TString& unit =
"",
135 Double_t min = 0, Double_t max = 0 )
137 AddTarget( expression, title, unit, min, max );
139 void AddSpectator (
const TString& expression,
const TString& title =
"",
const TString& unit =
"",
140 Double_t min = 0, Double_t max = 0 );
143 void SetWeightExpression(
const TString& variable,
const TString& className =
"" );
146 void SetCut(
const TString& cut,
const TString& className =
"" );
147 void SetCut(
const TCut& cut,
const TString& className =
"" );
148 void AddCut(
const TString& cut,
const TString& className =
"" );
149 void AddCut(
const TCut& cut,
const TString& className =
"" );
153 void PrepareTrainingAndTestTree(
const TCut& cut,
const TString& splitOpt );
154 void PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut,
const TString& splitOpt );
157 void PrepareTrainingAndTestTree(
const TCut& cut, Int_t Ntrain, Int_t Ntest = -1 );
159 void PrepareTrainingAndTestTree(
const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
160 const TString& otherOpt=
"SplitMode=Random:!V" );
163 void MakeKFoldDataSet(CvSplit & s);
164 void PrepareFoldDataSet(CvSplit & s, UInt_t foldNumber, Types::ETreeType tt = Types::kTraining);
165 void RecombineKFoldDataSet(CvSplit & s, Types::ETreeType tt = Types::kTraining);
167 const DataSetInfo& GetDefaultDataSetInfo(){
return DefaultDataSetInfo(); }
169 TH2* GetCorrelationMatrix(
const TString& className);
172 DataLoader* MakeCopy(TString name);
173 friend void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);
174 DataInputHandler& DataInput() {
return *fDataInputHandler; }
179 DataSetInfo& DefaultDataSetInfo();
180 void SetInputTreesFromEventAssignTrees();
188 DataSetManager* fDataSetManager;
191 DataInputHandler* fDataInputHandler;
193 std::vector<TMVA::VariableTransformBase*> fDefaultTrfs;
197 TString fTransformations;
201 enum DataAssignType { kUndefined = 0,
204 DataAssignType fDataAssignType;
205 std::vector<TTree*> fTrainAssignTree;
206 std::vector<TTree*> fTestAssignTree;
208 Int_t fATreeType = 0;
209 Float_t fATreeWeight = 0.0;
210 std::vector<Float_t> fATreeEvent;
212 Types::EAnalysisType fAnalysisType;
216 ClassDef(DataLoader,4);
218 void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);