Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
TActivationTanh.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Matt Jachowski
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : TActivationTanh *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Tanh activation function (sigmoid normalized in [-1,1] for an ANN. *
12  * *
13  * Authors (alphabetical): *
14  * Matt Jachowski <jachowski@stanford.edu> - Stanford University, USA *
15  * *
16  * Copyright (c) 2005: *
17  * CERN, Switzerland *
18  * *
19  * Redistribution and use in source and binary forms, with or without *
20  * modification, are permitted according to the terms listed in LICENSE *
21  * (http://tmva.sourceforge.net/LICENSE) *
22  **********************************************************************************/
23 
24 /*! \class TMVA::TActivationTanh
25 \ingroup TMVA
26 Tanh activation function for ANN.
27 */
28 
29 #include "TMVA/TActivationTanh.h"
30 
31 #include "TMVA/TActivation.h"
32 
33 #include "TMath.h"
34 #include "TString.h"
35 
36 #include <iostream>
37 
38 ClassImp(TMVA::TActivationTanh);
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 /// a fast tanh approximation
42 
43 Double_t TMVA::TActivationTanh::fast_tanh(Double_t arg){
44  if (arg > 4.97) return 1;
45  if (arg < -4.97) return -1;
46  float arg2 = arg * arg;
47  float a = arg * (135135.0f + arg2 * (17325.0f + arg2 * (378.0f + arg2)));
48  float b = 135135.0f + arg2 * (62370.0f + arg2 * (3150.0f + arg2 * 28.0f));
49  return a/b;
50 }
51 
52 ////////////////////////////////////////////////////////////////////////////////
53 /// evaluate the tanh
54 
55 Double_t TMVA::TActivationTanh::Eval(Double_t arg)
56 {
57  return fFAST ? fast_tanh(arg) : TMath::TanH(arg);
58 }
59 
60 ////////////////////////////////////////////////////////////////////////////////
61 /// evaluate the derivative
62 
63 Double_t TMVA::TActivationTanh::EvalDerivative(Double_t arg)
64 {
65  Double_t tmp=Eval(arg);
66  return ( 1-tmp*tmp);
67 }
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 /// get expressions for the tanh and its derivative
71 /// whatever that may be good for ...
72 
73 TString TMVA::TActivationTanh::GetExpression()
74 {
75  TString expr = "tanh(x)\t\t (1-tanh()^2)";
76  return expr;
77 }
78 
79 ////////////////////////////////////////////////////////////////////////////////
80 /// writes the Tanh sigmoid activation function source code
81 
82 void TMVA::TActivationTanh::MakeFunction( std::ostream& fout, const TString& fncName )
83 {
84  if (fFAST) {
85  fout << "double " << fncName << "(double x) const {" << std::endl;
86  fout << " // fast hyperbolic tan approximation" << std::endl;
87  fout << " if (x > 4.97) return 1;" << std::endl;
88  fout << " if (x < -4.97) return -1;" << std::endl;
89  fout << " float x2 = x * x;" << std::endl;
90  fout << " float a = x * (135135.0f + x2 * (17325.0f + x2 * (378.0f + x2)));" << std::endl;
91  fout << " float b = 135135.0f + x2 * (62370.0f + x2 * (3150.0f + x2 * 28.0f));" << std::endl;
92  fout << " return a / b;" << std::endl;
93  fout << "}" << std::endl;
94  } else {
95  fout << "double " << fncName << "(double x) const {" << std::endl;
96  fout << " // hyperbolic tan" << std::endl;
97  fout << " return tanh(x);" << std::endl;
98  fout << "}" << std::endl;
99  }
100 }