Skip to content

Commit afe5eb8

Browse files
committed
Refactored quickselect using numba
1 parent fb05faf commit afe5eb8

4 files changed

Lines changed: 155 additions & 164 deletions

File tree

src/tdamapper/utils/quickselect.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,41 @@
1-
def __swap(arr, i, j):
2-
arr[i], arr[j] = arr[j], arr[i]
1+
from numba import njit
32

43

5-
def partition(data, start, end, p_ord):
6-
higher = start
7-
for j in range(start, end):
8-
j_ord, _ = data[j]
9-
if j_ord < p_ord:
10-
__swap(data, higher, j)
11-
higher += 1
12-
return higher
4+
@njit
5+
def swap(arr, i, j):
6+
arr[i], arr[j] = arr[j], arr[i]
137

148

15-
def quickselect(data, start, end, k):
16-
if (k < start) or (k >= end):
17-
return
18-
start_, end_, higher = start, end, None
19-
while higher != k + 1:
20-
p, _ = data[k]
21-
__swap(data, start_, k)
22-
higher = partition(data, start_ + 1, end_, p)
23-
__swap(data, start_, higher - 1)
24-
if k <= higher - 1:
25-
end_ = higher
26-
else:
27-
start_ = higher
9+
@njit
10+
def swap_all(arr, i, j, extra1=None, extra2=None):
11+
swap(arr, i, j)
12+
if extra1 is not None:
13+
swap(extra1, i, j)
14+
if extra2 is not None:
15+
swap(extra2, i, j)
2816

2917

30-
def partition_tuple(data_ord, data_arr, start, end, p_ord):
18+
@njit
19+
def partition(data, start, end, p_ord, *extra):
3120
higher = start
3221
for j in range(start, end):
33-
j_ord = data_ord[j]
22+
j_ord = data[j]
3423
if j_ord < p_ord:
35-
__swap(data_arr, higher, j)
36-
__swap(data_ord, higher, j)
24+
swap_all(data, higher, j, *extra)
3725
higher += 1
3826
return higher
3927

4028

41-
def quickselect_tuple(data_ord, data_arr, start, end, k):
29+
@njit
30+
def quickselect(data, start, end, k, *extra):
4231
if (k < start) or (k >= end):
4332
return
4433
start_, end_, higher = start, end, None
4534
while higher != k + 1:
46-
p_ord = data_ord[k]
47-
__swap(data_arr, start_, k)
48-
__swap(data_ord, start_, k)
49-
higher = partition_tuple(data_ord, data_arr, start_ + 1, end_, p_ord)
50-
__swap(data_arr, start_, higher - 1)
51-
__swap(data_ord, start_, higher - 1)
35+
p = data[k]
36+
swap_all(data, start_, k, *extra)
37+
higher = partition(data, start_ + 1, end_, p, *extra)
38+
swap_all(data, start_, higher - 1, *extra)
5239
if k <= higher - 1:
5340
end_ = higher
5441
else:

src/tdamapper/utils/vptree_flat.py

Lines changed: 55 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,16 @@
11
from random import randrange
22

33
import numpy as np
4-
from numba import njit
54

65
from tdamapper.utils.heap import MaxHeap
76
from tdamapper.utils.metrics import get_metric
8-
9-
10-
@njit
11-
def _swap(arr, i, j):
12-
arr[i], arr[j] = arr[j], arr[i]
7+
from tdamapper.utils.quickselect import quickselect, swap_all
138

149

1510
def _mid(start, end):
1611
return (start + end) // 2
1712

1813

19-
@njit
20-
def _partition(distances, indices, is_terminal, start, end, p_ord):
21-
higher = start
22-
for j in range(start, end):
23-
j_ord = distances[j]
24-
if j_ord < p_ord:
25-
_swap(distances, higher, j)
26-
_swap(indices, higher, j)
27-
_swap(is_terminal, higher, j)
28-
higher += 1
29-
return higher
30-
31-
32-
def _quickselect(distances, indices, is_terminal, start, end, k):
33-
if (k < start) or (k >= end):
34-
return
35-
start_, end_, higher = start, end, None
36-
while higher != k + 1:
37-
# TODO: pivot_index = randrange(start_, end_)
38-
pivot_index = k
39-
40-
p = distances[pivot_index]
41-
42-
_swap(distances, start_, pivot_index)
43-
_swap(indices, start_, pivot_index)
44-
_swap(is_terminal, start_, pivot_index)
45-
46-
higher = _partition(distances, indices, is_terminal, start_ + 1, end_, p)
47-
48-
_swap(distances, start_, higher - 1)
49-
_swap(indices, start_, higher - 1)
50-
_swap(is_terminal, start_, higher - 1)
51-
52-
if k <= higher - 1:
53-
end_ = higher
54-
else:
55-
start_ = higher
56-
57-
5814
class VPTree:
5915

6016
def __init__(
@@ -71,9 +27,12 @@ def __init__(
7127
self.__leaf_capacity = leaf_capacity
7228
self.__leaf_radius = leaf_radius
7329
self.__pivoting = pivoting
74-
self.__dataset, self.__distances, self.__indices, self.__is_terminal = (
75-
self._Build(self, X).build()
76-
)
30+
(
31+
self.__dataset,
32+
self.__arr_distances,
33+
self.__arr_indices,
34+
self.__arr_is_terminal,
35+
) = self._Build(self, X).build()
7736

7837
def get_metric(self):
7938
return self.__metric
@@ -98,23 +57,23 @@ def _get_distance(self):
9857
return get_metric(self.__metric, **metric_params)
9958

10059
def _get_distances(self):
101-
return self.__distances
60+
return self.__arr_distances
10261

10362
def _get_indices(self):
104-
return self.__indices
63+
return self.__arr_indices
10564

10665
def _get_is_terminal(self):
107-
return self.__is_terminal
66+
return self.__arr_is_terminal
10867

10968
class _Build:
11069

11170
def __init__(self, vpt, X):
11271
self.__distance = vpt._get_distance()
11372

11473
self.__dataset = [x for x in X]
115-
self.__indices = np.array([i for i in range(len(self.__dataset))])
116-
self.__distances = np.array([0.0 for _ in X])
117-
self.__is_terminal = np.array([False for _ in X])
74+
self.__arr_indices = np.array([i for i in range(len(self.__dataset))])
75+
self.__arr_distances = np.array([0.0 for _ in X])
76+
self.__arr_is_terminal = np.array([False for _ in X])
11877

11978
self.__leaf_capacity = vpt.get_leaf_capacity()
12079
self.__leaf_radius = vpt.get_leaf_radius()
@@ -133,20 +92,25 @@ def _pivoting_random(self, start, end):
13392
return
13493
pivot = randrange(start, end)
13594
if pivot > start:
136-
_swap(self.__distances, start, pivot)
137-
_swap(self.__indices, start, pivot)
138-
_swap(self.__is_terminal, start, pivot)
95+
swap_all(
96+
self.__arr_distances,
97+
start,
98+
pivot,
99+
self.__arr_indices,
100+
self.__arr_is_terminal,
101+
)
102+
103+
def _get_point(self, i):
104+
return self.__dataset[self.__arr_indices[i]]
139105

140106
def _furthest(self, start, end, i):
141107
furthest_dist = 0.0
142108
furthest = start
143109

144-
i_point_index = self.__indices[i]
145-
i_point = self.__dataset[i_point_index]
110+
i_point = self._get_point(i)
146111

147112
for j in range(start, end):
148-
j_point_index = self.__indices[j]
149-
j_point = self.__dataset[j_point_index]
113+
j_point = self._get_point(j)
150114

151115
j_dist = self.__distance(i_point, j_point)
152116
if j_dist > furthest_dist:
@@ -161,27 +125,36 @@ def _pivoting_furthest(self, start, end):
161125
furthest_rnd = self._furthest(start, end, rnd)
162126
furthest = self._furthest(start, end, furthest_rnd)
163127
if furthest > start:
164-
_swap(self.__distances, start, furthest)
165-
_swap(self.__indices, start, furthest)
166-
_swap(self.__is_terminal, start, furthest)
128+
swap_all(
129+
self.__arr_distances,
130+
start,
131+
furthest,
132+
self.__arr_indices,
133+
self.__arr_is_terminal,
134+
)
167135

168136
def _update(self, start, end):
169137
self.__pivoting(start, end)
170138

171-
v_point_index = self.__indices[start]
139+
v_point_index = self.__arr_indices[start]
172140
v_point = self.__dataset[v_point_index]
173-
is_terminal = self.__is_terminal[start]
141+
is_terminal = self.__arr_is_terminal[start]
174142

175143
for i in range(start + 1, end):
176-
point_index = self.__indices[i]
144+
point_index = self.__arr_indices[i]
177145
point = self.__dataset[point_index]
178146

179-
self.__distances[i] = self.__distance(v_point, point)
180-
self.__is_terminal[i] = is_terminal
147+
self.__arr_distances[i] = self.__distance(v_point, point)
148+
self.__arr_is_terminal[i] = is_terminal
181149

182150
def build(self):
183151
self._build_iter()
184-
return self.__dataset, self.__distances, self.__indices, self.__is_terminal
152+
return (
153+
self.__dataset,
154+
self.__arr_distances,
155+
self.__arr_indices,
156+
self.__arr_is_terminal,
157+
)
185158

186159
def _build_iter(self):
187160
stack = [(0, len(self.__dataset))]
@@ -190,32 +163,32 @@ def _build_iter(self):
190163
mid = _mid(start, end)
191164
self._update(start, end)
192165

193-
v_point_index = self.__indices[start]
166+
# v_point_index = self.__indices[start]
194167

195-
_quickselect(
196-
self.__distances,
197-
self.__indices,
198-
self.__is_terminal,
168+
quickselect(
169+
self.__arr_distances,
199170
start + 1,
200171
end,
201172
mid,
173+
self.__arr_indices,
174+
self.__arr_is_terminal,
202175
)
203176

204-
v_radius = self.__distances[mid]
177+
v_radius = self.__arr_distances[mid]
205178

206179
if (end - start > 2 * self.__leaf_capacity) and (
207180
v_radius > self.__leaf_radius
208181
):
209-
self.__distances[start] = v_radius
210-
self.__indices[start] = v_point_index
211-
self.__is_terminal[start] = False
182+
self.__arr_distances[start] = v_radius
183+
# self.__indices[start] = v_point_index
184+
self.__arr_is_terminal[start] = False
212185

213186
stack.append((mid, end))
214187
stack.append((start + 1, mid))
215188
else:
216-
self.__distances[start] = v_radius
217-
self.__indices[start] = v_point_index
218-
self.__is_terminal[start] = True
189+
self.__arr_distances[start] = v_radius
190+
# self.__indices[start] = v_point_index
191+
self.__arr_is_terminal[start] = True
219192

220193
def ball_search(self, point, eps, inclusive=True):
221194
return self._BallSearch(self, point, eps, inclusive).search()

0 commit comments

Comments
 (0)