Logo ROOT   6.30.04
Reference Guide
 All Namespaces Files Pages
ModulekNN.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ModulekNN *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Module for k-nearest neighbor algorithm *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 #ifndef ROOT_TMVA_ModulekNN
27 #define ROOT_TMVA_ModulekNN
28 
29 //______________________________________________________________________
30 /*
31  kNN::Event describes point in input variable vector-space, with
32  additional functionality like distance between points
33 */
34 //______________________________________________________________________
35 
36 
37 // C++
38 #include <cassert>
39 #include <iosfwd>
40 #include <map>
41 #include <string>
42 #include <vector>
43 
44 // ROOT
45 #include "Rtypes.h"
46 #include "TRandom3.h"
47 #include "ThreadLocalStorage.h"
48 #include "TMVA/NodekNN.h"
49 
50 namespace TMVA {
51 
52  class MsgLogger;
53 
54  namespace kNN {
55 
56  typedef Float_t VarType;
57  typedef std::vector<VarType> VarVec;
58 
59  class Event {
60  public:
61 
62  Event();
63  Event(const VarVec &vec, Double_t weight, Short_t type);
64  Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec);
65  ~Event();
66 
67  Double_t GetWeight() const;
68 
69  VarType GetVar(UInt_t i) const;
70  VarType GetTgt(UInt_t i) const;
71 
72  UInt_t GetNVar() const;
73  UInt_t GetNTgt() const;
74 
75  Short_t GetType() const;
76 
77  // keep these two function separate
78  VarType GetDist(VarType var, UInt_t ivar) const;
79  VarType GetDist(const Event &other) const;
80 
81  void SetTargets(const VarVec &tvec);
82  const VarVec& GetTargets() const;
83  const VarVec& GetVars() const;
84 
85  void Print() const;
86  void Print(std::ostream& os) const;
87 
88  private:
89 
90  VarVec fVar; // coordinates (variables) for knn search
91  VarVec fTgt; // targets for regression analysis
92 
93  Double_t fWeight; // event weight
94  Short_t fType; // event type ==0 or == 1, expand it to arbitrary class types?
95  };
96 
97  typedef std::vector<TMVA::kNN::Event> EventVec;
98  typedef std::pair<const Node<Event> *, VarType> Elem;
99  typedef std::list<Elem> List;
100 
101  std::ostream& operator<<(std::ostream& os, const Event& event);
102 
103  class ModulekNN
104  {
105  public:
106 
107  typedef std::map<int, std::vector<Double_t> > VarMap;
108 
109  public:
110 
111  ModulekNN();
112  ~ModulekNN();
113 
114  void Clear();
115 
116  void Add(const Event &event);
117 
118  Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = "");
119 
120  Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const;
121  Bool_t Find(UInt_t nfind, const std::string &option) const;
122 
123  const EventVec& GetEventVec() const;
124 
125  const List& GetkNNList() const;
126  const Event& GetkNNEvent() const;
127 
128  const VarMap& GetVarMap() const;
129 
130  const std::map<Int_t, Double_t>& GetMetric() const;
131 
132  void Print() const;
133  void Print(std::ostream &os) const;
134 
135  private:
136 
137  Node<Event>* Optimize(UInt_t optimize_depth);
138 
139  void ComputeMetric(UInt_t ifrac);
140 
141  const Event Scale(const Event &event) const;
142 
143  private:
144 
145  // This is a workaround for OSx where static thread_local data members are
146  // not supported. The C++ solution would indeed be the following:
147  static TRandom3& GetRndmThreadLocal() {TTHREAD_TLS_DECL_ARG(TRandom3,fgRndm,1); return fgRndm;};
148 
149  UInt_t fDimn;
150 
151  Node<Event> *fTree;
152 
153  std::map<Int_t, Double_t> fVarScale;
154 
155  mutable List fkNNList; // latest result from kNN search
156  mutable Event fkNNEvent; // latest event used for kNN search
157 
158  std::map<Short_t, UInt_t> fCount; // count number of events of each type
159 
160  EventVec fEvent; // vector of all events used to build tree and analysis
161  VarMap fVar; // sorted map of variables in each dimension for all event types
162 
163  mutable MsgLogger* fLogger; // message logger
164  MsgLogger& Log() const { return *fLogger; }
165  };
166 
167  //
168  // inlined functions for Event class
169  //
170  inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const
171  {
172  const VarType var2 = GetVar(ivar);
173  return (var1 - var2) * (var1 - var2);
174  }
175  inline Double_t Event::GetWeight() const
176  {
177  return fWeight;
178  }
179  inline VarType Event::GetVar(const UInt_t i) const
180  {
181  return fVar[i];
182  }
183  inline VarType Event::GetTgt(const UInt_t i) const
184  {
185  return fTgt[i];
186  }
187 
188  inline UInt_t Event::GetNVar() const
189  {
190  return fVar.size();
191  }
192  inline UInt_t Event::GetNTgt() const
193  {
194  return fTgt.size();
195  }
196  inline Short_t Event::GetType() const
197  {
198  return fType;
199  }
200 
201  //
202  // inline functions for ModulekNN class
203  //
204  inline const List& ModulekNN::GetkNNList() const
205  {
206  return fkNNList;
207  }
208  inline const Event& ModulekNN::GetkNNEvent() const
209  {
210  return fkNNEvent;
211  }
212  inline const EventVec& ModulekNN::GetEventVec() const
213  {
214  return fEvent;
215  }
216  inline const ModulekNN::VarMap& ModulekNN::GetVarMap() const
217  {
218  return fVar;
219  }
220  inline const std::map<Int_t, Double_t>& ModulekNN::GetMetric() const
221  {
222  return fVarScale;
223  }
224 
225  } // end of kNN namespace
226 } // end of TMVA namespace
227 
228 #endif
229