Skip to content

Commit 3588b0c

Browse files
authored
Merge pull request #144 from lucasimi/feature/improve-vptree
Feature/improve vptree
2 parents 064f1df + cf8ade2 commit 3588b0c

File tree

5 files changed

+138
-146
lines changed

5 files changed

+138
-146
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@
186186
same "printed page" as the copyright notice for easier
187187
identification within third-party archives.
188188

189-
Copyright [yyyy] [name of copyright owner]
189+
Copyright 2020 Luca Simi
190190

191191
Licensed under the Apache License, Version 2.0 (the "License");
192192
you may not use this file except in compliance with the License.

src/tdamapper/utils/heap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __len__(self):
5151

5252
def top(self):
5353
if not self.__heap:
54-
return None
54+
return (None, None)
5555
return self.__heap[0].get()
5656

5757
def pop(self):

src/tdamapper/utils/vptree_flat.py

Lines changed: 118 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from random import randrange
22

33
from tdamapper.utils.metrics import get_metric
4-
from tdamapper.utils.quickselect import quickselect
54
from tdamapper.utils.heap import MaxHeap
65

76

@@ -13,6 +12,31 @@ def _mid(start, end):
1312
return (start + end) // 2
1413

1514

15+
def _partition(data, start, end, p_ord):
16+
higher = start
17+
for j in range(start, end):
18+
j_ord, _, _ = data[j]
19+
if j_ord < p_ord:
20+
_swap(data, higher, j)
21+
higher += 1
22+
return higher
23+
24+
25+
def _quickselect(data, start, end, k):
26+
if (k < start) or (k >= end):
27+
return
28+
start_, end_, higher = start, end, None
29+
while higher != k + 1:
30+
p, _, _ = data[k]
31+
_swap(data, start_, k)
32+
higher = _partition(data, start_ + 1, end_, p)
33+
_swap(data, start_, higher - 1)
34+
if k <= higher - 1:
35+
end_ = higher
36+
else:
37+
start_ = higher
38+
39+
1640
class VPTree:
1741

1842
def __init__(
@@ -57,7 +81,7 @@ class _Build:
5781

5882
def __init__(self, vpt, X):
5983
self.__distance = vpt._get_distance()
60-
self.__dataset = [(0.0, x) for x in X]
84+
self.__dataset = [(0.0, x, False) for x in X]
6185
self.__leaf_capacity = vpt.get_leaf_capacity()
6286
self.__leaf_radius = vpt.get_leaf_radius()
6387
pivoting = vpt.get_pivoting()
@@ -71,7 +95,7 @@ def _pivoting_disabled(self, start, end):
7195
pass
7296

7397
def _pivoting_random(self, start, end):
74-
if end - start < 2:
98+
if end <= start:
7599
return
76100
pivot = randrange(start, end)
77101
if pivot > start:
@@ -80,17 +104,17 @@ def _pivoting_random(self, start, end):
80104
def _furthest(self, start, end, i):
81105
furthest_dist = 0.0
82106
furthest = start
83-
_, i_point = self.__dataset[i]
107+
_, i_point, _ = self.__dataset[i]
84108
for j in range(start, end):
85-
_, j_point = self.__dataset[j]
109+
_, j_point, _ = self.__dataset[j]
86110
j_dist = self.__distance(i_point, j_point)
87111
if j_dist > furthest_dist:
88112
furthest = j
89113
furthest_dist = j_dist
90114
return furthest
91115

92116
def _pivoting_furthest(self, start, end):
93-
if end - start < 2:
117+
if end <= start:
94118
return
95119
rnd = randrange(start, end)
96120
furthest_rnd = self._furthest(start, end, rnd)
@@ -100,10 +124,10 @@ def _pivoting_furthest(self, start, end):
100124

101125
def _update(self, start, end):
102126
self.__pivoting(start, end)
103-
_, v_point = self.__dataset[start]
104-
for i in range(start, end):
105-
_, point = self.__dataset[i]
106-
self.__dataset[i] = self.__distance(v_point, point), point
127+
_, v_point, is_terminal = self.__dataset[start]
128+
for i in range(start + 1, end):
129+
_, point, _ = self.__dataset[i]
130+
self.__dataset[i] = self.__distance(v_point, point), point, is_terminal
107131

108132
def build(self):
109133
self._build_iter()
@@ -115,159 +139,122 @@ def _build_iter(self):
115139
start, end = stack.pop()
116140
mid = _mid(start, end)
117141
self._update(start, end)
118-
_, v_point = self.__dataset[start]
119-
quickselect(self.__dataset, start + 1, end, mid)
120-
v_radius, _ = self.__dataset[mid]
121-
self.__dataset[start] = (v_radius, v_point)
122-
if (end - start > 2 * self.__leaf_capacity) and (v_radius > self.__leaf_radius):
142+
_, v_point, _ = self.__dataset[start]
143+
_quickselect(self.__dataset, start + 1, end, mid)
144+
v_radius, _, _ = self.__dataset[mid]
145+
if (
146+
(end - start > 2 * self.__leaf_capacity) and
147+
(v_radius > self.__leaf_radius)
148+
):
149+
self.__dataset[start] = (v_radius, v_point, False)
123150
stack.append((mid, end))
124151
stack.append((start + 1, mid))
152+
else:
153+
self.__dataset[start] = (v_radius, v_point, True)
125154

126155
def ball_search(self, point, eps, inclusive=True):
127156
return self._BallSearch(self, point, eps, inclusive).search()
128157

129158
class _BallSearch:
130159

131-
def __init__(self, vpt, center, radius, inclusive):
132-
self.__distance = vpt._get_distance()
160+
def __init__(self, vpt, point, eps, inclusive=True):
133161
self.__dataset = vpt._get_dataset()
134-
self.__leaf_capacity = vpt._get_leaf_capacity()
135-
self.__leaf_radius = vpt._get_leaf_radius()
136-
self.__center = center
137-
self.__radius = radius
138-
self.__items = []
162+
self.__distance = vpt._get_distance()
163+
self.__point = point
164+
self.__eps = eps
139165
self.__inclusive = inclusive
140166

141-
class _BSVisit:
142-
143-
def __init__(self, start, end, m_radius):
144-
self.__start = start
145-
self.__end = end
146-
self.__m_radius = m_radius
167+
def search(self):
168+
return self._search_iter()
147169

148-
def bounds(self):
149-
return self.__start, self.__end, self.__m_radius
170+
def _inside(self, dist):
171+
if self.__inclusive:
172+
return dist <= self.__eps
173+
return dist < self.__eps
150174

151-
def search(self):
152-
stack = [self._BSVisit(0, len(self.__dataset), float('inf'))]
175+
def _search_iter(self):
176+
stack = [(0, len(self.__dataset))]
177+
result = []
153178
while stack:
154-
visit = stack.pop()
155-
start, end, m_radius = visit.bounds()
156-
v_radius, v_point = self.__dataset[start]
157-
if (end - start <= 2 * self.__leaf_capacity) or (m_radius <= self.__leaf_radius) or (v_radius <= self.__leaf_radius):
158-
for _, x in self.__dataset[start:end]:
159-
dist = self.__distance(self.__center, x)
179+
start, end = stack.pop()
180+
v_radius, v_point, is_terminal = self.__dataset[start]
181+
if is_terminal:
182+
for _, x, _ in self.__dataset[start:end]:
183+
dist = self.__distance(self.__point, x)
160184
if self._inside(dist):
161-
self.__items.append(x)
185+
result.append(x)
162186
else:
163-
dist = self.__distance(self.__center, v_point)
164-
if self._inside(dist):
165-
self.__items.append(v_point)
187+
dist = self.__distance(self.__point, v_point)
166188
mid = _mid(start, end)
189+
if self._inside(dist):
190+
result.append(v_point)
167191
if dist <= v_radius:
168-
fst_start, fst_end, fst_radius = start + 1, mid, v_radius
169-
snd_start, snd_end, snd_radius = mid, end, float('inf')
192+
fst = (start + 1, mid)
193+
snd = (mid, end)
170194
else:
171-
fst_start, fst_end, fst_radius = mid, end, float('inf')
172-
snd_start, snd_end, snd_radius = start + 1, mid, v_radius
173-
if abs(dist - v_radius) <= self.__radius:
174-
stack.append(self._BSVisit(snd_start, snd_end, snd_radius))
175-
stack.append(self._BSVisit(fst_start, fst_end, fst_radius))
176-
return self.__items
195+
fst = (mid, end)
196+
snd = (start + 1, mid)
197+
if abs(dist - v_radius) <= self.__eps:
198+
stack.append(snd)
199+
stack.append(fst)
200+
return result
177201

178-
def _inside(self, dist):
179-
if self.__inclusive:
180-
return dist <= self.__radius
181-
return dist < self.__radius
202+
def knn_search(self, point, k):
203+
return self._KnnSearch(self, point, k).search()
182204

183-
def knn_search(self, point, neighbors):
184-
return self._KNNSearch(self, point, neighbors).search()
205+
class _KnnSearch:
185206

186-
class _KNNSearch:
187-
188-
def __init__(self, vpt, center, neighbors):
189-
self.__distance = vpt._VPTree__distance
190-
self.__dataset = vpt._VPTree__dataset
191-
self.__leaf_capacity = vpt._VPTree__leaf_capacity
192-
self.__leaf_radius = vpt._VPTree__leaf_radius
193-
self.__center = center
207+
def __init__(self, vpt, point, neighbors):
208+
self.__dataset = vpt._get_dataset()
209+
self.__distance = vpt._get_distance()
210+
self.__point = point
194211
self.__neighbors = neighbors
195-
self.__items = MaxHeap()
212+
self.__radius = float('inf')
213+
self.__result = MaxHeap()
196214

197215
def _get_items(self):
198-
while len(self.__items) > self.__neighbors:
199-
self.__items.pop()
200-
return [x for (_, x) in self.__items]
216+
while len(self.__result) > self.__neighbors:
217+
self.__result.pop()
218+
return [x for (_, x) in self.__result]
201219

202-
def _get_radius(self):
203-
if len(self.__items) < self.__neighbors:
204-
return float('inf')
205-
furthest_dist, _ = self.__items.top()
206-
return furthest_dist
220+
def search(self):
221+
self._search_iter()
222+
return self._get_items()
207223

208224
def _process(self, x):
209-
dist = self.__distance(self.__center, x)
210-
if dist >= self._get_radius():
225+
dist = self.__distance(self.__point, x)
226+
if dist >= self.__radius:
211227
return dist
212-
self.__items.add(dist, x)
213-
while len(self.__items) > self.__neighbors:
214-
self.__items.pop()
228+
self.__result.add(dist, x)
229+
while len(self.__result) > self.__neighbors:
230+
self.__result.pop()
231+
if len(self.__result) == self.__neighbors:
232+
self.__radius, _ = self.__result.top()
215233
return dist
216234

217-
def pre(self, pre, stack):
218-
start, end, _ = pre.bounds()
219-
v_radius, v_point = self.__dataset[start]
220-
dist = self._process(v_point)
221-
mid = _mid(start, end)
222-
if dist <= v_radius:
223-
fst_start, fst_end, fst_radius = start + 1, mid, v_radius
224-
snd_start, snd_end, snd_radius = mid, end, float('inf')
225-
else:
226-
fst_start, fst_end, fst_radius = mid, end, float('inf')
227-
snd_start, snd_end, snd_radius = start + 1, mid, v_radius
228-
stack.append((self._KVPost(snd_start, snd_end, snd_radius, dist, v_radius), self.post))
229-
stack.append((self._KVPre(fst_start, fst_end, fst_radius), self.pre))
230-
231-
def post(self, post, stack):
232-
start, end, _ = post.bounds()
233-
m_radius, dist, v_radius = post.rad()
234-
if abs(dist - v_radius) <= self._get_radius():
235-
stack.append((self._KVPre(start, end, m_radius), self.pre))
236-
237-
def search(self):
238-
stack = [(self._KVPre(0, len(self.__dataset), float('inf')), self.pre)]
235+
def _search_iter(self):
236+
PRE, POST = 0, 1
237+
self.__result = MaxHeap()
238+
stack = [(0, len(self.__dataset), 0.0, PRE)]
239239
while stack:
240-
visit, after = stack.pop()
241-
start, end, m_radius = visit.bounds()
242-
v_radius, _ = self.__dataset[start]
243-
if (end - start <= 2 * self.__leaf_capacity) or (m_radius <= self.__leaf_radius) or (v_radius <= self.__leaf_radius):
244-
for _, x in self.__dataset[start:end]:
240+
start, end, thr, action = stack.pop()
241+
v_radius, v_point, is_terminal = self.__dataset[start]
242+
if is_terminal:
243+
for _, x, _ in self.__dataset[start:end]:
245244
self._process(x)
246245
else:
247-
after(visit, stack)
246+
if action == PRE:
247+
mid = _mid(start, end)
248+
dist = self._process(v_point)
249+
if dist <= v_radius:
250+
fst_start, fst_end = start + 1, mid
251+
snd_start, snd_end = mid, end
252+
else:
253+
fst_start, fst_end = mid, end
254+
snd_start, snd_end = start + 1, mid
255+
stack.append((snd_start, snd_end, abs(v_radius - dist), POST))
256+
stack.append((fst_start, fst_end, 0.0, PRE))
257+
elif action == POST:
258+
if self.__radius > thr:
259+
stack.append((start, end, 0.0, PRE))
248260
return self._get_items()
249-
250-
class _KVPre:
251-
252-
def __init__(self, start, end, m_radius):
253-
self.__start = start
254-
self.__end = end
255-
self.__m_radius = m_radius
256-
257-
def bounds(self):
258-
return self.__start, self.__end, self.__m_radius
259-
260-
class _KVPost:
261-
262-
def __init__(self, start, end, m_radius, dist, v_radius):
263-
self.__start = start
264-
self.__end = end
265-
self.__m_radius = m_radius
266-
self.__dist = dist
267-
self.__v_radius = v_radius
268-
269-
def bounds(self):
270-
return self.__start, self.__end, self.__m_radius
271-
272-
def rad(self):
273-
return self.__m_radius, self.__dist, self.__v_radius

src/tdamapper/utils/vptree_hier.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def _build_rec(self, start, end):
120120
quickselect(self.__dataset, start + 1, end, mid)
121121
v_radius, _ = self.__dataset[mid]
122122
self.__dataset[start] = (v_radius, v_point)
123-
if (end - start <= 2 * self.__leaf_capacity) or (v_radius <= self.__leaf_radius):
123+
if (
124+
(end - start <= 2 * self.__leaf_capacity) or
125+
(v_radius <= self.__leaf_radius)
126+
):
124127
left = _Leaf(start + 1, mid)
125128
right = _Leaf(mid, end)
126129
else:
@@ -201,6 +204,10 @@ def _get_radius(self):
201204
furthest_dist, _ = self.__items.top()
202205
return furthest_dist
203206

207+
def search(self):
208+
self._search_rec(self.__tree)
209+
return self._get_items()
210+
204211
def _search_rec(self, tree):
205212
if tree.is_terminal():
206213
start, end = tree.get_bounds()
@@ -221,10 +228,6 @@ def _search_rec(self, tree):
221228
if abs(dist - v_radius) <= self._get_radius():
222229
self._search_rec(snd)
223230

224-
def search(self):
225-
self._search_rec(self.__tree)
226-
return self._get_items()
227-
228231

229232
class _Node:
230233

0 commit comments

Comments
 (0)