11from random import randrange
22
3+ import numpy as np
4+ from numba import njit
5+
36from tdamapper .utils .heap import MaxHeap
47from tdamapper .utils .metrics import get_metric
58
69
10+ @njit
711def _swap (arr , i , j ):
812 arr [i ], arr [j ] = arr [j ], arr [i ]
913
@@ -12,25 +16,39 @@ def _mid(start, end):
1216 return (start + end ) // 2
1317
1418
15- def _partition (data , start , end , p_ord ):
19+ @njit
20+ def _partition (distances , indices , is_terminal , start , end , p_ord ):
1621 higher = start
1722 for j in range (start , end ):
18- j_ord , _ , _ = data [j ]
23+ j_ord = distances [j ]
1924 if j_ord < p_ord :
20- _swap (data , higher , j )
25+ _swap (distances , higher , j )
26+ _swap (indices , higher , j )
27+ _swap (is_terminal , higher , j )
2128 higher += 1
2229 return higher
2330
2431
25- def _quickselect (data , start , end , k ):
32+ def _quickselect (distances , indices , is_terminal , start , end , k ):
2633 if (k < start ) or (k >= end ):
2734 return
2835 start_ , end_ , higher = start , end , None
2936 while higher != k + 1 :
30- p , _ , _ = data [k ]
31- _swap (data , start_ , k )
32- higher = _partition (data , start_ + 1 , end_ , p )
33- _swap (data , start_ , higher - 1 )
37+ # TODO: pivot_index = randrange(start_, end_)
38+ pivot_index = k
39+
40+ p = distances [pivot_index ]
41+
42+ _swap (distances , start_ , pivot_index )
43+ _swap (indices , start_ , pivot_index )
44+ _swap (is_terminal , start_ , pivot_index )
45+
46+ higher = _partition (distances , indices , is_terminal , start_ + 1 , end_ , p )
47+
48+ _swap (distances , start_ , higher - 1 )
49+ _swap (indices , start_ , higher - 1 )
50+ _swap (is_terminal , start_ , higher - 1 )
51+
3452 if k <= higher - 1 :
3553 end_ = higher
3654 else :
@@ -53,7 +71,9 @@ def __init__(
5371 self .__leaf_capacity = leaf_capacity
5472 self .__leaf_radius = leaf_radius
5573 self .__pivoting = pivoting
56- self .__dataset = self ._Build (self , X ).build ()
74+ self .__dataset , self .__distances , self .__indices , self .__is_terminal = (
75+ self ._Build (self , X ).build ()
76+ )
5777
5878 def get_metric (self ):
5979 return self .__metric
@@ -77,11 +97,25 @@ def _get_distance(self):
7797 metric_params = self .__metric_params or {}
7898 return get_metric (self .__metric , ** metric_params )
7999
100+ def _get_distances (self ):
101+ return self .__distances
102+
103+ def _get_indices (self ):
104+ return self .__indices
105+
106+ def _get_is_terminal (self ):
107+ return self .__is_terminal
108+
80109 class _Build :
81110
82111 def __init__ (self , vpt , X ):
83112 self .__distance = vpt ._get_distance ()
84- self .__dataset = [(0.0 , x , False ) for x in X ]
113+
114+ self .__dataset = [x for x in X ]
115+ self .__indices = np .array ([i for i in range (len (self .__dataset ))])
116+ self .__distances = np .array ([0.0 for _ in X ])
117+ self .__is_terminal = np .array ([False for _ in X ])
118+
85119 self .__leaf_capacity = vpt .get_leaf_capacity ()
86120 self .__leaf_radius = vpt .get_leaf_radius ()
87121 pivoting = vpt .get_pivoting ()
@@ -99,14 +133,21 @@ def _pivoting_random(self, start, end):
99133 return
100134 pivot = randrange (start , end )
101135 if pivot > start :
102- _swap (self .__dataset , start , pivot )
136+ _swap (self .__distances , start , pivot )
137+ _swap (self .__indices , start , pivot )
138+ _swap (self .__is_terminal , start , pivot )
103139
104140 def _furthest (self , start , end , i ):
105141 furthest_dist = 0.0
106142 furthest = start
107- _ , i_point , _ = self .__dataset [i ]
143+
144+ i_point_index = self .__indices [i ]
145+ i_point = self .__dataset [i_point_index ]
146+
108147 for j in range (start , end ):
109- _ , j_point , _ = self .__dataset [j ]
148+ j_point_index = self .__indices [j ]
149+ j_point = self .__dataset [j_point_index ]
150+
110151 j_dist = self .__distance (i_point , j_point )
111152 if j_dist > furthest_dist :
112153 furthest = j
@@ -120,36 +161,61 @@ def _pivoting_furthest(self, start, end):
120161 furthest_rnd = self ._furthest (start , end , rnd )
121162 furthest = self ._furthest (start , end , furthest_rnd )
122163 if furthest > start :
123- _swap (self .__dataset , start , furthest )
164+ _swap (self .__distances , start , furthest )
165+ _swap (self .__indices , start , furthest )
166+ _swap (self .__is_terminal , start , furthest )
124167
125168 def _update (self , start , end ):
126169 self .__pivoting (start , end )
127- _ , v_point , is_terminal = self .__dataset [start ]
170+
171+ v_point_index = self .__indices [start ]
172+ v_point = self .__dataset [v_point_index ]
173+ is_terminal = self .__is_terminal [start ]
174+
128175 for i in range (start + 1 , end ):
129- _ , point , _ = self .__dataset [i ]
130- self .__dataset [i ] = self .__distance (v_point , point ), point , is_terminal
176+ point_index = self .__indices [i ]
177+ point = self .__dataset [point_index ]
178+
179+ self .__distances [i ] = self .__distance (v_point , point )
180+ self .__is_terminal [i ] = is_terminal
131181
132182 def build (self ):
133183 self ._build_iter ()
134- return self .__dataset
184+ return self .__dataset , self . __distances , self . __indices , self . __is_terminal
135185
136186 def _build_iter (self ):
137187 stack = [(0 , len (self .__dataset ))]
138188 while stack :
139189 start , end = stack .pop ()
140190 mid = _mid (start , end )
141191 self ._update (start , end )
142- _ , v_point , _ = self .__dataset [start ]
143- _quickselect (self .__dataset , start + 1 , end , mid )
144- v_radius , _ , _ = self .__dataset [mid ]
192+
193+ v_point_index = self .__indices [start ]
194+
195+ _quickselect (
196+ self .__distances ,
197+ self .__indices ,
198+ self .__is_terminal ,
199+ start + 1 ,
200+ end ,
201+ mid ,
202+ )
203+
204+ v_radius = self .__distances [mid ]
205+
145206 if (end - start > 2 * self .__leaf_capacity ) and (
146207 v_radius > self .__leaf_radius
147208 ):
148- self .__dataset [start ] = (v_radius , v_point , False )
209+ self .__distances [start ] = v_radius
210+ self .__indices [start ] = v_point_index
211+ self .__is_terminal [start ] = False
212+
149213 stack .append ((mid , end ))
150214 stack .append ((start + 1 , mid ))
151215 else :
152- self .__dataset [start ] = (v_radius , v_point , True )
216+ self .__distances [start ] = v_radius
217+ self .__indices [start ] = v_point_index
218+ self .__is_terminal [start ] = True
153219
154220 def ball_search (self , point , eps , inclusive = True ):
155221 return self ._BallSearch (self , point , eps , inclusive ).search ()
@@ -158,6 +224,9 @@ class _BallSearch:
158224
159225 def __init__ (self , vpt , point , eps , inclusive = True ):
160226 self .__dataset = vpt ._get_dataset ()
227+ self .__distances = vpt ._get_distances ()
228+ self .__indices = vpt ._get_indices ()
229+ self .__is_terminal = vpt ._get_is_terminal ()
161230 self .__distance = vpt ._get_distance ()
162231 self .__point = point
163232 self .__eps = eps
@@ -176,9 +245,15 @@ def _search_iter(self):
176245 result = []
177246 while stack :
178247 start , end = stack .pop ()
179- v_radius , v_point , is_terminal = self .__dataset [start ]
248+
249+ v_radius = self .__distances [start ]
250+ v_point_index = self .__indices [start ]
251+ v_point = self .__dataset [v_point_index ]
252+ is_terminal = self .__is_terminal [start ]
253+
180254 if is_terminal :
181- for _ , x , _ in self .__dataset [start :end ]:
255+ for x_index in self .__indices [start :end ]:
256+ x = self .__dataset [x_index ]
182257 dist = self .__distance (self .__point , x )
183258 if self ._inside (dist ):
184259 result .append (x )
@@ -205,6 +280,9 @@ class _KnnSearch:
205280
206281 def __init__ (self , vpt , point , neighbors ):
207282 self .__dataset = vpt ._get_dataset ()
283+ self .__distances = vpt ._get_distances ()
284+ self .__indices = vpt ._get_indices ()
285+ self .__is_terminal = vpt ._get_is_terminal ()
208286 self .__distance = vpt ._get_distance ()
209287 self .__point = point
210288 self .__neighbors = neighbors
@@ -237,9 +315,15 @@ def _search_iter(self):
237315 stack = [(0 , len (self .__dataset ), 0.0 , PRE )]
238316 while stack :
239317 start , end , thr , action = stack .pop ()
240- v_radius , v_point , is_terminal = self .__dataset [start ]
318+
319+ v_radius = self .__distances [start ]
320+ v_point_index = self .__indices [start ]
321+ v_point = self .__dataset [v_point_index ]
322+ is_terminal = self .__is_terminal [start ]
323+
241324 if is_terminal :
242- for _ , x , _ in self .__dataset [start :end ]:
325+ for x_index in self .__indices [start :end ]:
326+ x = self .__dataset [x_index ]
243327 self ._process (x )
244328 else :
245329 if action == PRE :
0 commit comments