41 static const Int_t UNINITIALIZED = -1;
45 ClassImp(TMVA::TNeuron);
50 TMVA::TNeuron::TNeuron()
58 TMVA::TNeuron::~TNeuron()
60 if (fLinksIn != NULL)
delete fLinksIn;
61 if (fLinksOut != NULL)
delete fLinksOut;
67 void TMVA::TNeuron::InitNeuron()
69 fLinksIn =
new TObjArray();
70 fLinksOut =
new TObjArray();
71 fValue = UNINITIALIZED;
72 fActivationValue = UNINITIALIZED;
73 fDelta = UNINITIALIZED;
74 fDEDw = UNINITIALIZED;
75 fError = UNINITIALIZED;
77 fForcedValue = kFALSE;
78 fInputCalculator = NULL;
84 void TMVA::TNeuron::ForceValue(Double_t value)
93 void TMVA::TNeuron::CalculateValue()
95 if (fForcedValue)
return;
96 fValue = fInputCalculator->GetInput(
this);
102 void TMVA::TNeuron::CalculateActivationValue()
104 if (fActivation == NULL) {
105 PrintMessage( kWARNING ,
"No activation equation specified." );
106 fActivationValue = UNINITIALIZED;
109 fActivationValue = fActivation->Eval(fValue);
115 void TMVA::TNeuron::CalculateDelta()
118 if (IsInputNeuron()) {
126 if (IsOutputNeuron()) error = fError;
131 TSynapse* synapse = NULL;
136 TObjArrayIter iter(fLinksOut);
138 synapse = (TSynapse*) iter.Next();
139 if (synapse == NULL)
break;
140 error += synapse->GetWeightedDelta();
145 fDelta = error * fActivation->EvalDerivative(GetValue());
151 void TMVA::TNeuron::SetInputCalculator(TNeuronInput* calculator)
153 if (fInputCalculator != NULL)
delete fInputCalculator;
154 fInputCalculator = calculator;
160 void TMVA::TNeuron::SetActivationEqn(TActivation* activation)
162 if (fActivation != NULL)
delete fActivation;
163 fActivation = activation;
169 void TMVA::TNeuron::AddPreLink(TSynapse* pre)
171 if (IsInputNeuron())
return;
178 void TMVA::TNeuron::AddPostLink(TSynapse* post)
180 if (IsOutputNeuron())
return;
181 fLinksOut->Add(post);
187 void TMVA::TNeuron::DeletePreLinks()
189 DeleteLinksArray(fLinksIn);
195 void TMVA::TNeuron::DeleteLinksArray(TObjArray*& links)
197 if (links == NULL)
return;
199 TSynapse* synapse = NULL;
200 Int_t numLinks = links->GetEntriesFast();
201 for (Int_t i=0; i<numLinks; i++) {
202 synapse = (TSynapse*)links->At(i);
203 if (synapse != NULL)
delete synapse;
212 void TMVA::TNeuron::SetError(Double_t error)
214 if (!IsOutputNeuron())
215 PrintMessage( kWARNING,
"Warning! Setting an error on a non-output neuron is probably not what you want to do." );
224 void TMVA::TNeuron::UpdateSynapsesBatch()
226 if (IsInputNeuron())
return;
228 TSynapse* synapse = NULL;
229 TObjArrayIter iter(fLinksIn);
231 synapse = (TSynapse*) iter.Next();
232 if (synapse == NULL)
break;
233 synapse->CalculateDelta();
242 void TMVA::TNeuron::UpdateSynapsesSequential()
244 if (IsInputNeuron())
return;
246 TSynapse* synapse = NULL;
247 TObjArrayIter iter(fLinksIn);
250 synapse = (TSynapse*) iter.Next();
251 if (synapse == NULL)
break;
252 synapse->InitDelta();
253 synapse->CalculateDelta();
254 synapse->AdjustWeight();
263 void TMVA::TNeuron::AdjustSynapseWeights()
265 if (IsInputNeuron())
return;
267 TSynapse* synapse = NULL;
268 TObjArrayIter iter(fLinksIn);
272 synapse = (TSynapse*) iter.Next();
273 if (synapse == NULL)
break;
274 synapse->AdjustWeight();
283 void TMVA::TNeuron::InitSynapseDeltas()
286 if (IsInputNeuron())
return;
288 TSynapse* synapse = NULL;
289 TObjArrayIter iter(fLinksIn);
292 synapse = (TSynapse*) iter.Next();
294 if (synapse == NULL)
break;
295 synapse->InitDelta();
303 void TMVA::TNeuron::PrintLinks(TObjArray* links)
const
306 Log() << kDEBUG <<
"\t\t\t<none>" << Endl;
312 Int_t numLinks = links->GetEntriesFast();
313 for (Int_t i = 0; i < numLinks; i++) {
314 synapse = (TSynapse*)links->At(i);
316 "\t\t\tweighta: " << synapse->GetWeight()
317 <<
"\t\tw-value: " << synapse->GetWeightedValue()
318 <<
"\t\tw-delta: " << synapse->GetWeightedDelta()
319 <<
"\t\tl-rate: " << synapse->GetLearningRate()
327 void TMVA::TNeuron::PrintActivationEqn()
329 if (fActivation != NULL) Log() << kDEBUG << fActivation->GetExpression() << Endl;
330 else Log() << kDEBUG <<
"<none>" << Endl;
336 void TMVA::TNeuron::PrintMessage( EMsgType type, TString message)
338 Log() << type << message << Endl;
343 TMVA::MsgLogger& TMVA::TNeuron::Log()
const
345 TTHREAD_TLS_DECL_ARG2(MsgLogger,logger,
"TNeuron",kDEBUG);