From 3e07f197c83ac432fd34a44c125a7788aa0f01b7 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Wed, 21 May 2025 07:38:30 +0200 Subject: [PATCH 01/10] Added profiling decorator. Improved tests --- benchmarks/benchmark.py | 6 +++++- src/tdamapper/_common.py | 22 ++++++++++++++++++++++ src/tdamapper/utils/metrics.py | 6 +++++- tests/test_bench_cover.py | 9 ++++++++- tests/test_bench_vptree.py | 6 +++++- tests/test_unit_proximity.py | 8 ++++++-- 6 files changed, 51 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 9a727130..1300eb24 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -12,6 +12,10 @@ from tdamapper.core import TrivialClustering +def _identity(x): + return x + + def _segment(cardinality, dimension, noise=0.1, start=None, end=None): if start is None: start = np.zeros(dimension) @@ -70,7 +74,7 @@ def fit(self, X, y=None): def run_gm(X, n, p): t0 = time.time() pipe = gm.make_mapper_pipeline( - filter_func=lambda x: x, + filter_func=_identity, cover=gm.CubicalCover(n_intervals=n, overlap_frac=p), clusterer=TrivialEstimator(), ) diff --git a/src/tdamapper/_common.py b/src/tdamapper/_common.py index 5daabca8..4a66980c 100644 --- a/src/tdamapper/_common.py +++ b/src/tdamapper/_common.py @@ -2,6 +2,9 @@ This module provides common functionalities for internal use. """ +import cProfile +import io +import pstats import warnings import numpy as np @@ -147,3 +150,22 @@ def clone(obj): obj_noargs = type(obj)() obj_noargs.set_params(**params) return obj_noargs + + +def profile(n_lines=10): + def decorator(func): + def wrapper(*args, **kwargs): + profiler = cProfile.Profile() + profiler.enable() + result = func(*args, **kwargs) + profiler.disable() + + s = io.StringIO() + ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative") + ps.print_stats(n_lines) + print(s.getvalue()) + return result + + return wrapper + + return decorator diff --git a/src/tdamapper/utils/metrics.py b/src/tdamapper/utils/metrics.py index 84909cbc..1c0fa74c 100644 --- a/src/tdamapper/utils/metrics.py +++ b/src/tdamapper/utils/metrics.py @@ -114,7 +114,11 @@ def minkowski(p): return euclidean() elif np.isinf(p): return chebyshev() - return lambda x, y: _metrics.minkowski(p, x, y) + + def dist(x, y): + return _metrics.minkowski(p, x, y) + + return dist def cosine(): diff --git a/tests/test_bench_cover.py b/tests/test_bench_cover.py index 67bd8549..bfcb19ab 100644 --- a/tests/test_bench_cover.py +++ b/tests/test_bench_cover.py @@ -5,6 +5,7 @@ import numpy as np from sklearn.datasets import load_digits +from tdamapper._common import profile from tdamapper.utils.metrics import euclidean from tdamapper.utils.vptree_flat import VPTree as FVPT from tdamapper.utils.vptree_hier import VPTree as HVPT @@ -21,6 +22,10 @@ def dataset(dim=10, num=1000): return [np.random.rand(dim) for _ in range(num)] +def dist_proj(x, y): + return dist(x[1:], x[1:]) + + class TestVpSettings(unittest.TestCase): setup_logging() @@ -39,11 +44,12 @@ def cover(self, vpt, X, r): def run_bench(self, X, r, dist, vp, **kwargs): XX = np.array([[i] + [xi for xi in x] for i, x in enumerate(X)]) t0 = time.time() - vpt = vp(XX, metric=lambda x, y: dist(x[1:], y[1:]), **kwargs) + vpt = vp(XX, metric=dist_proj, **kwargs) list(self.cover(vpt, XX, r)) t1 = time.time() self.logger.info(f"time: {t1 - t0}") + @profile(n_lines=20) def test_cover_random(self): for r in [1.0, 10.0, 100.0]: for n in [100, 1000, 10000]: @@ -61,6 +67,7 @@ def test_cover_random(self): self.run_bench(X, r, dist, SkBallTree, leaf_radius=r) self.logger.info("") + @profile(n_lines=20) def test_cover_digits(self): X, _ = load_digits(return_X_y=True) # X = PCA(n_components=3).fit_transform(X) diff --git a/tests/test_bench_vptree.py b/tests/test_bench_vptree.py index 9f7f9530..faef9cf5 100644 --- a/tests/test_bench_vptree.py +++ b/tests/test_bench_vptree.py @@ -93,7 +93,11 @@ def _test_knn_search_naive(self, data, name): d(np.array([0.0]), np.array([0.0])) # jit-compile numba t0 = time() for val in data: - data.sort(key=lambda x: d(x, val)) + + def _dist_key(x): + return d(x, val) + + data.sort(key=_dist_key) [x for x in data[: self.k]] t1 = time() self.logger.info(f"{name}: {t1 - t0}") diff --git a/tests/test_unit_proximity.py b/tests/test_unit_proximity.py index 90854235..333da250 100644 --- a/tests/test_unit_proximity.py +++ b/tests/test_unit_proximity.py @@ -10,11 +10,15 @@ def dataset(dim=1, num=10000): return [np.random.rand(dim) for _ in range(num)] +def absdist(x, y): + return abs(x - y) + + class TestProximity(unittest.TestCase): def test_ball_proximity(self): data = list(range(100)) - cover = BallCover(radius=10, metric=lambda x, y: abs(x - y)) + cover = BallCover(radius=10, metric=absdist) cover.fit(data) for x in data: result = cover.search(x) @@ -23,7 +27,7 @@ def test_ball_proximity(self): def test_knn_proximity(self): data = list(range(100)) - cover = KNNCover(neighbors=11, metric=lambda x, y: abs(x - y)) + cover = KNNCover(neighbors=11, metric=absdist) cover.fit(data) for x in range(5, 94): result = cover.search(x) From fb05faf0b1a536b917016455ddc8e2b0472f2ea4 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Wed, 21 May 2025 23:42:36 +0200 Subject: [PATCH 02/10] Refactored flat vptree using numba for array operations --- src/tdamapper/utils/vptree_flat.py | 138 +++++++++++++++++++++++------ tests/test_unit_knn.py | 16 +++- 2 files changed, 123 insertions(+), 31 deletions(-) diff --git a/src/tdamapper/utils/vptree_flat.py b/src/tdamapper/utils/vptree_flat.py index 3af82584..d9c75d07 100755 --- a/src/tdamapper/utils/vptree_flat.py +++ b/src/tdamapper/utils/vptree_flat.py @@ -1,9 +1,13 @@ from random import randrange +import numpy as np +from numba import njit + from tdamapper.utils.heap import MaxHeap from tdamapper.utils.metrics import get_metric +@njit def _swap(arr, i, j): arr[i], arr[j] = arr[j], arr[i] @@ -12,25 +16,39 @@ def _mid(start, end): return (start + end) // 2 -def _partition(data, start, end, p_ord): +@njit +def _partition(distances, indices, is_terminal, start, end, p_ord): higher = start for j in range(start, end): - j_ord, _, _ = data[j] + j_ord = distances[j] if j_ord < p_ord: - _swap(data, higher, j) + _swap(distances, higher, j) + _swap(indices, higher, j) + _swap(is_terminal, higher, j) higher += 1 return higher -def _quickselect(data, start, end, k): +def _quickselect(distances, indices, is_terminal, start, end, k): if (k < start) or (k >= end): return start_, end_, higher = start, end, None while higher != k + 1: - p, _, _ = data[k] - _swap(data, start_, k) - higher = _partition(data, start_ + 1, end_, p) - _swap(data, start_, higher - 1) + # TODO: pivot_index = randrange(start_, end_) + pivot_index = k + + p = distances[pivot_index] + + _swap(distances, start_, pivot_index) + _swap(indices, start_, pivot_index) + _swap(is_terminal, start_, pivot_index) + + higher = _partition(distances, indices, is_terminal, start_ + 1, end_, p) + + _swap(distances, start_, higher - 1) + _swap(indices, start_, higher - 1) + _swap(is_terminal, start_, higher - 1) + if k <= higher - 1: end_ = higher else: @@ -53,7 +71,9 @@ def __init__( self.__leaf_capacity = leaf_capacity self.__leaf_radius = leaf_radius self.__pivoting = pivoting - self.__dataset = self._Build(self, X).build() + self.__dataset, self.__distances, self.__indices, self.__is_terminal = ( + self._Build(self, X).build() + ) def get_metric(self): return self.__metric @@ -77,11 +97,25 @@ def _get_distance(self): metric_params = self.__metric_params or {} return get_metric(self.__metric, **metric_params) + def _get_distances(self): + return self.__distances + + def _get_indices(self): + return self.__indices + + def _get_is_terminal(self): + return self.__is_terminal + class _Build: def __init__(self, vpt, X): self.__distance = vpt._get_distance() - self.__dataset = [(0.0, x, False) for x in X] + + self.__dataset = [x for x in X] + self.__indices = np.array([i for i in range(len(self.__dataset))]) + self.__distances = np.array([0.0 for _ in X]) + self.__is_terminal = np.array([False for _ in X]) + self.__leaf_capacity = vpt.get_leaf_capacity() self.__leaf_radius = vpt.get_leaf_radius() pivoting = vpt.get_pivoting() @@ -99,14 +133,21 @@ def _pivoting_random(self, start, end): return pivot = randrange(start, end) if pivot > start: - _swap(self.__dataset, start, pivot) + _swap(self.__distances, start, pivot) + _swap(self.__indices, start, pivot) + _swap(self.__is_terminal, start, pivot) def _furthest(self, start, end, i): furthest_dist = 0.0 furthest = start - _, i_point, _ = self.__dataset[i] + + i_point_index = self.__indices[i] + i_point = self.__dataset[i_point_index] + for j in range(start, end): - _, j_point, _ = self.__dataset[j] + j_point_index = self.__indices[j] + j_point = self.__dataset[j_point_index] + j_dist = self.__distance(i_point, j_point) if j_dist > furthest_dist: furthest = j @@ -120,18 +161,27 @@ def _pivoting_furthest(self, start, end): furthest_rnd = self._furthest(start, end, rnd) furthest = self._furthest(start, end, furthest_rnd) if furthest > start: - _swap(self.__dataset, start, furthest) + _swap(self.__distances, start, furthest) + _swap(self.__indices, start, furthest) + _swap(self.__is_terminal, start, furthest) def _update(self, start, end): self.__pivoting(start, end) - _, v_point, is_terminal = self.__dataset[start] + + v_point_index = self.__indices[start] + v_point = self.__dataset[v_point_index] + is_terminal = self.__is_terminal[start] + for i in range(start + 1, end): - _, point, _ = self.__dataset[i] - self.__dataset[i] = self.__distance(v_point, point), point, is_terminal + point_index = self.__indices[i] + point = self.__dataset[point_index] + + self.__distances[i] = self.__distance(v_point, point) + self.__is_terminal[i] = is_terminal def build(self): self._build_iter() - return self.__dataset + return self.__dataset, self.__distances, self.__indices, self.__is_terminal def _build_iter(self): stack = [(0, len(self.__dataset))] @@ -139,17 +189,33 @@ def _build_iter(self): start, end = stack.pop() mid = _mid(start, end) self._update(start, end) - _, v_point, _ = self.__dataset[start] - _quickselect(self.__dataset, start + 1, end, mid) - v_radius, _, _ = self.__dataset[mid] + + v_point_index = self.__indices[start] + + _quickselect( + self.__distances, + self.__indices, + self.__is_terminal, + start + 1, + end, + mid, + ) + + v_radius = self.__distances[mid] + if (end - start > 2 * self.__leaf_capacity) and ( v_radius > self.__leaf_radius ): - self.__dataset[start] = (v_radius, v_point, False) + self.__distances[start] = v_radius + self.__indices[start] = v_point_index + self.__is_terminal[start] = False + stack.append((mid, end)) stack.append((start + 1, mid)) else: - self.__dataset[start] = (v_radius, v_point, True) + self.__distances[start] = v_radius + self.__indices[start] = v_point_index + self.__is_terminal[start] = True def ball_search(self, point, eps, inclusive=True): return self._BallSearch(self, point, eps, inclusive).search() @@ -158,6 +224,9 @@ class _BallSearch: def __init__(self, vpt, point, eps, inclusive=True): self.__dataset = vpt._get_dataset() + self.__distances = vpt._get_distances() + self.__indices = vpt._get_indices() + self.__is_terminal = vpt._get_is_terminal() self.__distance = vpt._get_distance() self.__point = point self.__eps = eps @@ -176,9 +245,15 @@ def _search_iter(self): result = [] while stack: start, end = stack.pop() - v_radius, v_point, is_terminal = self.__dataset[start] + + v_radius = self.__distances[start] + v_point_index = self.__indices[start] + v_point = self.__dataset[v_point_index] + is_terminal = self.__is_terminal[start] + if is_terminal: - for _, x, _ in self.__dataset[start:end]: + for x_index in self.__indices[start:end]: + x = self.__dataset[x_index] dist = self.__distance(self.__point, x) if self._inside(dist): result.append(x) @@ -205,6 +280,9 @@ class _KnnSearch: def __init__(self, vpt, point, neighbors): self.__dataset = vpt._get_dataset() + self.__distances = vpt._get_distances() + self.__indices = vpt._get_indices() + self.__is_terminal = vpt._get_is_terminal() self.__distance = vpt._get_distance() self.__point = point self.__neighbors = neighbors @@ -237,9 +315,15 @@ def _search_iter(self): stack = [(0, len(self.__dataset), 0.0, PRE)] while stack: start, end, thr, action = stack.pop() - v_radius, v_point, is_terminal = self.__dataset[start] + + v_radius = self.__distances[start] + v_point_index = self.__indices[start] + v_point = self.__dataset[v_point_index] + is_terminal = self.__is_terminal[start] + if is_terminal: - for _, x, _ in self.__dataset[start:end]: + for x_index in self.__indices[start:end]: + x = self.__dataset[x_index] self._process(x) else: if action == PRE: diff --git a/tests/test_unit_knn.py b/tests/test_unit_knn.py index 02d402f4..9dffe37c 100644 --- a/tests/test_unit_knn.py +++ b/tests/test_unit_knn.py @@ -127,22 +127,30 @@ def test_vptree_simple(self): def check_vptree(self, vpt): data = vpt._get_dataset() + distances = vpt._get_distances() + indices = vpt._get_indices() + dist = vpt._get_distance() leaf_capacity = vpt.get_leaf_capacity() leaf_radius = vpt.get_leaf_radius() def check_sub(start, end): - v_radius, v_point, *_ = data[start] + v_radius = distances[start] + v_point_index = indices[start] + v_point = data[v_point_index] + mid = (start + end) // 2 for i in range(start + 1, mid): - _, y, *_ = data[i] + y_index = indices[i] + y = data[y_index] self.assertTrue(dist(v_point, y) <= v_radius) for i in range(mid, end): - _, y, *_ = data[i] + y_index = indices[i] + y = data[y_index] self.assertTrue(dist(v_point, y) >= v_radius) def check_rec(start, end): - v_radius, *_ = data[start] + v_radius = distances[start] if (end - start > leaf_capacity) and (v_radius > leaf_radius): check_sub(start, end) mid = (start + end) // 2 From afe5eb8534c2bf85c59e1c170e2fd1e688abcc82 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Thu, 22 May 2025 13:23:25 +0200 Subject: [PATCH 03/10] Refactored quickselect using numba --- src/tdamapper/utils/quickselect.py | 55 +++++------- src/tdamapper/utils/vptree_flat.py | 137 ++++++++++++----------------- src/tdamapper/utils/vptree_hier.py | 72 ++++++++++----- tests/test_unit_quickselect.py | 55 ++++++------ 4 files changed, 155 insertions(+), 164 deletions(-) diff --git a/src/tdamapper/utils/quickselect.py b/src/tdamapper/utils/quickselect.py index 93bc528f..d04c4e3f 100755 --- a/src/tdamapper/utils/quickselect.py +++ b/src/tdamapper/utils/quickselect.py @@ -1,54 +1,41 @@ -def __swap(arr, i, j): - arr[i], arr[j] = arr[j], arr[i] +from numba import njit -def partition(data, start, end, p_ord): - higher = start - for j in range(start, end): - j_ord, _ = data[j] - if j_ord < p_ord: - __swap(data, higher, j) - higher += 1 - return higher +@njit +def swap(arr, i, j): + arr[i], arr[j] = arr[j], arr[i] -def quickselect(data, start, end, k): - if (k < start) or (k >= end): - return - start_, end_, higher = start, end, None - while higher != k + 1: - p, _ = data[k] - __swap(data, start_, k) - higher = partition(data, start_ + 1, end_, p) - __swap(data, start_, higher - 1) - if k <= higher - 1: - end_ = higher - else: - start_ = higher +@njit +def swap_all(arr, i, j, extra1=None, extra2=None): + swap(arr, i, j) + if extra1 is not None: + swap(extra1, i, j) + if extra2 is not None: + swap(extra2, i, j) -def partition_tuple(data_ord, data_arr, start, end, p_ord): +@njit +def partition(data, start, end, p_ord, *extra): higher = start for j in range(start, end): - j_ord = data_ord[j] + j_ord = data[j] if j_ord < p_ord: - __swap(data_arr, higher, j) - __swap(data_ord, higher, j) + swap_all(data, higher, j, *extra) higher += 1 return higher -def quickselect_tuple(data_ord, data_arr, start, end, k): +@njit +def quickselect(data, start, end, k, *extra): if (k < start) or (k >= end): return start_, end_, higher = start, end, None while higher != k + 1: - p_ord = data_ord[k] - __swap(data_arr, start_, k) - __swap(data_ord, start_, k) - higher = partition_tuple(data_ord, data_arr, start_ + 1, end_, p_ord) - __swap(data_arr, start_, higher - 1) - __swap(data_ord, start_, higher - 1) + p = data[k] + swap_all(data, start_, k, *extra) + higher = partition(data, start_ + 1, end_, p, *extra) + swap_all(data, start_, higher - 1, *extra) if k <= higher - 1: end_ = higher else: diff --git a/src/tdamapper/utils/vptree_flat.py b/src/tdamapper/utils/vptree_flat.py index d9c75d07..e807cae0 100755 --- a/src/tdamapper/utils/vptree_flat.py +++ b/src/tdamapper/utils/vptree_flat.py @@ -1,60 +1,16 @@ from random import randrange import numpy as np -from numba import njit from tdamapper.utils.heap import MaxHeap from tdamapper.utils.metrics import get_metric - - -@njit -def _swap(arr, i, j): - arr[i], arr[j] = arr[j], arr[i] +from tdamapper.utils.quickselect import quickselect, swap_all def _mid(start, end): return (start + end) // 2 -@njit -def _partition(distances, indices, is_terminal, start, end, p_ord): - higher = start - for j in range(start, end): - j_ord = distances[j] - if j_ord < p_ord: - _swap(distances, higher, j) - _swap(indices, higher, j) - _swap(is_terminal, higher, j) - higher += 1 - return higher - - -def _quickselect(distances, indices, is_terminal, start, end, k): - if (k < start) or (k >= end): - return - start_, end_, higher = start, end, None - while higher != k + 1: - # TODO: pivot_index = randrange(start_, end_) - pivot_index = k - - p = distances[pivot_index] - - _swap(distances, start_, pivot_index) - _swap(indices, start_, pivot_index) - _swap(is_terminal, start_, pivot_index) - - higher = _partition(distances, indices, is_terminal, start_ + 1, end_, p) - - _swap(distances, start_, higher - 1) - _swap(indices, start_, higher - 1) - _swap(is_terminal, start_, higher - 1) - - if k <= higher - 1: - end_ = higher - else: - start_ = higher - - class VPTree: def __init__( @@ -71,9 +27,12 @@ def __init__( self.__leaf_capacity = leaf_capacity self.__leaf_radius = leaf_radius self.__pivoting = pivoting - self.__dataset, self.__distances, self.__indices, self.__is_terminal = ( - self._Build(self, X).build() - ) + ( + self.__dataset, + self.__arr_distances, + self.__arr_indices, + self.__arr_is_terminal, + ) = self._Build(self, X).build() def get_metric(self): return self.__metric @@ -98,13 +57,13 @@ def _get_distance(self): return get_metric(self.__metric, **metric_params) def _get_distances(self): - return self.__distances + return self.__arr_distances def _get_indices(self): - return self.__indices + return self.__arr_indices def _get_is_terminal(self): - return self.__is_terminal + return self.__arr_is_terminal class _Build: @@ -112,9 +71,9 @@ def __init__(self, vpt, X): self.__distance = vpt._get_distance() self.__dataset = [x for x in X] - self.__indices = np.array([i for i in range(len(self.__dataset))]) - self.__distances = np.array([0.0 for _ in X]) - self.__is_terminal = np.array([False for _ in X]) + self.__arr_indices = np.array([i for i in range(len(self.__dataset))]) + self.__arr_distances = np.array([0.0 for _ in X]) + self.__arr_is_terminal = np.array([False for _ in X]) self.__leaf_capacity = vpt.get_leaf_capacity() self.__leaf_radius = vpt.get_leaf_radius() @@ -133,20 +92,25 @@ def _pivoting_random(self, start, end): return pivot = randrange(start, end) if pivot > start: - _swap(self.__distances, start, pivot) - _swap(self.__indices, start, pivot) - _swap(self.__is_terminal, start, pivot) + swap_all( + self.__arr_distances, + start, + pivot, + self.__arr_indices, + self.__arr_is_terminal, + ) + + def _get_point(self, i): + return self.__dataset[self.__arr_indices[i]] def _furthest(self, start, end, i): furthest_dist = 0.0 furthest = start - i_point_index = self.__indices[i] - i_point = self.__dataset[i_point_index] + i_point = self._get_point(i) for j in range(start, end): - j_point_index = self.__indices[j] - j_point = self.__dataset[j_point_index] + j_point = self._get_point(j) j_dist = self.__distance(i_point, j_point) if j_dist > furthest_dist: @@ -161,27 +125,36 @@ def _pivoting_furthest(self, start, end): furthest_rnd = self._furthest(start, end, rnd) furthest = self._furthest(start, end, furthest_rnd) if furthest > start: - _swap(self.__distances, start, furthest) - _swap(self.__indices, start, furthest) - _swap(self.__is_terminal, start, furthest) + swap_all( + self.__arr_distances, + start, + furthest, + self.__arr_indices, + self.__arr_is_terminal, + ) def _update(self, start, end): self.__pivoting(start, end) - v_point_index = self.__indices[start] + v_point_index = self.__arr_indices[start] v_point = self.__dataset[v_point_index] - is_terminal = self.__is_terminal[start] + is_terminal = self.__arr_is_terminal[start] for i in range(start + 1, end): - point_index = self.__indices[i] + point_index = self.__arr_indices[i] point = self.__dataset[point_index] - self.__distances[i] = self.__distance(v_point, point) - self.__is_terminal[i] = is_terminal + self.__arr_distances[i] = self.__distance(v_point, point) + self.__arr_is_terminal[i] = is_terminal def build(self): self._build_iter() - return self.__dataset, self.__distances, self.__indices, self.__is_terminal + return ( + self.__dataset, + self.__arr_distances, + self.__arr_indices, + self.__arr_is_terminal, + ) def _build_iter(self): stack = [(0, len(self.__dataset))] @@ -190,32 +163,32 @@ def _build_iter(self): mid = _mid(start, end) self._update(start, end) - v_point_index = self.__indices[start] + # v_point_index = self.__indices[start] - _quickselect( - self.__distances, - self.__indices, - self.__is_terminal, + quickselect( + self.__arr_distances, start + 1, end, mid, + self.__arr_indices, + self.__arr_is_terminal, ) - v_radius = self.__distances[mid] + v_radius = self.__arr_distances[mid] if (end - start > 2 * self.__leaf_capacity) and ( v_radius > self.__leaf_radius ): - self.__distances[start] = v_radius - self.__indices[start] = v_point_index - self.__is_terminal[start] = False + self.__arr_distances[start] = v_radius + # self.__indices[start] = v_point_index + self.__arr_is_terminal[start] = False stack.append((mid, end)) stack.append((start + 1, mid)) else: - self.__distances[start] = v_radius - self.__indices[start] = v_point_index - self.__is_terminal[start] = True + self.__arr_distances[start] = v_radius + # self.__indices[start] = v_point_index + self.__arr_is_terminal[start] = True def ball_search(self, point, eps, inclusive=True): return self._BallSearch(self, point, eps, inclusive).search() diff --git a/src/tdamapper/utils/vptree_hier.py b/src/tdamapper/utils/vptree_hier.py index 474ed247..2078e43e 100644 --- a/src/tdamapper/utils/vptree_hier.py +++ b/src/tdamapper/utils/vptree_hier.py @@ -1,12 +1,10 @@ from random import randrange +import numpy as np + from tdamapper.utils.heap import MaxHeap from tdamapper.utils.metrics import get_metric -from tdamapper.utils.quickselect import quickselect - - -def _swap(arr, i, j): - arr[i], arr[j] = arr[j], arr[i] +from tdamapper.utils.quickselect import quickselect, swap_all def _mid(start, end): @@ -29,8 +27,13 @@ def __init__( self.__leaf_capacity = leaf_capacity self.__leaf_radius = leaf_radius self.__pivoting = pivoting - tree, dataset = self._Build(self, X).build() - self.__tree, self.__dataset = tree, dataset + tree, dataset, distances, indices = self._Build(self, X).build() + self.__tree, self.__dataset, self.__distances, self.__indices = ( + tree, + dataset, + distances, + indices, + ) def get_metric(self): return self.__metric @@ -53,6 +56,12 @@ def _get_tree(self): def _get_dataset(self): return self.__dataset + def _get_distances(self): + return self.__distances + + def _get_indices(self): + return self.__indices + def _get_distance(self): metric_params = self.__metric_params or {} return get_metric(self.__metric, **metric_params) @@ -61,7 +70,10 @@ class _Build: def __init__(self, vpt, X): self.__distance = vpt._get_distance() - self.__dataset = [(0.0, x) for x in X] + self.__dataset = [x for x in X] + self.__indices = np.array([i for i in range(len(self.__dataset))]) + self.__distances = np.array([0.0 for _ in X]) + self.__leaf_capacity = vpt.get_leaf_capacity() self.__leaf_radius = vpt.get_leaf_radius() pivoting = vpt.get_pivoting() @@ -79,14 +91,20 @@ def _pivoting_random(self, start, end): return pivot = randrange(start, end) if pivot > start: - _swap(self.__dataset, start, pivot) + swap_all(self.__distances, start, pivot, self.__indices) + + def _get_point(self, i): + return self.__dataset[self.__indices[i]] + + def _get_dist(self, i): + return self.__distances[i] def _furthest(self, start, end, i): furthest_dist = 0.0 furthest = start - _, i_point = self.__dataset[i] + i_point = self._get_point(i) for j in range(start, end): - _, j_point = self.__dataset[j] + j_point = self._get_point(j) j_dist = self.__distance(i_point, j_point) if j_dist > furthest_dist: furthest = j @@ -100,26 +118,26 @@ def _pivoting_furthest(self, start, end): furthest_rnd = self._furthest(start, end, rnd) furthest = self._furthest(start, end, furthest_rnd) if furthest > start: - _swap(self.__dataset, start, furthest) + swap_all(self.__distances, start, furthest, self.__indices) def _update(self, start, end): self.__pivoting(start, end) - _, v_point = self.__dataset[start] + v_point = self._get_point(start) for i in range(start + 1, end): - _, point = self.__dataset[i] - self.__dataset[i] = self.__distance(v_point, point), point + point = self._get_point(i) + self.__distances[i] = self.__distance(v_point, point) def build(self): tree = self._build_rec(0, len(self.__dataset)) - return tree, self.__dataset + return tree, self.__dataset, self.__distances, self.__indices def _build_rec(self, start, end): mid = _mid(start, end) self._update(start, end) - _, v_point = self.__dataset[start] - quickselect(self.__dataset, start + 1, end, mid) - v_radius, _ = self.__dataset[mid] - self.__dataset[start] = (v_radius, v_point) + v_point = self._get_point(start) + quickselect(self.__distances, start + 1, end, mid, self.__indices) + v_radius = self._get_dist(mid) + self.__distances[start] = v_radius if (end - start <= 2 * self.__leaf_capacity) or ( v_radius <= self.__leaf_radius ): @@ -138,6 +156,7 @@ class _BallSearch: def __init__(self, vpt, point, eps, inclusive=True): self.__tree = vpt._get_tree() self.__dataset = vpt._get_dataset() + self.__indices = vpt._get_indices() self.__distance = vpt._get_distance() self.__point = point self.__eps = eps @@ -149,6 +168,10 @@ def search(self): self._search_rec(self.__tree) return self.__result + def _get_points(self, s, e): + for x_index in self.__indices[s:e]: + yield self.__dataset[x_index] + def _inside(self, dist): if self.__inclusive: return dist <= self.__eps @@ -157,7 +180,7 @@ def _inside(self, dist): def _search_rec(self, tree): if tree.is_terminal(): start, end = tree.get_bounds() - for _, x in self.__dataset[start:end]: + for x in self._get_points(start, end): dist = self.__distance(self.__point, x) if self._inside(dist): self.__result.append(x) @@ -182,6 +205,7 @@ class _KnnSearch: def __init__(self, vpt, point, neighbors): self.__tree = vpt._get_tree() self.__dataset = vpt._get_dataset() + self.__indices = vpt._get_indices() self.__distance = vpt._get_distance() self.__point = point self.__neighbors = neighbors @@ -192,6 +216,10 @@ def _add(self, dist, x): if len(self.__items) > self.__neighbors: self.__items.pop() + def _get_points(self, s, e): + for x_index in self.__indices[s:e]: + yield self.__dataset[x_index] + def _get_items(self): while len(self.__items) > self.__neighbors: self.__items.pop() @@ -210,7 +238,7 @@ def search(self): def _search_rec(self, tree): if tree.is_terminal(): start, end = tree.get_bounds() - for _, x in self.__dataset[start:end]: + for x in self._get_points(start, end): dist = self.__distance(self.__point, x) if dist < self._get_radius(): self._add(dist, x) diff --git a/tests/test_unit_quickselect.py b/tests/test_unit_quickselect.py index 86f23247..ca091dc3 100755 --- a/tests/test_unit_quickselect.py +++ b/tests/test_unit_quickselect.py @@ -1,50 +1,53 @@ import random import unittest -from tdamapper.utils.quickselect import ( - partition, - partition_tuple, - quickselect, - quickselect_tuple, -) +import numpy as np + +from tdamapper.utils.quickselect import partition, quickselect class TestQuickSelect(unittest.TestCase): def test_partition(self): n = 1000 - arr = [(i, random.randint(0, n - 1)) for i in range(n)] + arr = np.array([i for i in range(n)]) + arr_extra = np.array([random.randint(0, n - 1) for i in range(n)]) for choice in range(n): - h = partition(arr, 0, n, choice) + h = partition(arr, 0, n, choice, arr_extra) for i in range(0, h): - self.assertTrue(arr[i][0] < choice) + self.assertTrue(arr[i] < choice) for i in range(h, n): - self.assertTrue(arr[i][0] >= choice) + self.assertTrue(arr[i] >= choice) def test_quickselect_bounds(self): - arr = [(0, 4), (1, 5), (-1, 6)] - quickselect(arr, 1, 2, 0) - self.assertEqual((0, 4), arr[0]) - self.assertEqual((1, 5), arr[1]) - self.assertEqual((-1, 6), arr[2]) + arr = np.array([0, 1, -1]) + arr_extra = np.array([4, 5, 6]) + quickselect(arr, 1, 2, 0, arr_extra) + self.assertEqual(0, arr[0]) + self.assertEqual(1, arr[1]) + self.assertEqual(-1, arr[2]) + self.assertEqual(4, arr_extra[0]) + self.assertEqual(5, arr_extra[1]) + self.assertEqual(6, arr_extra[2]) def test_quickselect(self): n = 1000 - arr = [(i, random.randint(0, n - 1)) for i in range(n)] + arr = np.array([i for i in range(n)]) + arr_extra = np.array([random.randint(0, n - 1) for i in range(n)]) for choice in range(n): - quickselect(arr, 0, n, choice) - val = arr[choice][0] + quickselect(arr, 0, n, choice, arr_extra) + val = arr[choice] for i in range(0, choice): - self.assertTrue(arr[i][0] <= val) + self.assertTrue(arr[i] <= val) for i in range(choice, n): - self.assertTrue(arr[i][0] >= val) + self.assertTrue(arr[i] >= val) def test_partition_tuple(self): n = 1000 - arr_data = [random.randint(0, n - 1) for i in range(n)] - arr_ord = list(range(n)) + arr_data = np.array([random.randint(0, n - 1) for i in range(n)]) + arr_ord = np.array(list(range(n))) for choice in range(n): - h = partition_tuple(arr_ord, arr_data, 0, n, choice) + h = partition(arr_ord, 0, n, choice, arr_data) for i in range(0, h): self.assertTrue(arr_ord[i] < choice) for i in range(h, n): @@ -52,10 +55,10 @@ def test_partition_tuple(self): def test_quickselect_tuple(self): n = 1000 - arr_data = [random.randint(0, n - 1) for i in range(n)] - arr_ord = list(range(n)) + arr_data = np.array([random.randint(0, n - 1) for i in range(n)]) + arr_ord = np.array(list(range(n))) for choice in range(n): - quickselect_tuple(arr_ord, arr_data, 0, n, choice) + quickselect(arr_ord, 0, n, choice, arr_data) val = arr_ord[choice] for i in range(0, choice): self.assertTrue(arr_ord[i] <= val) From 4b6087bea2d75d721e3a172cb7f3857bb23c828f Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Thu, 22 May 2025 21:49:52 +0200 Subject: [PATCH 04/10] Refactor of vptree code. Improved modularity and performance using numba --- src/tdamapper/utils/vptree.py | 4 +- src/tdamapper/utils/vptree_flat.py | 316 ------------------ src/tdamapper/utils/vptree_flat/__init__.py | 0 .../utils/vptree_flat/ball_search.py | 48 +++ src/tdamapper/utils/vptree_flat/builder.py | 93 ++++++ src/tdamapper/utils/vptree_flat/common.py | 42 +++ src/tdamapper/utils/vptree_flat/knn_search.py | 64 ++++ src/tdamapper/utils/vptree_flat/vptree.py | 51 +++ src/tdamapper/utils/vptree_hier.py | 290 ---------------- src/tdamapper/utils/vptree_hier/__init__.py | 0 .../utils/vptree_hier/ball_search.py | 40 +++ src/tdamapper/utils/vptree_hier/builder.py | 85 +++++ src/tdamapper/utils/vptree_hier/common.py | 69 ++++ src/tdamapper/utils/vptree_hier/knn_search.py | 52 +++ src/tdamapper/utils/vptree_hier/vptree.py | 54 +++ tests/test_bench_cover.py | 4 +- tests/test_bench_vptree.py | 4 +- tests/test_unit_knn.py | 9 +- tests/test_unit_vptree.py | 4 +- 19 files changed, 611 insertions(+), 618 deletions(-) delete mode 100755 src/tdamapper/utils/vptree_flat.py create mode 100644 src/tdamapper/utils/vptree_flat/__init__.py create mode 100644 src/tdamapper/utils/vptree_flat/ball_search.py create mode 100644 src/tdamapper/utils/vptree_flat/builder.py create mode 100644 src/tdamapper/utils/vptree_flat/common.py create mode 100644 src/tdamapper/utils/vptree_flat/knn_search.py create mode 100755 src/tdamapper/utils/vptree_flat/vptree.py delete mode 100644 src/tdamapper/utils/vptree_hier.py create mode 100644 src/tdamapper/utils/vptree_hier/__init__.py create mode 100644 src/tdamapper/utils/vptree_hier/ball_search.py create mode 100644 src/tdamapper/utils/vptree_hier/builder.py create mode 100644 src/tdamapper/utils/vptree_hier/common.py create mode 100644 src/tdamapper/utils/vptree_hier/knn_search.py create mode 100644 src/tdamapper/utils/vptree_hier/vptree.py diff --git a/src/tdamapper/utils/vptree.py b/src/tdamapper/utils/vptree.py index a85ec3b9..3b3a62f6 100644 --- a/src/tdamapper/utils/vptree.py +++ b/src/tdamapper/utils/vptree.py @@ -2,8 +2,8 @@ A module for fast knn and range searches, depending only on a given metric """ -from tdamapper.utils.vptree_flat import VPTree as FVPT -from tdamapper.utils.vptree_hier import VPTree as HVPT +from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT +from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT class VPTree: diff --git a/src/tdamapper/utils/vptree_flat.py b/src/tdamapper/utils/vptree_flat.py deleted file mode 100755 index e807cae0..00000000 --- a/src/tdamapper/utils/vptree_flat.py +++ /dev/null @@ -1,316 +0,0 @@ -from random import randrange - -import numpy as np - -from tdamapper.utils.heap import MaxHeap -from tdamapper.utils.metrics import get_metric -from tdamapper.utils.quickselect import quickselect, swap_all - - -def _mid(start, end): - return (start + end) // 2 - - -class VPTree: - - def __init__( - self, - X, - metric="euclidean", - metric_params=None, - leaf_capacity=1, - leaf_radius=0.0, - pivoting=None, - ): - self.__metric = metric - self.__metric_params = metric_params - self.__leaf_capacity = leaf_capacity - self.__leaf_radius = leaf_radius - self.__pivoting = pivoting - ( - self.__dataset, - self.__arr_distances, - self.__arr_indices, - self.__arr_is_terminal, - ) = self._Build(self, X).build() - - def get_metric(self): - return self.__metric - - def get_metric_params(self): - return self.__metric_params - - def get_leaf_capacity(self): - return self.__leaf_capacity - - def get_leaf_radius(self): - return self.__leaf_radius - - def get_pivoting(self): - return self.__pivoting - - def _get_dataset(self): - return self.__dataset - - def _get_distance(self): - metric_params = self.__metric_params or {} - return get_metric(self.__metric, **metric_params) - - def _get_distances(self): - return self.__arr_distances - - def _get_indices(self): - return self.__arr_indices - - def _get_is_terminal(self): - return self.__arr_is_terminal - - class _Build: - - def __init__(self, vpt, X): - self.__distance = vpt._get_distance() - - self.__dataset = [x for x in X] - self.__arr_indices = np.array([i for i in range(len(self.__dataset))]) - self.__arr_distances = np.array([0.0 for _ in X]) - self.__arr_is_terminal = np.array([False for _ in X]) - - self.__leaf_capacity = vpt.get_leaf_capacity() - self.__leaf_radius = vpt.get_leaf_radius() - pivoting = vpt.get_pivoting() - self.__pivoting = self._pivoting_disabled - if pivoting == "random": - self.__pivoting = self._pivoting_random - elif pivoting == "furthest": - self.__pivoting = self._pivoting_furthest - - def _pivoting_disabled(self, start, end): - pass - - def _pivoting_random(self, start, end): - if end <= start: - return - pivot = randrange(start, end) - if pivot > start: - swap_all( - self.__arr_distances, - start, - pivot, - self.__arr_indices, - self.__arr_is_terminal, - ) - - def _get_point(self, i): - return self.__dataset[self.__arr_indices[i]] - - def _furthest(self, start, end, i): - furthest_dist = 0.0 - furthest = start - - i_point = self._get_point(i) - - for j in range(start, end): - j_point = self._get_point(j) - - j_dist = self.__distance(i_point, j_point) - if j_dist > furthest_dist: - furthest = j - furthest_dist = j_dist - return furthest - - def _pivoting_furthest(self, start, end): - if end <= start: - return - rnd = randrange(start, end) - furthest_rnd = self._furthest(start, end, rnd) - furthest = self._furthest(start, end, furthest_rnd) - if furthest > start: - swap_all( - self.__arr_distances, - start, - furthest, - self.__arr_indices, - self.__arr_is_terminal, - ) - - def _update(self, start, end): - self.__pivoting(start, end) - - v_point_index = self.__arr_indices[start] - v_point = self.__dataset[v_point_index] - is_terminal = self.__arr_is_terminal[start] - - for i in range(start + 1, end): - point_index = self.__arr_indices[i] - point = self.__dataset[point_index] - - self.__arr_distances[i] = self.__distance(v_point, point) - self.__arr_is_terminal[i] = is_terminal - - def build(self): - self._build_iter() - return ( - self.__dataset, - self.__arr_distances, - self.__arr_indices, - self.__arr_is_terminal, - ) - - def _build_iter(self): - stack = [(0, len(self.__dataset))] - while stack: - start, end = stack.pop() - mid = _mid(start, end) - self._update(start, end) - - # v_point_index = self.__indices[start] - - quickselect( - self.__arr_distances, - start + 1, - end, - mid, - self.__arr_indices, - self.__arr_is_terminal, - ) - - v_radius = self.__arr_distances[mid] - - if (end - start > 2 * self.__leaf_capacity) and ( - v_radius > self.__leaf_radius - ): - self.__arr_distances[start] = v_radius - # self.__indices[start] = v_point_index - self.__arr_is_terminal[start] = False - - stack.append((mid, end)) - stack.append((start + 1, mid)) - else: - self.__arr_distances[start] = v_radius - # self.__indices[start] = v_point_index - self.__arr_is_terminal[start] = True - - def ball_search(self, point, eps, inclusive=True): - return self._BallSearch(self, point, eps, inclusive).search() - - class _BallSearch: - - def __init__(self, vpt, point, eps, inclusive=True): - self.__dataset = vpt._get_dataset() - self.__distances = vpt._get_distances() - self.__indices = vpt._get_indices() - self.__is_terminal = vpt._get_is_terminal() - self.__distance = vpt._get_distance() - self.__point = point - self.__eps = eps - self.__inclusive = inclusive - - def search(self): - return self._search_iter() - - def _inside(self, dist): - if self.__inclusive: - return dist <= self.__eps - return dist < self.__eps - - def _search_iter(self): - stack = [(0, len(self.__dataset))] - result = [] - while stack: - start, end = stack.pop() - - v_radius = self.__distances[start] - v_point_index = self.__indices[start] - v_point = self.__dataset[v_point_index] - is_terminal = self.__is_terminal[start] - - if is_terminal: - for x_index in self.__indices[start:end]: - x = self.__dataset[x_index] - dist = self.__distance(self.__point, x) - if self._inside(dist): - result.append(x) - else: - dist = self.__distance(self.__point, v_point) - mid = _mid(start, end) - if self._inside(dist): - result.append(v_point) - if dist <= v_radius: - fst = (start + 1, mid) - snd = (mid, end) - else: - fst = (mid, end) - snd = (start + 1, mid) - if abs(dist - v_radius) <= self.__eps: - stack.append(snd) - stack.append(fst) - return result - - def knn_search(self, point, k): - return self._KnnSearch(self, point, k).search() - - class _KnnSearch: - - def __init__(self, vpt, point, neighbors): - self.__dataset = vpt._get_dataset() - self.__distances = vpt._get_distances() - self.__indices = vpt._get_indices() - self.__is_terminal = vpt._get_is_terminal() - self.__distance = vpt._get_distance() - self.__point = point - self.__neighbors = neighbors - self.__radius = float("inf") - self.__result = MaxHeap() - - def _get_items(self): - while len(self.__result) > self.__neighbors: - self.__result.pop() - return [x for (_, x) in self.__result] - - def search(self): - self._search_iter() - return self._get_items() - - def _process(self, x): - dist = self.__distance(self.__point, x) - if dist >= self.__radius: - return dist - self.__result.add(dist, x) - while len(self.__result) > self.__neighbors: - self.__result.pop() - if len(self.__result) == self.__neighbors: - self.__radius, _ = self.__result.top() - return dist - - def _search_iter(self): - PRE, POST = 0, 1 - self.__result = MaxHeap() - stack = [(0, len(self.__dataset), 0.0, PRE)] - while stack: - start, end, thr, action = stack.pop() - - v_radius = self.__distances[start] - v_point_index = self.__indices[start] - v_point = self.__dataset[v_point_index] - is_terminal = self.__is_terminal[start] - - if is_terminal: - for x_index in self.__indices[start:end]: - x = self.__dataset[x_index] - self._process(x) - else: - if action == PRE: - mid = _mid(start, end) - dist = self._process(v_point) - if dist <= v_radius: - fst_start, fst_end = start + 1, mid - snd_start, snd_end = mid, end - else: - fst_start, fst_end = mid, end - snd_start, snd_end = start + 1, mid - stack.append((snd_start, snd_end, abs(v_radius - dist), POST)) - stack.append((fst_start, fst_end, 0.0, PRE)) - elif action == POST: - if self.__radius > thr: - stack.append((start, end, 0.0, PRE)) - return self._get_items() diff --git a/src/tdamapper/utils/vptree_flat/__init__.py b/src/tdamapper/utils/vptree_flat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tdamapper/utils/vptree_flat/ball_search.py b/src/tdamapper/utils/vptree_flat/ball_search.py new file mode 100644 index 00000000..418375f5 --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/ball_search.py @@ -0,0 +1,48 @@ +from tdamapper.utils.vptree_flat.common import _mid + + +class BallSearch: + + def __init__(self, vpt, point, eps, inclusive=True): + self._arr = vpt._get_arr() + self.__distance = vpt._get_distance() + self.__point = point + self.__eps = eps + self.__inclusive = inclusive + + def search(self): + return self._search_iter() + + def _inside(self, dist): + if self.__inclusive: + return dist <= self.__eps + return dist < self.__eps + + def _search_iter(self): + stack = [(0, self._arr.size())] + result = [] + while stack: + start, end = stack.pop() + v_radius = self._arr.get_distance(start) + v_point = self._arr.get_point(start) + is_terminal = self._arr.is_terminal(start) + if is_terminal: + for x in self._arr.get_points(start, end): + dist = self.__distance(self.__point, x) + if self._inside(dist): + result.append(x) + else: + dist = self.__distance(self.__point, v_point) + mid = _mid(start, end) + if self._inside(dist): + result.append(v_point) + if dist <= v_radius: + fst = (start + 1, mid) + snd = (mid, end) + else: + fst = (mid, end) + snd = (start + 1, mid) + if abs(dist - v_radius) <= self.__eps: + stack.append(snd) + stack.append(fst) + return result diff --git a/src/tdamapper/utils/vptree_flat/builder.py b/src/tdamapper/utils/vptree_flat/builder.py new file mode 100644 index 00000000..a19b1c12 --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/builder.py @@ -0,0 +1,93 @@ +from random import randrange + +import numpy as np + +from tdamapper.utils.vptree_flat.common import VPArray + + +def _mid(start, end): + return (start + end) // 2 + + +class Builder: + + def __init__(self, vpt, X): + self.__distance = vpt._get_distance() + + dataset = [x for x in X] + indices = np.array([i for i in range(len(dataset))]) + distances = np.array([0.0 for _ in X]) + is_terminal = np.array([False for _ in X]) + self._arr = VPArray(dataset, distances, indices, is_terminal) + + self.__leaf_capacity = vpt.get_leaf_capacity() + self.__leaf_radius = vpt.get_leaf_radius() + pivoting = vpt.get_pivoting() + self.__pivoting = self._pivoting_disabled + if pivoting == "random": + self.__pivoting = self._pivoting_random + elif pivoting == "furthest": + self.__pivoting = self._pivoting_furthest + + def _pivoting_disabled(self, start, end): + pass + + def _pivoting_random(self, start, end): + if end <= start: + return + pivot = randrange(start, end) + if pivot > start: + self._arr.swap(start, pivot) + + def _furthest(self, start, end, i): + furthest_dist = 0.0 + furthest = start + i_point = self._arr.get_point(i) + for j in range(start, end): + j_point = self._arr.get_point(j) + j_dist = self.__distance(i_point, j_point) + if j_dist > furthest_dist: + furthest = j + furthest_dist = j_dist + return furthest + + def _pivoting_furthest(self, start, end): + if end <= start: + return + rnd = randrange(start, end) + furthest_rnd = self._furthest(start, end, rnd) + furthest = self._furthest(start, end, furthest_rnd) + if furthest > start: + self._arr.swap(start, furthest) + + def _update(self, start, end): + self.__pivoting(start, end) + v_point = self._arr.get_point(start) + is_terminal = self._arr.is_terminal(start) + for i in range(start + 1, end): + point = self._arr.get_point(i) + self._arr.set_distance(i, self.__distance(v_point, point)) + self._arr.set_terminal(i, is_terminal) + + def build(self): + self._build_iter() + return self._arr + + def _build_iter(self): + stack = [(0, self._arr.size())] + while stack: + start, end = stack.pop() + mid = _mid(start, end) + self._update(start, end) + self._arr.partition(start + 1, end, mid) + v_radius = self._arr.get_distance(mid) + if (end - start > 2 * self.__leaf_capacity) and ( + v_radius > self.__leaf_radius + ): + self._arr.set_distance(start, v_radius) + self._arr.set_terminal(start, False) + stack.append((mid, end)) + stack.append((start + 1, mid)) + else: + self._arr.set_distance(start, v_radius) + self._arr.set_terminal(start, True) diff --git a/src/tdamapper/utils/vptree_flat/common.py b/src/tdamapper/utils/vptree_flat/common.py new file mode 100644 index 00000000..f51e9c95 --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/common.py @@ -0,0 +1,42 @@ +from tdamapper.utils.quickselect import quickselect, swap_all + + +def _mid(start, end): + return (start + end) // 2 + + +class VPArray: + + def __init__(self, dataset, distances, indices, is_terminal): + self._dataset = dataset + self._distances = distances + self._indices = indices + self._is_terminal = is_terminal + + def size(self): + return len(self._dataset) + + def get_point(self, i): + return self._dataset[self._indices[i]] + + def get_points(self, s, e): + for x_index in self._indices[s:e]: + yield self._dataset[x_index] + + def get_distance(self, i): + return self._distances[i] + + def set_distance(self, i, dist): + self._distances[i] = dist + + def set_terminal(self, i, terminal): + self._is_terminal[i] = terminal + + def is_terminal(self, i): + return self._is_terminal[i] + + def swap(self, i, j): + swap_all(self._distances, i, j, self._indices, self._is_terminal) + + def partition(self, s, e, k): + quickselect(self._distances, s, e, k, self._indices, self._is_terminal) diff --git a/src/tdamapper/utils/vptree_flat/knn_search.py b/src/tdamapper/utils/vptree_flat/knn_search.py new file mode 100644 index 00000000..57b1d6cb --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/knn_search.py @@ -0,0 +1,64 @@ +from tdamapper.utils.heap import MaxHeap +from tdamapper.utils.vptree_flat.common import _mid + + +class KnnSearch: + + def __init__(self, vpt, point, neighbors): + self._arr = vpt._get_arr() + self.__distance = vpt._get_distance() + self.__point = point + self.__neighbors = neighbors + self.__radius = float("inf") + self.__result = MaxHeap() + + def _get_items(self): + while len(self.__result) > self.__neighbors: + self.__result.pop() + return [x for (_, x) in self.__result] + + def search(self): + self._search_iter() + return self._get_items() + + def _process(self, x): + dist = self.__distance(self.__point, x) + if dist >= self.__radius: + return dist + self.__result.add(dist, x) + while len(self.__result) > self.__neighbors: + self.__result.pop() + if len(self.__result) == self.__neighbors: + self.__radius, _ = self.__result.top() + return dist + + def _search_iter(self): + PRE, POST = 0, 1 + self.__result = MaxHeap() + stack = [(0, self._arr.size(), 0.0, PRE)] + while stack: + start, end, thr, action = stack.pop() + + v_radius = self._arr.get_distance(start) + v_point = self._arr.get_point(start) + is_terminal = self._arr.is_terminal(start) + + if is_terminal: + for x in self._arr.get_points(start, end): + self._process(x) + else: + if action == PRE: + mid = _mid(start, end) + dist = self._process(v_point) + if dist <= v_radius: + fst_start, fst_end = start + 1, mid + snd_start, snd_end = mid, end + else: + fst_start, fst_end = mid, end + snd_start, snd_end = start + 1, mid + stack.append((snd_start, snd_end, abs(v_radius - dist), POST)) + stack.append((fst_start, fst_end, 0.0, PRE)) + elif action == POST: + if self.__radius > thr: + stack.append((start, end, 0.0, PRE)) + return self._get_items() diff --git a/src/tdamapper/utils/vptree_flat/vptree.py b/src/tdamapper/utils/vptree_flat/vptree.py new file mode 100755 index 00000000..f157bbc6 --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/vptree.py @@ -0,0 +1,51 @@ +from tdamapper.utils.metrics import get_metric +from tdamapper.utils.vptree_flat.ball_search import BallSearch +from tdamapper.utils.vptree_flat.builder import Builder +from tdamapper.utils.vptree_flat.knn_search import KnnSearch + + +class VPTree: + + def __init__( + self, + X, + metric="euclidean", + metric_params=None, + leaf_capacity=1, + leaf_radius=0.0, + pivoting=None, + ): + self.__metric = metric + self.__metric_params = metric_params + self.__leaf_capacity = leaf_capacity + self.__leaf_radius = leaf_radius + self.__pivoting = pivoting + self._arr = Builder(self, X).build() + + def get_metric(self): + return self.__metric + + def get_metric_params(self): + return self.__metric_params + + def get_leaf_capacity(self): + return self.__leaf_capacity + + def get_leaf_radius(self): + return self.__leaf_radius + + def get_pivoting(self): + return self.__pivoting + + def _get_arr(self): + return self._arr + + def _get_distance(self): + metric_params = self.__metric_params or {} + return get_metric(self.__metric, **metric_params) + + def ball_search(self, point, eps, inclusive=True): + return BallSearch(self, point, eps, inclusive).search() + + def knn_search(self, point, k): + return KnnSearch(self, point, k).search() diff --git a/src/tdamapper/utils/vptree_hier.py b/src/tdamapper/utils/vptree_hier.py deleted file mode 100644 index 2078e43e..00000000 --- a/src/tdamapper/utils/vptree_hier.py +++ /dev/null @@ -1,290 +0,0 @@ -from random import randrange - -import numpy as np - -from tdamapper.utils.heap import MaxHeap -from tdamapper.utils.metrics import get_metric -from tdamapper.utils.quickselect import quickselect, swap_all - - -def _mid(start, end): - return (start + end) // 2 - - -class VPTree: - - def __init__( - self, - X, - metric="euclidean", - metric_params=None, - leaf_capacity=1, - leaf_radius=0.0, - pivoting=None, - ): - self.__metric = metric - self.__metric_params = metric_params - self.__leaf_capacity = leaf_capacity - self.__leaf_radius = leaf_radius - self.__pivoting = pivoting - tree, dataset, distances, indices = self._Build(self, X).build() - self.__tree, self.__dataset, self.__distances, self.__indices = ( - tree, - dataset, - distances, - indices, - ) - - def get_metric(self): - return self.__metric - - def get_metric_params(self): - return self.__metric_params - - def get_leaf_capacity(self): - return self.__leaf_capacity - - def get_leaf_radius(self): - return self.__leaf_radius - - def get_pivoting(self): - return self.__pivoting - - def _get_tree(self): - return self.__tree - - def _get_dataset(self): - return self.__dataset - - def _get_distances(self): - return self.__distances - - def _get_indices(self): - return self.__indices - - def _get_distance(self): - metric_params = self.__metric_params or {} - return get_metric(self.__metric, **metric_params) - - class _Build: - - def __init__(self, vpt, X): - self.__distance = vpt._get_distance() - self.__dataset = [x for x in X] - self.__indices = np.array([i for i in range(len(self.__dataset))]) - self.__distances = np.array([0.0 for _ in X]) - - self.__leaf_capacity = vpt.get_leaf_capacity() - self.__leaf_radius = vpt.get_leaf_radius() - pivoting = vpt.get_pivoting() - self.__pivoting = self._pivoting_disabled - if pivoting == "random": - self.__pivoting = self._pivoting_random - elif pivoting == "furthest": - self.__pivoting = self._pivoting_furthest - - def _pivoting_disabled(self, start, end): - pass - - def _pivoting_random(self, start, end): - if end <= start: - return - pivot = randrange(start, end) - if pivot > start: - swap_all(self.__distances, start, pivot, self.__indices) - - def _get_point(self, i): - return self.__dataset[self.__indices[i]] - - def _get_dist(self, i): - return self.__distances[i] - - def _furthest(self, start, end, i): - furthest_dist = 0.0 - furthest = start - i_point = self._get_point(i) - for j in range(start, end): - j_point = self._get_point(j) - j_dist = self.__distance(i_point, j_point) - if j_dist > furthest_dist: - furthest = j - furthest_dist = j_dist - return furthest - - def _pivoting_furthest(self, start, end): - if end <= start: - return - rnd = randrange(start, end) - furthest_rnd = self._furthest(start, end, rnd) - furthest = self._furthest(start, end, furthest_rnd) - if furthest > start: - swap_all(self.__distances, start, furthest, self.__indices) - - def _update(self, start, end): - self.__pivoting(start, end) - v_point = self._get_point(start) - for i in range(start + 1, end): - point = self._get_point(i) - self.__distances[i] = self.__distance(v_point, point) - - def build(self): - tree = self._build_rec(0, len(self.__dataset)) - return tree, self.__dataset, self.__distances, self.__indices - - def _build_rec(self, start, end): - mid = _mid(start, end) - self._update(start, end) - v_point = self._get_point(start) - quickselect(self.__distances, start + 1, end, mid, self.__indices) - v_radius = self._get_dist(mid) - self.__distances[start] = v_radius - if (end - start <= 2 * self.__leaf_capacity) or ( - v_radius <= self.__leaf_radius - ): - left = _Leaf(start + 1, mid) - right = _Leaf(mid, end) - else: - left = self._build_rec(start + 1, mid) - right = self._build_rec(mid, end) - return _Node(v_radius, v_point, left, right) - - def ball_search(self, point, eps, inclusive=True): - return self._BallSearch(self, point, eps, inclusive).search() - - class _BallSearch: - - def __init__(self, vpt, point, eps, inclusive=True): - self.__tree = vpt._get_tree() - self.__dataset = vpt._get_dataset() - self.__indices = vpt._get_indices() - self.__distance = vpt._get_distance() - self.__point = point - self.__eps = eps - self.__inclusive = inclusive - self.__result = [] - - def search(self): - self.__result.clear() - self._search_rec(self.__tree) - return self.__result - - def _get_points(self, s, e): - for x_index in self.__indices[s:e]: - yield self.__dataset[x_index] - - def _inside(self, dist): - if self.__inclusive: - return dist <= self.__eps - return dist < self.__eps - - def _search_rec(self, tree): - if tree.is_terminal(): - start, end = tree.get_bounds() - for x in self._get_points(start, end): - dist = self.__distance(self.__point, x) - if self._inside(dist): - self.__result.append(x) - else: - v_radius, v_point = tree.get_ball() - dist = self.__distance(v_point, self.__point) - if self._inside(dist): - self.__result.append(v_point) - if dist <= v_radius: - fst, snd = tree.get_left(), tree.get_right() - else: - fst, snd = tree.get_right(), tree.get_left() - self._search_rec(fst) - if abs(dist - v_radius) <= self.__eps: - self._search_rec(snd) - - def knn_search(self, point, k): - return self._KnnSearch(self, point, k).search() - - class _KnnSearch: - - def __init__(self, vpt, point, neighbors): - self.__tree = vpt._get_tree() - self.__dataset = vpt._get_dataset() - self.__indices = vpt._get_indices() - self.__distance = vpt._get_distance() - self.__point = point - self.__neighbors = neighbors - self.__items = MaxHeap() - - def _add(self, dist, x): - self.__items.add(dist, x) - if len(self.__items) > self.__neighbors: - self.__items.pop() - - def _get_points(self, s, e): - for x_index in self.__indices[s:e]: - yield self.__dataset[x_index] - - def _get_items(self): - while len(self.__items) > self.__neighbors: - self.__items.pop() - return [x for (_, x) in self.__items] - - def _get_radius(self): - if len(self.__items) < self.__neighbors: - return float("inf") - furthest_dist, _ = self.__items.top() - return furthest_dist - - def search(self): - self._search_rec(self.__tree) - return self._get_items() - - def _search_rec(self, tree): - if tree.is_terminal(): - start, end = tree.get_bounds() - for x in self._get_points(start, end): - dist = self.__distance(self.__point, x) - if dist < self._get_radius(): - self._add(dist, x) - else: - v_radius, v_point = tree.get_ball() - dist = self.__distance(v_point, self.__point) - if dist < self._get_radius(): - self._add(dist, v_point) - if dist <= v_radius: - fst, snd = tree.get_left(), tree.get_right() - else: - fst, snd = tree.get_right(), tree.get_left() - self._search_rec(fst) - if abs(dist - v_radius) <= self._get_radius(): - self._search_rec(snd) - - -class _Node: - - def __init__(self, radius, center, left, right): - self.__radius = radius - self.__center = center - self.__left = left - self.__right = right - - def get_ball(self): - return self.__radius, self.__center - - def is_terminal(self): - return False - - def get_left(self): - return self.__left - - def get_right(self): - return self.__right - - -class _Leaf: - - def __init__(self, start, end): - self.__start = start - self.__end = end - - def get_bounds(self): - return self.__start, self.__end - - def is_terminal(self): - return True diff --git a/src/tdamapper/utils/vptree_hier/__init__.py b/src/tdamapper/utils/vptree_hier/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tdamapper/utils/vptree_hier/ball_search.py b/src/tdamapper/utils/vptree_hier/ball_search.py new file mode 100644 index 00000000..1f519a39 --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/ball_search.py @@ -0,0 +1,40 @@ +class BallSearch: + + def __init__(self, vpt, point, eps, inclusive=True): + self.__tree = vpt._get_tree() + self._arr = vpt._get_arr() + self.__distance = vpt._get_distance() + self.__point = point + self.__eps = eps + self.__inclusive = inclusive + self.__result = [] + + def search(self): + self.__result.clear() + self._search_rec(self.__tree) + return self.__result + + def _inside(self, dist): + if self.__inclusive: + return dist <= self.__eps + return dist < self.__eps + + def _search_rec(self, tree): + if tree.is_terminal(): + start, end = tree.get_bounds() + for x in self._arr.get_points(start, end): + dist = self.__distance(self.__point, x) + if self._inside(dist): + self.__result.append(x) + else: + v_radius, v_point = tree.get_ball() + dist = self.__distance(v_point, self.__point) + if self._inside(dist): + self.__result.append(v_point) + if dist <= v_radius: + fst, snd = tree.get_left(), tree.get_right() + else: + fst, snd = tree.get_right(), tree.get_left() + self._search_rec(fst) + if abs(dist - v_radius) <= self.__eps: + self._search_rec(snd) diff --git a/src/tdamapper/utils/vptree_hier/builder.py b/src/tdamapper/utils/vptree_hier/builder.py new file mode 100644 index 00000000..944b93e7 --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/builder.py @@ -0,0 +1,85 @@ +from random import randrange + +import numpy as np + +from tdamapper.utils.quickselect import quickselect, swap_all +from tdamapper.utils.vptree_hier.common import Leaf, Node, VPArray, _mid + + +class Builder: + + def __init__(self, vpt, X): + self.__distance = vpt._get_distance() + + dataset = [x for x in X] + indices = np.array([i for i in range(len(dataset))]) + distances = np.array([0.0 for _ in X]) + self._arr = VPArray(dataset, distances, indices) + + self.__leaf_capacity = vpt.get_leaf_capacity() + self.__leaf_radius = vpt.get_leaf_radius() + pivoting = vpt.get_pivoting() + self.__pivoting = self._pivoting_disabled + if pivoting == "random": + self.__pivoting = self._pivoting_random + elif pivoting == "furthest": + self.__pivoting = self._pivoting_furthest + + def _pivoting_disabled(self, start, end): + pass + + def _pivoting_random(self, start, end): + if end <= start: + return + pivot = randrange(start, end) + if pivot > start: + self._arr.swap(start, pivot) + + def _furthest(self, start, end, i): + furthest_dist = 0.0 + furthest = start + i_point = self._arr.get_point(i) + for j in range(start, end): + j_point = self._arr.get_point(j) + j_dist = self.__distance(i_point, j_point) + if j_dist > furthest_dist: + furthest = j + furthest_dist = j_dist + return furthest + + def _pivoting_furthest(self, start, end): + if end <= start: + return + rnd = randrange(start, end) + furthest_rnd = self._furthest(start, end, rnd) + furthest = self._furthest(start, end, furthest_rnd) + if furthest > start: + self._arr.swap(start, furthest) + + def _update(self, start, end): + self.__pivoting(start, end) + v_point = self._arr.get_point(start) + for i in range(start + 1, end): + point = self._arr.get_point(i) + self._arr.set_distance(i, self.__distance(v_point, point)) + + def build(self): + tree = self._build_rec(0, self._arr.size()) + return tree, self._arr + + def _build_rec(self, start, end): + mid = _mid(start, end) + self._update(start, end) + v_point = self._arr.get_point(start) + self._arr.partition(start + 1, end, mid) + v_radius = self._arr.get_distance(mid) + self._arr.set_distance(start, v_radius) + if (end - start <= 2 * self.__leaf_capacity) or ( + v_radius <= self.__leaf_radius + ): + left = Leaf(start + 1, mid) + right = Leaf(mid, end) + else: + left = self._build_rec(start + 1, mid) + right = self._build_rec(mid, end) + return Node(v_radius, v_point, left, right) diff --git a/src/tdamapper/utils/vptree_hier/common.py b/src/tdamapper/utils/vptree_hier/common.py new file mode 100644 index 00000000..749c8fab --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/common.py @@ -0,0 +1,69 @@ +from tdamapper.utils.quickselect import quickselect, swap_all + + +def _mid(start, end): + return (start + end) // 2 + + +class VPArray: + + def __init__(self, dataset, distances, indices): + self._dataset = dataset + self._distances = distances + self._indices = indices + + def size(self): + return len(self._dataset) + + def get_point(self, i): + return self._dataset[self._indices[i]] + + def get_points(self, s, e): + for x_index in self._indices[s:e]: + yield self._dataset[x_index] + + def get_distance(self, i): + return self._distances[i] + + def set_distance(self, i, dist): + self._distances[i] = dist + + def swap(self, i, j): + swap_all(self._distances, i, j, self._indices) + + def partition(self, s, e, k): + quickselect(self._distances, s, e, k, self._indices) + + +class Node: + + def __init__(self, radius, center, left, right): + self.__radius = radius + self.__center = center + self.__left = left + self.__right = right + + def get_ball(self): + return self.__radius, self.__center + + def is_terminal(self): + return False + + def get_left(self): + return self.__left + + def get_right(self): + return self.__right + + +class Leaf: + + def __init__(self, start, end): + self.__start = start + self.__end = end + + def get_bounds(self): + return self.__start, self.__end + + def is_terminal(self): + return True diff --git a/src/tdamapper/utils/vptree_hier/knn_search.py b/src/tdamapper/utils/vptree_hier/knn_search.py new file mode 100644 index 00000000..a6531414 --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/knn_search.py @@ -0,0 +1,52 @@ +from tdamapper.utils.heap import MaxHeap + + +class KnnSearch: + + def __init__(self, vpt, point, neighbors): + self.__tree = vpt._get_tree() + self._arr = vpt._get_arr() + self.__distance = vpt._get_distance() + self.__point = point + self.__neighbors = neighbors + self.__items = MaxHeap() + + def _add(self, dist, x): + self.__items.add(dist, x) + if len(self.__items) > self.__neighbors: + self.__items.pop() + + def _get_items(self): + while len(self.__items) > self.__neighbors: + self.__items.pop() + return [x for (_, x) in self.__items] + + def _get_radius(self): + if len(self.__items) < self.__neighbors: + return float("inf") + furthest_dist, _ = self.__items.top() + return furthest_dist + + def search(self): + self._search_rec(self.__tree) + return self._get_items() + + def _search_rec(self, tree): + if tree.is_terminal(): + start, end = tree.get_bounds() + for x in self._arr.get_points(start, end): + dist = self.__distance(self.__point, x) + if dist < self._get_radius(): + self._add(dist, x) + else: + v_radius, v_point = tree.get_ball() + dist = self.__distance(v_point, self.__point) + if dist < self._get_radius(): + self._add(dist, v_point) + if dist <= v_radius: + fst, snd = tree.get_left(), tree.get_right() + else: + fst, snd = tree.get_right(), tree.get_left() + self._search_rec(fst) + if abs(dist - v_radius) <= self._get_radius(): + self._search_rec(snd) diff --git a/src/tdamapper/utils/vptree_hier/vptree.py b/src/tdamapper/utils/vptree_hier/vptree.py new file mode 100644 index 00000000..cd0f0172 --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/vptree.py @@ -0,0 +1,54 @@ +from tdamapper.utils.metrics import get_metric +from tdamapper.utils.vptree_hier.ball_search import BallSearch +from tdamapper.utils.vptree_hier.builder import Builder +from tdamapper.utils.vptree_hier.knn_search import KnnSearch + + +class VPTree: + + def __init__( + self, + X, + metric="euclidean", + metric_params=None, + leaf_capacity=1, + leaf_radius=0.0, + pivoting=None, + ): + self.__metric = metric + self.__metric_params = metric_params + self.__leaf_capacity = leaf_capacity + self.__leaf_radius = leaf_radius + self.__pivoting = pivoting + self.__tree, self._arr = Builder(self, X).build() + + def get_metric(self): + return self.__metric + + def get_metric_params(self): + return self.__metric_params + + def get_leaf_capacity(self): + return self.__leaf_capacity + + def get_leaf_radius(self): + return self.__leaf_radius + + def get_pivoting(self): + return self.__pivoting + + def _get_tree(self): + return self.__tree + + def _get_arr(self): + return self._arr + + def _get_distance(self): + metric_params = self.__metric_params or {} + return get_metric(self.__metric, **metric_params) + + def ball_search(self, point, eps, inclusive=True): + return BallSearch(self, point, eps, inclusive).search() + + def knn_search(self, point, k): + return KnnSearch(self, point, k).search() diff --git a/tests/test_bench_cover.py b/tests/test_bench_cover.py index bfcb19ab..0f94ee30 100644 --- a/tests/test_bench_cover.py +++ b/tests/test_bench_cover.py @@ -7,8 +7,8 @@ from tdamapper._common import profile from tdamapper.utils.metrics import euclidean -from tdamapper.utils.vptree_flat import VPTree as FVPT -from tdamapper.utils.vptree_hier import VPTree as HVPT +from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT +from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT from tests.ball_tree import SkBallTree from tests.setup_logging import setup_logging diff --git a/tests/test_bench_vptree.py b/tests/test_bench_vptree.py index faef9cf5..df8ff2e1 100644 --- a/tests/test_bench_vptree.py +++ b/tests/test_bench_vptree.py @@ -6,8 +6,8 @@ from sklearn.datasets import load_breast_cancer, load_digits, load_iris from tdamapper.utils.metrics import euclidean, get_metric -from tdamapper.utils.vptree_flat import VPTree as FVPT -from tdamapper.utils.vptree_hier import VPTree as HVPT +from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT +from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT from tests.ball_tree import SkBallTree from tests.setup_logging import setup_logging diff --git a/tests/test_unit_knn.py b/tests/test_unit_knn.py index 9dffe37c..a38e1c9e 100644 --- a/tests/test_unit_knn.py +++ b/tests/test_unit_knn.py @@ -4,7 +4,7 @@ from tdamapper.cover import KNNCover from tdamapper.utils.metrics import euclidean -from tdamapper.utils.vptree_flat import VPTree +from tdamapper.utils.vptree_flat.vptree import VPTree X = np.array( [ @@ -126,9 +126,10 @@ def test_vptree_simple(self): self.assertTrue(0.0 in dists) def check_vptree(self, vpt): - data = vpt._get_dataset() - distances = vpt._get_distances() - indices = vpt._get_indices() + arr = vpt._get_arr() + data = arr._dataset + distances = arr._distances + indices = arr._indices dist = vpt._get_distance() leaf_capacity = vpt.get_leaf_capacity() diff --git a/tests/test_unit_vptree.py b/tests/test_unit_vptree.py index 2b31c152..e50ec1e5 100644 --- a/tests/test_unit_vptree.py +++ b/tests/test_unit_vptree.py @@ -4,8 +4,8 @@ import numpy as np from tdamapper.utils.metrics import get_metric -from tdamapper.utils.vptree_flat import VPTree as FVPT -from tdamapper.utils.vptree_hier import VPTree as HVPT +from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT +from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT from tests.ball_tree import SkBallTree distance = "euclidean" From 262b75b0ae3fb18ec776b95edf9171de50c3c0a4 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Thu, 22 May 2025 22:06:13 +0200 Subject: [PATCH 05/10] Removed variadic args to support numba on python < 3.10 --- src/tdamapper/utils/quickselect.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tdamapper/utils/quickselect.py b/src/tdamapper/utils/quickselect.py index d04c4e3f..18ff5e7e 100755 --- a/src/tdamapper/utils/quickselect.py +++ b/src/tdamapper/utils/quickselect.py @@ -16,26 +16,26 @@ def swap_all(arr, i, j, extra1=None, extra2=None): @njit -def partition(data, start, end, p_ord, *extra): +def partition(data, start, end, p_ord, extra1=None, extra2=None): higher = start for j in range(start, end): j_ord = data[j] if j_ord < p_ord: - swap_all(data, higher, j, *extra) + swap_all(data, higher, j, extra1, extra2) higher += 1 return higher @njit -def quickselect(data, start, end, k, *extra): +def quickselect(data, start, end, k, extra1=None, extra2=None): if (k < start) or (k >= end): return start_, end_, higher = start, end, None while higher != k + 1: p = data[k] - swap_all(data, start_, k, *extra) - higher = partition(data, start_ + 1, end_, p, *extra) - swap_all(data, start_, higher - 1, *extra) + swap_all(data, start_, k, extra1, extra2) + higher = partition(data, start_ + 1, end_, p, extra1, extra2) + swap_all(data, start_, higher - 1, extra1, extra2) if k <= higher - 1: end_ = higher else: From 8c8313f35f80e5999b5135921d002e93ec96909c Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Thu, 22 May 2025 22:42:50 +0200 Subject: [PATCH 06/10] Added flags for compatibility of numba with python 3.8 --- src/tdamapper/utils/quickselect.py | 79 ++++++++++++++++++++++++++---- 1 file changed, 70 insertions(+), 9 deletions(-) diff --git a/src/tdamapper/utils/quickselect.py b/src/tdamapper/utils/quickselect.py index 18ff5e7e..8cfe72fd 100755 --- a/src/tdamapper/utils/quickselect.py +++ b/src/tdamapper/utils/quickselect.py @@ -1,5 +1,8 @@ +import numpy as np from numba import njit +_ARR = np.zeros(1) + @njit def swap(arr, i, j): @@ -7,36 +10,94 @@ def swap(arr, i, j): @njit -def swap_all(arr, i, j, extra1=None, extra2=None): +def _swap_all(arr, i, j, extra1, use_extra1, extra2, use_extra2): swap(arr, i, j) - if extra1 is not None: + if use_extra1: swap(extra1, i, j) - if extra2 is not None: + if use_extra2: swap(extra2, i, j) @njit -def partition(data, start, end, p_ord, extra1=None, extra2=None): +def _partition(data, start, end, p_ord, extra1, use_extra1, extra2, use_extra2): higher = start for j in range(start, end): j_ord = data[j] if j_ord < p_ord: - swap_all(data, higher, j, extra1, extra2) + _swap_all(data, higher, j, extra1, use_extra1, extra2, use_extra2) higher += 1 return higher @njit -def quickselect(data, start, end, k, extra1=None, extra2=None): +def _quickselect(data, start, end, k, extra1, use_extra1, extra2, use_extra2): if (k < start) or (k >= end): return start_, end_, higher = start, end, None while higher != k + 1: p = data[k] - swap_all(data, start_, k, extra1, extra2) - higher = partition(data, start_ + 1, end_, p, extra1, extra2) - swap_all(data, start_, higher - 1, extra1, extra2) + _swap_all(data, start_, k, extra1, use_extra1, extra2, use_extra2) + higher = _partition( + data, start_ + 1, end_, p, extra1, use_extra1, extra2, use_extra2 + ) + _swap_all(data, start_, higher - 1, extra1, use_extra1, extra2, use_extra2) if k <= higher - 1: end_ = higher else: start_ = higher + + +def _to_array(extra1=None, extra2=None): + extra1_arr = _ARR if extra1 is None else extra1 + extra2_arr = _ARR if extra2 is None else extra2 + return extra1_arr, extra2_arr + + +def _use_array(extra1=None, extra2=None): + use_extra1 = extra1 is not None + use_extra2 = extra2 is not None + return use_extra1, use_extra2 + + +def swap_all(arr, i, j, extra1=None, extra2=None): + extra1_arr, extra2_arr = _to_array(extra1, extra2) + use_extra1, use_extra2 = _use_array(extra1, extra2) + _swap_all( + arr, + i, + j, + extra1=extra1_arr, + use_extra1=use_extra1, + extra2=extra2_arr, + use_extra2=use_extra2, + ) + + +def partition(data, start, end, p_ord, extra1=None, extra2=None): + extra1_arr, extra2_arr = _to_array(extra1, extra2) + use_extra1, use_extra2 = _use_array(extra1, extra2) + return _partition( + data, + start, + end, + p_ord, + extra1=extra1_arr, + use_extra1=use_extra1, + extra2=extra2_arr, + use_extra2=use_extra2, + ) + + +def quickselect(data, start, end, k, extra1=None, extra2=None): + extra1_arr, extra2_arr = _to_array(extra1, extra2) + use_extra1, use_extra2 = _use_array(extra1, extra2) + _quickselect( + data, + start, + end, + k, + extra1=extra1_arr, + use_extra1=use_extra1, + extra2=extra2_arr, + use_extra2=use_extra2, + ) From fa915da80a8bf7ef61826c2495ae93976b3be039 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Thu, 22 May 2025 22:48:30 +0200 Subject: [PATCH 07/10] Removed None for compatibility with numba --- src/tdamapper/utils/quickselect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tdamapper/utils/quickselect.py b/src/tdamapper/utils/quickselect.py index 8cfe72fd..a61730f1 100755 --- a/src/tdamapper/utils/quickselect.py +++ b/src/tdamapper/utils/quickselect.py @@ -33,7 +33,7 @@ def _partition(data, start, end, p_ord, extra1, use_extra1, extra2, use_extra2): def _quickselect(data, start, end, k, extra1, use_extra1, extra2, use_extra2): if (k < start) or (k >= end): return - start_, end_, higher = start, end, None + start_, end_, higher = start, end, -1 while higher != k + 1: p = data[k] _swap_all(data, start_, k, extra1, use_extra1, extra2, use_extra2) From e8603cd15dc7584b1b5d9ba7745364ed8388957e Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Thu, 22 May 2025 23:08:02 +0200 Subject: [PATCH 08/10] Skipped njitted functions from coverage count --- src/tdamapper/utils/_metrics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/tdamapper/utils/_metrics.py b/src/tdamapper/utils/_metrics.py index 5e805bd0..7bd09cdb 100644 --- a/src/tdamapper/utils/_metrics.py +++ b/src/tdamapper/utils/_metrics.py @@ -2,27 +2,27 @@ from numba import njit -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def euclidean(x, y): return np.linalg.norm(x - y) -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def manhattan(x, y): return np.linalg.norm(x - y, ord=1) -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def chebyshev(x, y): return np.linalg.norm(x - y, ord=np.inf) -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def minkowski(p, x, y): return np.linalg.norm(x - y, ord=p) -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def cosine(x, y): xy = np.dot(x, y) xx = np.linalg.norm(x) From 1bf67c9a3f8e34e17ffd976eb30e7ba8cc3d9237 Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Thu, 22 May 2025 23:08:21 +0200 Subject: [PATCH 09/10] Skipped njitted functions from coverage --- src/tdamapper/utils/quickselect.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tdamapper/utils/quickselect.py b/src/tdamapper/utils/quickselect.py index a61730f1..1fad5fcd 100755 --- a/src/tdamapper/utils/quickselect.py +++ b/src/tdamapper/utils/quickselect.py @@ -4,12 +4,12 @@ _ARR = np.zeros(1) -@njit +@njit # pragma: no cover def swap(arr, i, j): arr[i], arr[j] = arr[j], arr[i] -@njit +@njit # pragma: no cover def _swap_all(arr, i, j, extra1, use_extra1, extra2, use_extra2): swap(arr, i, j) if use_extra1: @@ -18,7 +18,7 @@ def _swap_all(arr, i, j, extra1, use_extra1, extra2, use_extra2): swap(extra2, i, j) -@njit +@njit # pragma: no cover def _partition(data, start, end, p_ord, extra1, use_extra1, extra2, use_extra2): higher = start for j in range(start, end): @@ -29,7 +29,7 @@ def _partition(data, start, end, p_ord, extra1, use_extra1, extra2, use_extra2): return higher -@njit +@njit # pragma: no cover def _quickselect(data, start, end, k, extra1, use_extra1, extra2, use_extra2): if (k < start) or (k >= end): return From 219a07041152aa6dc639434ef6f60ff5c0bedd6f Mon Sep 17 00:00:00 2001 From: Luca Simi Date: Thu, 22 May 2025 23:17:20 +0200 Subject: [PATCH 10/10] Minor improvements --- app/streamlit_app.py | 22 +++++++++++++++------- tests/test_bench_metrics.py | 6 +++++- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/app/streamlit_app.py b/app/streamlit_app.py index 1c1032ae..eb0c7731 100644 --- a/app/streamlit_app.py +++ b/app/streamlit_app.py @@ -211,7 +211,11 @@ def mode(arr): def quantile(q): - return lambda agg: np.nanquantile(agg, q=q) + + def _quantile_q(agg): + return np.nanquantile(agg, q=q) + + return _quantile_q @st.cache_data @@ -565,12 +569,12 @@ def plot_agg_input_section(): return agg, agg_name +def _hash_networkx_graph(graph): + return _encode_graph(_get_graph_no_attribs(graph)) + + @st.cache_data( - hash_funcs={ - "networkx.classes.graph.Graph": lambda g: _encode_graph( - _get_graph_no_attribs(g) - ) - }, + hash_funcs={"networkx.classes.graph.Graph": _hash_networkx_graph}, show_spinner="Generating Mapper Layout", ) def compute_mapper_plot(mapper_graph, dim, seed, iterations): @@ -610,8 +614,12 @@ def mapper_plot_section(mapper_graph): return mapper_plot +def _hash_mapper_plot(mapper_plot): + return mapper_plot.positions + + @st.cache_data( - hash_funcs={"tdamapper.plot.MapperPlot": lambda mp: mp.positions}, + hash_funcs={"tdamapper.plot.MapperPlot": _hash_mapper_plot}, show_spinner="Rendering Mapper", ) def compute_mapper_fig(mapper_plot, colors, node_size, cmap, _agg, agg_name): diff --git a/tests/test_bench_metrics.py b/tests/test_bench_metrics.py index 5c7e1751..f5644e9f 100644 --- a/tests/test_bench_metrics.py +++ b/tests/test_bench_metrics.py @@ -47,7 +47,11 @@ def eval_dist(X, d): def run_dist_bench(X, d): eval_dist(X, d) - return timeit.timeit(lambda: eval_dist(X, d), number=200) + + def _eval(): + eval_dist(X, d) + + return timeit.timeit(_eval, number=200) def run_euclidean_bench(X):