11from random import randrange
22
33from tdamapper .utils .metrics import get_metric
4- from tdamapper .utils .quickselect import quickselect
54from tdamapper .utils .heap import MaxHeap
65
76
@@ -13,6 +12,31 @@ def _mid(start, end):
1312 return (start + end ) // 2
1413
1514
15+ def _partition (data , start , end , p_ord ):
16+ higher = start
17+ for j in range (start , end ):
18+ j_ord , _ , _ = data [j ]
19+ if j_ord < p_ord :
20+ _swap (data , higher , j )
21+ higher += 1
22+ return higher
23+
24+
25+ def _quickselect (data , start , end , k ):
26+ if (k < start ) or (k >= end ):
27+ return
28+ start_ , end_ , higher = start , end , None
29+ while higher != k + 1 :
30+ p , _ , _ = data [k ]
31+ _swap (data , start_ , k )
32+ higher = _partition (data , start_ + 1 , end_ , p )
33+ _swap (data , start_ , higher - 1 )
34+ if k <= higher - 1 :
35+ end_ = higher
36+ else :
37+ start_ = higher
38+
39+
1640class VPTree :
1741
1842 def __init__ (
@@ -57,7 +81,7 @@ class _Build:
5781
5882 def __init__ (self , vpt , X ):
5983 self .__distance = vpt ._get_distance ()
60- self .__dataset = [(0.0 , x ) for x in X ]
84+ self .__dataset = [(0.0 , x , False ) for x in X ]
6185 self .__leaf_capacity = vpt .get_leaf_capacity ()
6286 self .__leaf_radius = vpt .get_leaf_radius ()
6387 pivoting = vpt .get_pivoting ()
@@ -71,7 +95,7 @@ def _pivoting_disabled(self, start, end):
7195 pass
7296
7397 def _pivoting_random (self , start , end ):
74- if end - start < 2 :
98+ if end <= start :
7599 return
76100 pivot = randrange (start , end )
77101 if pivot > start :
@@ -80,17 +104,17 @@ def _pivoting_random(self, start, end):
80104 def _furthest (self , start , end , i ):
81105 furthest_dist = 0.0
82106 furthest = start
83- _ , i_point = self .__dataset [i ]
107+ _ , i_point , _ = self .__dataset [i ]
84108 for j in range (start , end ):
85- _ , j_point = self .__dataset [j ]
109+ _ , j_point , _ = self .__dataset [j ]
86110 j_dist = self .__distance (i_point , j_point )
87111 if j_dist > furthest_dist :
88112 furthest = j
89113 furthest_dist = j_dist
90114 return furthest
91115
92116 def _pivoting_furthest (self , start , end ):
93- if end - start < 2 :
117+ if end <= start :
94118 return
95119 rnd = randrange (start , end )
96120 furthest_rnd = self ._furthest (start , end , rnd )
@@ -100,10 +124,10 @@ def _pivoting_furthest(self, start, end):
100124
101125 def _update (self , start , end ):
102126 self .__pivoting (start , end )
103- _ , v_point = self .__dataset [start ]
104- for i in range (start , end ):
105- _ , point = self .__dataset [i ]
106- self .__dataset [i ] = self .__distance (v_point , point ), point
127+ _ , v_point , is_terminal = self .__dataset [start ]
128+ for i in range (start + 1 , end ):
129+ _ , point , _ = self .__dataset [i ]
130+ self .__dataset [i ] = self .__distance (v_point , point ), point , is_terminal
107131
108132 def build (self ):
109133 self ._build_iter ()
@@ -115,159 +139,122 @@ def _build_iter(self):
115139 start , end = stack .pop ()
116140 mid = _mid (start , end )
117141 self ._update (start , end )
118- _ , v_point = self .__dataset [start ]
119- quickselect (self .__dataset , start + 1 , end , mid )
120- v_radius , _ = self .__dataset [mid ]
121- self .__dataset [start ] = (v_radius , v_point )
122- if (end - start > 2 * self .__leaf_capacity ) and (v_radius > self .__leaf_radius ):
142+ _ , v_point , _ = self .__dataset [start ]
143+ _quickselect (self .__dataset , start + 1 , end , mid )
144+ v_radius , _ , _ = self .__dataset [mid ]
145+ if (
146+ (end - start > 2 * self .__leaf_capacity ) and
147+ (v_radius > self .__leaf_radius )
148+ ):
149+ self .__dataset [start ] = (v_radius , v_point , False )
123150 stack .append ((mid , end ))
124151 stack .append ((start + 1 , mid ))
152+ else :
153+ self .__dataset [start ] = (v_radius , v_point , True )
125154
126155 def ball_search (self , point , eps , inclusive = True ):
127156 return self ._BallSearch (self , point , eps , inclusive ).search ()
128157
129158 class _BallSearch :
130159
131- def __init__ (self , vpt , center , radius , inclusive ):
132- self .__distance = vpt ._get_distance ()
160+ def __init__ (self , vpt , point , eps , inclusive = True ):
133161 self .__dataset = vpt ._get_dataset ()
134- self .__leaf_capacity = vpt ._get_leaf_capacity ()
135- self .__leaf_radius = vpt ._get_leaf_radius ()
136- self .__center = center
137- self .__radius = radius
138- self .__items = []
162+ self .__distance = vpt ._get_distance ()
163+ self .__point = point
164+ self .__eps = eps
139165 self .__inclusive = inclusive
140166
141- class _BSVisit :
142-
143- def __init__ (self , start , end , m_radius ):
144- self .__start = start
145- self .__end = end
146- self .__m_radius = m_radius
167+ def search (self ):
168+ return self ._search_iter ()
147169
148- def bounds (self ):
149- return self .__start , self .__end , self .__m_radius
170+ def _inside (self , dist ):
171+ if self .__inclusive :
172+ return dist <= self .__eps
173+ return dist < self .__eps
150174
151- def search (self ):
152- stack = [self ._BSVisit (0 , len (self .__dataset ), float ('inf' ))]
175+ def _search_iter (self ):
176+ stack = [(0 , len (self .__dataset ))]
177+ result = []
153178 while stack :
154- visit = stack .pop ()
155- start , end , m_radius = visit .bounds ()
156- v_radius , v_point = self .__dataset [start ]
157- if (end - start <= 2 * self .__leaf_capacity ) or (m_radius <= self .__leaf_radius ) or (v_radius <= self .__leaf_radius ):
158- for _ , x in self .__dataset [start :end ]:
159- dist = self .__distance (self .__center , x )
179+ start , end = stack .pop ()
180+ v_radius , v_point , is_terminal = self .__dataset [start ]
181+ if is_terminal :
182+ for _ , x , _ in self .__dataset [start :end ]:
183+ dist = self .__distance (self .__point , x )
160184 if self ._inside (dist ):
161- self . __items .append (x )
185+ result .append (x )
162186 else :
163- dist = self .__distance (self .__center , v_point )
164- if self ._inside (dist ):
165- self .__items .append (v_point )
187+ dist = self .__distance (self .__point , v_point )
166188 mid = _mid (start , end )
189+ if self ._inside (dist ):
190+ result .append (v_point )
167191 if dist <= v_radius :
168- fst_start , fst_end , fst_radius = start + 1 , mid , v_radius
169- snd_start , snd_end , snd_radius = mid , end , float ( 'inf' )
192+ fst = ( start + 1 , mid )
193+ snd = ( mid , end )
170194 else :
171- fst_start , fst_end , fst_radius = mid , end , float ( 'inf' )
172- snd_start , snd_end , snd_radius = start + 1 , mid , v_radius
173- if abs (dist - v_radius ) <= self .__radius :
174- stack .append (self . _BSVisit ( snd_start , snd_end , snd_radius ) )
175- stack .append (self . _BSVisit ( fst_start , fst_end , fst_radius ) )
176- return self . __items
195+ fst = ( mid , end )
196+ snd = ( start + 1 , mid )
197+ if abs (dist - v_radius ) <= self .__eps :
198+ stack .append (snd )
199+ stack .append (fst )
200+ return result
177201
178- def _inside (self , dist ):
179- if self .__inclusive :
180- return dist <= self .__radius
181- return dist < self .__radius
202+ def knn_search (self , point , k ):
203+ return self ._KnnSearch (self , point , k ).search ()
182204
183- def knn_search (self , point , neighbors ):
184- return self ._KNNSearch (self , point , neighbors ).search ()
205+ class _KnnSearch :
185206
186- class _KNNSearch :
187-
188- def __init__ (self , vpt , center , neighbors ):
189- self .__distance = vpt ._VPTree__distance
190- self .__dataset = vpt ._VPTree__dataset
191- self .__leaf_capacity = vpt ._VPTree__leaf_capacity
192- self .__leaf_radius = vpt ._VPTree__leaf_radius
193- self .__center = center
207+ def __init__ (self , vpt , point , neighbors ):
208+ self .__dataset = vpt ._get_dataset ()
209+ self .__distance = vpt ._get_distance ()
210+ self .__point = point
194211 self .__neighbors = neighbors
195- self .__items = MaxHeap ()
212+ self .__radius = float ('inf' )
213+ self .__result = MaxHeap ()
196214
197215 def _get_items (self ):
198- while len (self .__items ) > self .__neighbors :
199- self .__items .pop ()
200- return [x for (_ , x ) in self .__items ]
216+ while len (self .__result ) > self .__neighbors :
217+ self .__result .pop ()
218+ return [x for (_ , x ) in self .__result ]
201219
202- def _get_radius (self ):
203- if len (self .__items ) < self .__neighbors :
204- return float ('inf' )
205- furthest_dist , _ = self .__items .top ()
206- return furthest_dist
220+ def search (self ):
221+ self ._search_iter ()
222+ return self ._get_items ()
207223
208224 def _process (self , x ):
209- dist = self .__distance (self .__center , x )
210- if dist >= self ._get_radius () :
225+ dist = self .__distance (self .__point , x )
226+ if dist >= self .__radius :
211227 return dist
212- self .__items .add (dist , x )
213- while len (self .__items ) > self .__neighbors :
214- self .__items .pop ()
228+ self .__result .add (dist , x )
229+ while len (self .__result ) > self .__neighbors :
230+ self .__result .pop ()
231+ if len (self .__result ) == self .__neighbors :
232+ self .__radius , _ = self .__result .top ()
215233 return dist
216234
217- def pre (self , pre , stack ):
218- start , end , _ = pre .bounds ()
219- v_radius , v_point = self .__dataset [start ]
220- dist = self ._process (v_point )
221- mid = _mid (start , end )
222- if dist <= v_radius :
223- fst_start , fst_end , fst_radius = start + 1 , mid , v_radius
224- snd_start , snd_end , snd_radius = mid , end , float ('inf' )
225- else :
226- fst_start , fst_end , fst_radius = mid , end , float ('inf' )
227- snd_start , snd_end , snd_radius = start + 1 , mid , v_radius
228- stack .append ((self ._KVPost (snd_start , snd_end , snd_radius , dist , v_radius ), self .post ))
229- stack .append ((self ._KVPre (fst_start , fst_end , fst_radius ), self .pre ))
230-
231- def post (self , post , stack ):
232- start , end , _ = post .bounds ()
233- m_radius , dist , v_radius = post .rad ()
234- if abs (dist - v_radius ) <= self ._get_radius ():
235- stack .append ((self ._KVPre (start , end , m_radius ), self .pre ))
236-
237- def search (self ):
238- stack = [(self ._KVPre (0 , len (self .__dataset ), float ('inf' )), self .pre )]
235+ def _search_iter (self ):
236+ PRE , POST = 0 , 1
237+ self .__result = MaxHeap ()
238+ stack = [(0 , len (self .__dataset ), 0.0 , PRE )]
239239 while stack :
240- visit , after = stack .pop ()
241- start , end , m_radius = visit .bounds ()
242- v_radius , _ = self .__dataset [start ]
243- if (end - start <= 2 * self .__leaf_capacity ) or (m_radius <= self .__leaf_radius ) or (v_radius <= self .__leaf_radius ):
244- for _ , x in self .__dataset [start :end ]:
240+ start , end , thr , action = stack .pop ()
241+ v_radius , v_point , is_terminal = self .__dataset [start ]
242+ if is_terminal :
243+ for _ , x , _ in self .__dataset [start :end ]:
245244 self ._process (x )
246245 else :
247- after (visit , stack )
246+ if action == PRE :
247+ mid = _mid (start , end )
248+ dist = self ._process (v_point )
249+ if dist <= v_radius :
250+ fst_start , fst_end = start + 1 , mid
251+ snd_start , snd_end = mid , end
252+ else :
253+ fst_start , fst_end = mid , end
254+ snd_start , snd_end = start + 1 , mid
255+ stack .append ((snd_start , snd_end , abs (v_radius - dist ), POST ))
256+ stack .append ((fst_start , fst_end , 0.0 , PRE ))
257+ elif action == POST :
258+ if self .__radius > thr :
259+ stack .append ((start , end , 0.0 , PRE ))
248260 return self ._get_items ()
249-
250- class _KVPre :
251-
252- def __init__ (self , start , end , m_radius ):
253- self .__start = start
254- self .__end = end
255- self .__m_radius = m_radius
256-
257- def bounds (self ):
258- return self .__start , self .__end , self .__m_radius
259-
260- class _KVPost :
261-
262- def __init__ (self , start , end , m_radius , dist , v_radius ):
263- self .__start = start
264- self .__end = end
265- self .__m_radius = m_radius
266- self .__dist = dist
267- self .__v_radius = v_radius
268-
269- def bounds (self ):
270- return self .__start , self .__end , self .__m_radius
271-
272- def rad (self ):
273- return self .__m_radius , self .__dist , self .__v_radius
0 commit comments