11from random import randrange
22
33import numpy as np
4- from numba import njit
54
65from tdamapper .utils .heap import MaxHeap
76from tdamapper .utils .metrics import get_metric
8-
9-
10- @njit
11- def _swap (arr , i , j ):
12- arr [i ], arr [j ] = arr [j ], arr [i ]
7+ from tdamapper .utils .quickselect import quickselect , swap_all
138
149
1510def _mid (start , end ):
1611 return (start + end ) // 2
1712
1813
19- @njit
20- def _partition (distances , indices , is_terminal , start , end , p_ord ):
21- higher = start
22- for j in range (start , end ):
23- j_ord = distances [j ]
24- if j_ord < p_ord :
25- _swap (distances , higher , j )
26- _swap (indices , higher , j )
27- _swap (is_terminal , higher , j )
28- higher += 1
29- return higher
30-
31-
32- def _quickselect (distances , indices , is_terminal , start , end , k ):
33- if (k < start ) or (k >= end ):
34- return
35- start_ , end_ , higher = start , end , None
36- while higher != k + 1 :
37- # TODO: pivot_index = randrange(start_, end_)
38- pivot_index = k
39-
40- p = distances [pivot_index ]
41-
42- _swap (distances , start_ , pivot_index )
43- _swap (indices , start_ , pivot_index )
44- _swap (is_terminal , start_ , pivot_index )
45-
46- higher = _partition (distances , indices , is_terminal , start_ + 1 , end_ , p )
47-
48- _swap (distances , start_ , higher - 1 )
49- _swap (indices , start_ , higher - 1 )
50- _swap (is_terminal , start_ , higher - 1 )
51-
52- if k <= higher - 1 :
53- end_ = higher
54- else :
55- start_ = higher
56-
57-
5814class VPTree :
5915
6016 def __init__ (
@@ -71,9 +27,12 @@ def __init__(
7127 self .__leaf_capacity = leaf_capacity
7228 self .__leaf_radius = leaf_radius
7329 self .__pivoting = pivoting
74- self .__dataset , self .__distances , self .__indices , self .__is_terminal = (
75- self ._Build (self , X ).build ()
76- )
30+ (
31+ self .__dataset ,
32+ self .__arr_distances ,
33+ self .__arr_indices ,
34+ self .__arr_is_terminal ,
35+ ) = self ._Build (self , X ).build ()
7736
7837 def get_metric (self ):
7938 return self .__metric
@@ -98,23 +57,23 @@ def _get_distance(self):
9857 return get_metric (self .__metric , ** metric_params )
9958
10059 def _get_distances (self ):
101- return self .__distances
60+ return self .__arr_distances
10261
10362 def _get_indices (self ):
104- return self .__indices
63+ return self .__arr_indices
10564
10665 def _get_is_terminal (self ):
107- return self .__is_terminal
66+ return self .__arr_is_terminal
10867
10968 class _Build :
11069
11170 def __init__ (self , vpt , X ):
11271 self .__distance = vpt ._get_distance ()
11372
11473 self .__dataset = [x for x in X ]
115- self .__indices = np .array ([i for i in range (len (self .__dataset ))])
116- self .__distances = np .array ([0.0 for _ in X ])
117- self .__is_terminal = np .array ([False for _ in X ])
74+ self .__arr_indices = np .array ([i for i in range (len (self .__dataset ))])
75+ self .__arr_distances = np .array ([0.0 for _ in X ])
76+ self .__arr_is_terminal = np .array ([False for _ in X ])
11877
11978 self .__leaf_capacity = vpt .get_leaf_capacity ()
12079 self .__leaf_radius = vpt .get_leaf_radius ()
@@ -133,20 +92,25 @@ def _pivoting_random(self, start, end):
13392 return
13493 pivot = randrange (start , end )
13594 if pivot > start :
136- _swap (self .__distances , start , pivot )
137- _swap (self .__indices , start , pivot )
138- _swap (self .__is_terminal , start , pivot )
95+ swap_all (
96+ self .__arr_distances ,
97+ start ,
98+ pivot ,
99+ self .__arr_indices ,
100+ self .__arr_is_terminal ,
101+ )
102+
103+ def _get_point (self , i ):
104+ return self .__dataset [self .__arr_indices [i ]]
139105
140106 def _furthest (self , start , end , i ):
141107 furthest_dist = 0.0
142108 furthest = start
143109
144- i_point_index = self .__indices [i ]
145- i_point = self .__dataset [i_point_index ]
110+ i_point = self ._get_point (i )
146111
147112 for j in range (start , end ):
148- j_point_index = self .__indices [j ]
149- j_point = self .__dataset [j_point_index ]
113+ j_point = self ._get_point (j )
150114
151115 j_dist = self .__distance (i_point , j_point )
152116 if j_dist > furthest_dist :
@@ -161,27 +125,36 @@ def _pivoting_furthest(self, start, end):
161125 furthest_rnd = self ._furthest (start , end , rnd )
162126 furthest = self ._furthest (start , end , furthest_rnd )
163127 if furthest > start :
164- _swap (self .__distances , start , furthest )
165- _swap (self .__indices , start , furthest )
166- _swap (self .__is_terminal , start , furthest )
128+ swap_all (
129+ self .__arr_distances ,
130+ start ,
131+ furthest ,
132+ self .__arr_indices ,
133+ self .__arr_is_terminal ,
134+ )
167135
168136 def _update (self , start , end ):
169137 self .__pivoting (start , end )
170138
171- v_point_index = self .__indices [start ]
139+ v_point_index = self .__arr_indices [start ]
172140 v_point = self .__dataset [v_point_index ]
173- is_terminal = self .__is_terminal [start ]
141+ is_terminal = self .__arr_is_terminal [start ]
174142
175143 for i in range (start + 1 , end ):
176- point_index = self .__indices [i ]
144+ point_index = self .__arr_indices [i ]
177145 point = self .__dataset [point_index ]
178146
179- self .__distances [i ] = self .__distance (v_point , point )
180- self .__is_terminal [i ] = is_terminal
147+ self .__arr_distances [i ] = self .__distance (v_point , point )
148+ self .__arr_is_terminal [i ] = is_terminal
181149
182150 def build (self ):
183151 self ._build_iter ()
184- return self .__dataset , self .__distances , self .__indices , self .__is_terminal
152+ return (
153+ self .__dataset ,
154+ self .__arr_distances ,
155+ self .__arr_indices ,
156+ self .__arr_is_terminal ,
157+ )
185158
186159 def _build_iter (self ):
187160 stack = [(0 , len (self .__dataset ))]
@@ -190,32 +163,32 @@ def _build_iter(self):
190163 mid = _mid (start , end )
191164 self ._update (start , end )
192165
193- v_point_index = self .__indices [start ]
166+ # v_point_index = self.__indices[start]
194167
195- _quickselect (
196- self .__distances ,
197- self .__indices ,
198- self .__is_terminal ,
168+ quickselect (
169+ self .__arr_distances ,
199170 start + 1 ,
200171 end ,
201172 mid ,
173+ self .__arr_indices ,
174+ self .__arr_is_terminal ,
202175 )
203176
204- v_radius = self .__distances [mid ]
177+ v_radius = self .__arr_distances [mid ]
205178
206179 if (end - start > 2 * self .__leaf_capacity ) and (
207180 v_radius > self .__leaf_radius
208181 ):
209- self .__distances [start ] = v_radius
210- self .__indices [start ] = v_point_index
211- self .__is_terminal [start ] = False
182+ self .__arr_distances [start ] = v_radius
183+ # self.__indices[start] = v_point_index
184+ self .__arr_is_terminal [start ] = False
212185
213186 stack .append ((mid , end ))
214187 stack .append ((start + 1 , mid ))
215188 else :
216- self .__distances [start ] = v_radius
217- self .__indices [start ] = v_point_index
218- self .__is_terminal [start ] = True
189+ self .__arr_distances [start ] = v_radius
190+ # self.__indices[start] = v_point_index
191+ self .__arr_is_terminal [start ] = True
219192
220193 def ball_search (self , point , eps , inclusive = True ):
221194 return self ._BallSearch (self , point , eps , inclusive ).search ()
0 commit comments