Skip to content

Commit 4b6087b

Browse files
committed
Refactor of vptree code. Improved modularity and performance using numba
1 parent afe5eb8 commit 4b6087b

19 files changed

Lines changed: 611 additions & 618 deletions

src/tdamapper/utils/vptree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
A module for fast knn and range searches, depending only on a given metric
33
"""
44

5-
from tdamapper.utils.vptree_flat import VPTree as FVPT
6-
from tdamapper.utils.vptree_hier import VPTree as HVPT
5+
from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT
6+
from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT
77

88

99
class VPTree:

src/tdamapper/utils/vptree_flat.py

Lines changed: 0 additions & 316 deletions
This file was deleted.

src/tdamapper/utils/vptree_flat/__init__.py

Whitespace-only changes.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from tdamapper.utils.vptree_flat.common import _mid
2+
3+
4+
class BallSearch:
5+
6+
def __init__(self, vpt, point, eps, inclusive=True):
7+
self._arr = vpt._get_arr()
8+
self.__distance = vpt._get_distance()
9+
self.__point = point
10+
self.__eps = eps
11+
self.__inclusive = inclusive
12+
13+
def search(self):
14+
return self._search_iter()
15+
16+
def _inside(self, dist):
17+
if self.__inclusive:
18+
return dist <= self.__eps
19+
return dist < self.__eps
20+
21+
def _search_iter(self):
22+
stack = [(0, self._arr.size())]
23+
result = []
24+
while stack:
25+
start, end = stack.pop()
26+
v_radius = self._arr.get_distance(start)
27+
v_point = self._arr.get_point(start)
28+
is_terminal = self._arr.is_terminal(start)
29+
if is_terminal:
30+
for x in self._arr.get_points(start, end):
31+
dist = self.__distance(self.__point, x)
32+
if self._inside(dist):
33+
result.append(x)
34+
else:
35+
dist = self.__distance(self.__point, v_point)
36+
mid = _mid(start, end)
37+
if self._inside(dist):
38+
result.append(v_point)
39+
if dist <= v_radius:
40+
fst = (start + 1, mid)
41+
snd = (mid, end)
42+
else:
43+
fst = (mid, end)
44+
snd = (start + 1, mid)
45+
if abs(dist - v_radius) <= self.__eps:
46+
stack.append(snd)
47+
stack.append(fst)
48+
return result

0 commit comments

Comments
 (0)