Skip to content

Commit fb05faf

Browse files
committed
Refactored flat vptree using numba for array operations
1 parent 3e07f19 commit fb05faf

2 files changed

Lines changed: 123 additions & 31 deletions

File tree

src/tdamapper/utils/vptree_flat.py

Lines changed: 111 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from random import randrange
22

3+
import numpy as np
4+
from numba import njit
5+
36
from tdamapper.utils.heap import MaxHeap
47
from tdamapper.utils.metrics import get_metric
58

69

10+
@njit
711
def _swap(arr, i, j):
812
arr[i], arr[j] = arr[j], arr[i]
913

@@ -12,25 +16,39 @@ def _mid(start, end):
1216
return (start + end) // 2
1317

1418

15-
def _partition(data, start, end, p_ord):
19+
@njit
20+
def _partition(distances, indices, is_terminal, start, end, p_ord):
1621
higher = start
1722
for j in range(start, end):
18-
j_ord, _, _ = data[j]
23+
j_ord = distances[j]
1924
if j_ord < p_ord:
20-
_swap(data, higher, j)
25+
_swap(distances, higher, j)
26+
_swap(indices, higher, j)
27+
_swap(is_terminal, higher, j)
2128
higher += 1
2229
return higher
2330

2431

25-
def _quickselect(data, start, end, k):
32+
def _quickselect(distances, indices, is_terminal, start, end, k):
2633
if (k < start) or (k >= end):
2734
return
2835
start_, end_, higher = start, end, None
2936
while higher != k + 1:
30-
p, _, _ = data[k]
31-
_swap(data, start_, k)
32-
higher = _partition(data, start_ + 1, end_, p)
33-
_swap(data, start_, higher - 1)
37+
# TODO: pivot_index = randrange(start_, end_)
38+
pivot_index = k
39+
40+
p = distances[pivot_index]
41+
42+
_swap(distances, start_, pivot_index)
43+
_swap(indices, start_, pivot_index)
44+
_swap(is_terminal, start_, pivot_index)
45+
46+
higher = _partition(distances, indices, is_terminal, start_ + 1, end_, p)
47+
48+
_swap(distances, start_, higher - 1)
49+
_swap(indices, start_, higher - 1)
50+
_swap(is_terminal, start_, higher - 1)
51+
3452
if k <= higher - 1:
3553
end_ = higher
3654
else:
@@ -53,7 +71,9 @@ def __init__(
5371
self.__leaf_capacity = leaf_capacity
5472
self.__leaf_radius = leaf_radius
5573
self.__pivoting = pivoting
56-
self.__dataset = self._Build(self, X).build()
74+
self.__dataset, self.__distances, self.__indices, self.__is_terminal = (
75+
self._Build(self, X).build()
76+
)
5777

5878
def get_metric(self):
5979
return self.__metric
@@ -77,11 +97,25 @@ def _get_distance(self):
7797
metric_params = self.__metric_params or {}
7898
return get_metric(self.__metric, **metric_params)
7999

100+
def _get_distances(self):
101+
return self.__distances
102+
103+
def _get_indices(self):
104+
return self.__indices
105+
106+
def _get_is_terminal(self):
107+
return self.__is_terminal
108+
80109
class _Build:
81110

82111
def __init__(self, vpt, X):
83112
self.__distance = vpt._get_distance()
84-
self.__dataset = [(0.0, x, False) for x in X]
113+
114+
self.__dataset = [x for x in X]
115+
self.__indices = np.array([i for i in range(len(self.__dataset))])
116+
self.__distances = np.array([0.0 for _ in X])
117+
self.__is_terminal = np.array([False for _ in X])
118+
85119
self.__leaf_capacity = vpt.get_leaf_capacity()
86120
self.__leaf_radius = vpt.get_leaf_radius()
87121
pivoting = vpt.get_pivoting()
@@ -99,14 +133,21 @@ def _pivoting_random(self, start, end):
99133
return
100134
pivot = randrange(start, end)
101135
if pivot > start:
102-
_swap(self.__dataset, start, pivot)
136+
_swap(self.__distances, start, pivot)
137+
_swap(self.__indices, start, pivot)
138+
_swap(self.__is_terminal, start, pivot)
103139

104140
def _furthest(self, start, end, i):
105141
furthest_dist = 0.0
106142
furthest = start
107-
_, i_point, _ = self.__dataset[i]
143+
144+
i_point_index = self.__indices[i]
145+
i_point = self.__dataset[i_point_index]
146+
108147
for j in range(start, end):
109-
_, j_point, _ = self.__dataset[j]
148+
j_point_index = self.__indices[j]
149+
j_point = self.__dataset[j_point_index]
150+
110151
j_dist = self.__distance(i_point, j_point)
111152
if j_dist > furthest_dist:
112153
furthest = j
@@ -120,36 +161,61 @@ def _pivoting_furthest(self, start, end):
120161
furthest_rnd = self._furthest(start, end, rnd)
121162
furthest = self._furthest(start, end, furthest_rnd)
122163
if furthest > start:
123-
_swap(self.__dataset, start, furthest)
164+
_swap(self.__distances, start, furthest)
165+
_swap(self.__indices, start, furthest)
166+
_swap(self.__is_terminal, start, furthest)
124167

125168
def _update(self, start, end):
126169
self.__pivoting(start, end)
127-
_, v_point, is_terminal = self.__dataset[start]
170+
171+
v_point_index = self.__indices[start]
172+
v_point = self.__dataset[v_point_index]
173+
is_terminal = self.__is_terminal[start]
174+
128175
for i in range(start + 1, end):
129-
_, point, _ = self.__dataset[i]
130-
self.__dataset[i] = self.__distance(v_point, point), point, is_terminal
176+
point_index = self.__indices[i]
177+
point = self.__dataset[point_index]
178+
179+
self.__distances[i] = self.__distance(v_point, point)
180+
self.__is_terminal[i] = is_terminal
131181

132182
def build(self):
133183
self._build_iter()
134-
return self.__dataset
184+
return self.__dataset, self.__distances, self.__indices, self.__is_terminal
135185

136186
def _build_iter(self):
137187
stack = [(0, len(self.__dataset))]
138188
while stack:
139189
start, end = stack.pop()
140190
mid = _mid(start, end)
141191
self._update(start, end)
142-
_, v_point, _ = self.__dataset[start]
143-
_quickselect(self.__dataset, start + 1, end, mid)
144-
v_radius, _, _ = self.__dataset[mid]
192+
193+
v_point_index = self.__indices[start]
194+
195+
_quickselect(
196+
self.__distances,
197+
self.__indices,
198+
self.__is_terminal,
199+
start + 1,
200+
end,
201+
mid,
202+
)
203+
204+
v_radius = self.__distances[mid]
205+
145206
if (end - start > 2 * self.__leaf_capacity) and (
146207
v_radius > self.__leaf_radius
147208
):
148-
self.__dataset[start] = (v_radius, v_point, False)
209+
self.__distances[start] = v_radius
210+
self.__indices[start] = v_point_index
211+
self.__is_terminal[start] = False
212+
149213
stack.append((mid, end))
150214
stack.append((start + 1, mid))
151215
else:
152-
self.__dataset[start] = (v_radius, v_point, True)
216+
self.__distances[start] = v_radius
217+
self.__indices[start] = v_point_index
218+
self.__is_terminal[start] = True
153219

154220
def ball_search(self, point, eps, inclusive=True):
155221
return self._BallSearch(self, point, eps, inclusive).search()
@@ -158,6 +224,9 @@ class _BallSearch:
158224

159225
def __init__(self, vpt, point, eps, inclusive=True):
160226
self.__dataset = vpt._get_dataset()
227+
self.__distances = vpt._get_distances()
228+
self.__indices = vpt._get_indices()
229+
self.__is_terminal = vpt._get_is_terminal()
161230
self.__distance = vpt._get_distance()
162231
self.__point = point
163232
self.__eps = eps
@@ -176,9 +245,15 @@ def _search_iter(self):
176245
result = []
177246
while stack:
178247
start, end = stack.pop()
179-
v_radius, v_point, is_terminal = self.__dataset[start]
248+
249+
v_radius = self.__distances[start]
250+
v_point_index = self.__indices[start]
251+
v_point = self.__dataset[v_point_index]
252+
is_terminal = self.__is_terminal[start]
253+
180254
if is_terminal:
181-
for _, x, _ in self.__dataset[start:end]:
255+
for x_index in self.__indices[start:end]:
256+
x = self.__dataset[x_index]
182257
dist = self.__distance(self.__point, x)
183258
if self._inside(dist):
184259
result.append(x)
@@ -205,6 +280,9 @@ class _KnnSearch:
205280

206281
def __init__(self, vpt, point, neighbors):
207282
self.__dataset = vpt._get_dataset()
283+
self.__distances = vpt._get_distances()
284+
self.__indices = vpt._get_indices()
285+
self.__is_terminal = vpt._get_is_terminal()
208286
self.__distance = vpt._get_distance()
209287
self.__point = point
210288
self.__neighbors = neighbors
@@ -237,9 +315,15 @@ def _search_iter(self):
237315
stack = [(0, len(self.__dataset), 0.0, PRE)]
238316
while stack:
239317
start, end, thr, action = stack.pop()
240-
v_radius, v_point, is_terminal = self.__dataset[start]
318+
319+
v_radius = self.__distances[start]
320+
v_point_index = self.__indices[start]
321+
v_point = self.__dataset[v_point_index]
322+
is_terminal = self.__is_terminal[start]
323+
241324
if is_terminal:
242-
for _, x, _ in self.__dataset[start:end]:
325+
for x_index in self.__indices[start:end]:
326+
x = self.__dataset[x_index]
243327
self._process(x)
244328
else:
245329
if action == PRE:

tests/test_unit_knn.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,22 +127,30 @@ def test_vptree_simple(self):
127127

128128
def check_vptree(self, vpt):
129129
data = vpt._get_dataset()
130+
distances = vpt._get_distances()
131+
indices = vpt._get_indices()
132+
130133
dist = vpt._get_distance()
131134
leaf_capacity = vpt.get_leaf_capacity()
132135
leaf_radius = vpt.get_leaf_radius()
133136

134137
def check_sub(start, end):
135-
v_radius, v_point, *_ = data[start]
138+
v_radius = distances[start]
139+
v_point_index = indices[start]
140+
v_point = data[v_point_index]
141+
136142
mid = (start + end) // 2
137143
for i in range(start + 1, mid):
138-
_, y, *_ = data[i]
144+
y_index = indices[i]
145+
y = data[y_index]
139146
self.assertTrue(dist(v_point, y) <= v_radius)
140147
for i in range(mid, end):
141-
_, y, *_ = data[i]
148+
y_index = indices[i]
149+
y = data[y_index]
142150
self.assertTrue(dist(v_point, y) >= v_radius)
143151

144152
def check_rec(start, end):
145-
v_radius, *_ = data[start]
153+
v_radius = distances[start]
146154
if (end - start > leaf_capacity) and (v_radius > leaf_radius):
147155
check_sub(start, end)
148156
mid = (start + end) // 2

0 commit comments

Comments
 (0)