Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
Objectives.hxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Web : http://tmva.sourceforge.net *
5  * *
6  * Description: *
7  * *
8  * Authors: *
9  * Stefan Wunsch (stefan.wunsch@cern.ch) *
10  * Luca Zampieri (luca.zampieri@alumni.epfl.ch) *
11  * *
12  * Copyright (c) 2019: *
13  * CERN, Switzerland *
14  * *
15  * Redistribution and use in source and binary forms, with or without *
16  * modification, are permitted according to the terms listed in LICENSE *
17  * (http://tmva.sourceforge.net/LICENSE) *
18  **********************************************************************************/
19 
20 #ifndef TMVA_TREEINFERENCE_OBJECTIVES
21 #define TMVA_TREEINFERENCE_OBJECTIVES
22 
23 #include <string>
24 #include <stdexcept>
25 #include <cmath> // std::exp
26 #include <functional> // std::function
27 
28 namespace TMVA {
29 namespace Experimental {
30 namespace Objectives {
31 
32 /// Logistic function f(x) = 1 / (1 + exp(-x))
33 template <typename T>
34 inline T Logistic(T value)
35 {
36  return 1.0 / (1.0 + std::exp(-1.0 * value));
37 }
38 
39 /// Identity function f(x) = x
40 template <typename T>
41 inline T Identity(T value)
42 {
43  return value;
44 }
45 
46 /// Natural exponential function f(x) = exp(x)
47 ///
48 /// This objective is used for the softmax objective in the multiclass
49 /// case with the formula exp(x)/sum(exp(x)) and the vector x.
50 template <typename T>
51 inline T Exponential(T value)
52 {
53  return std::exp(value);
54 }
55 
56 /// Get function pointer to implementation from name given as string
57 template <typename T>
58 std::function<T(T)> GetFunction(const std::string &name)
59 {
60  if (name.compare("identity") == 0)
61  return std::function<T(T)>(Identity<T>);
62  else if (name.compare("logistic") == 0)
63  return std::function<T(T)>(Logistic<T>);
64  else if (name.compare("softmax") == 0)
65  return std::function<T(T)>(Exponential<T>);
66  else
67  throw std::runtime_error("Objective function with name \"" + name + "\" is not implemented.");
68 }
69 
70 } // namespace Objectives
71 } // namespace Experimental
72 } // namespace TMVA
73 
74 #endif // TMVA_TREEINFERENCE_OBJECTIVES