diff --git a/app/streamlit_app.py b/app/streamlit_app.py index 1c1032ae..eb0c7731 100644 --- a/app/streamlit_app.py +++ b/app/streamlit_app.py @@ -211,7 +211,11 @@ def mode(arr): def quantile(q): - return lambda agg: np.nanquantile(agg, q=q) + + def _quantile_q(agg): + return np.nanquantile(agg, q=q) + + return _quantile_q @st.cache_data @@ -565,12 +569,12 @@ def plot_agg_input_section(): return agg, agg_name +def _hash_networkx_graph(graph): + return _encode_graph(_get_graph_no_attribs(graph)) + + @st.cache_data( - hash_funcs={ - "networkx.classes.graph.Graph": lambda g: _encode_graph( - _get_graph_no_attribs(g) - ) - }, + hash_funcs={"networkx.classes.graph.Graph": _hash_networkx_graph}, show_spinner="Generating Mapper Layout", ) def compute_mapper_plot(mapper_graph, dim, seed, iterations): @@ -610,8 +614,12 @@ def mapper_plot_section(mapper_graph): return mapper_plot +def _hash_mapper_plot(mapper_plot): + return mapper_plot.positions + + @st.cache_data( - hash_funcs={"tdamapper.plot.MapperPlot": lambda mp: mp.positions}, + hash_funcs={"tdamapper.plot.MapperPlot": _hash_mapper_plot}, show_spinner="Rendering Mapper", ) def compute_mapper_fig(mapper_plot, colors, node_size, cmap, _agg, agg_name): diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 9a727130..1300eb24 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -12,6 +12,10 @@ from tdamapper.core import TrivialClustering +def _identity(x): + return x + + def _segment(cardinality, dimension, noise=0.1, start=None, end=None): if start is None: start = np.zeros(dimension) @@ -70,7 +74,7 @@ def fit(self, X, y=None): def run_gm(X, n, p): t0 = time.time() pipe = gm.make_mapper_pipeline( - filter_func=lambda x: x, + filter_func=_identity, cover=gm.CubicalCover(n_intervals=n, overlap_frac=p), clusterer=TrivialEstimator(), ) diff --git a/src/tdamapper/_common.py b/src/tdamapper/_common.py index 5daabca8..4a66980c 100644 --- a/src/tdamapper/_common.py +++ b/src/tdamapper/_common.py @@ -2,6 +2,9 @@ This module provides common functionalities for internal use. """ +import cProfile +import io +import pstats import warnings import numpy as np @@ -147,3 +150,22 @@ def clone(obj): obj_noargs = type(obj)() obj_noargs.set_params(**params) return obj_noargs + + +def profile(n_lines=10): + def decorator(func): + def wrapper(*args, **kwargs): + profiler = cProfile.Profile() + profiler.enable() + result = func(*args, **kwargs) + profiler.disable() + + s = io.StringIO() + ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative") + ps.print_stats(n_lines) + print(s.getvalue()) + return result + + return wrapper + + return decorator diff --git a/src/tdamapper/utils/_metrics.py b/src/tdamapper/utils/_metrics.py index 5e805bd0..7bd09cdb 100644 --- a/src/tdamapper/utils/_metrics.py +++ b/src/tdamapper/utils/_metrics.py @@ -2,27 +2,27 @@ from numba import njit -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def euclidean(x, y): return np.linalg.norm(x - y) -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def manhattan(x, y): return np.linalg.norm(x - y, ord=1) -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def chebyshev(x, y): return np.linalg.norm(x - y, ord=np.inf) -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def minkowski(p, x, y): return np.linalg.norm(x - y, ord=p) -@njit(fastmath=True) +@njit(fastmath=True) # pragma: no cover def cosine(x, y): xy = np.dot(x, y) xx = np.linalg.norm(x) diff --git a/src/tdamapper/utils/metrics.py b/src/tdamapper/utils/metrics.py index 84909cbc..1c0fa74c 100644 --- a/src/tdamapper/utils/metrics.py +++ b/src/tdamapper/utils/metrics.py @@ -114,7 +114,11 @@ def minkowski(p): return euclidean() elif np.isinf(p): return chebyshev() - return lambda x, y: _metrics.minkowski(p, x, y) + + def dist(x, y): + return _metrics.minkowski(p, x, y) + + return dist def cosine(): diff --git a/src/tdamapper/utils/quickselect.py b/src/tdamapper/utils/quickselect.py index 93bc528f..1fad5fcd 100755 --- a/src/tdamapper/utils/quickselect.py +++ b/src/tdamapper/utils/quickselect.py @@ -1,55 +1,103 @@ -def __swap(arr, i, j): +import numpy as np +from numba import njit + +_ARR = np.zeros(1) + + +@njit # pragma: no cover +def swap(arr, i, j): arr[i], arr[j] = arr[j], arr[i] -def partition(data, start, end, p_ord): +@njit # pragma: no cover +def _swap_all(arr, i, j, extra1, use_extra1, extra2, use_extra2): + swap(arr, i, j) + if use_extra1: + swap(extra1, i, j) + if use_extra2: + swap(extra2, i, j) + + +@njit # pragma: no cover +def _partition(data, start, end, p_ord, extra1, use_extra1, extra2, use_extra2): higher = start for j in range(start, end): - j_ord, _ = data[j] + j_ord = data[j] if j_ord < p_ord: - __swap(data, higher, j) + _swap_all(data, higher, j, extra1, use_extra1, extra2, use_extra2) higher += 1 return higher -def quickselect(data, start, end, k): +@njit # pragma: no cover +def _quickselect(data, start, end, k, extra1, use_extra1, extra2, use_extra2): if (k < start) or (k >= end): return - start_, end_, higher = start, end, None + start_, end_, higher = start, end, -1 while higher != k + 1: - p, _ = data[k] - __swap(data, start_, k) - higher = partition(data, start_ + 1, end_, p) - __swap(data, start_, higher - 1) + p = data[k] + _swap_all(data, start_, k, extra1, use_extra1, extra2, use_extra2) + higher = _partition( + data, start_ + 1, end_, p, extra1, use_extra1, extra2, use_extra2 + ) + _swap_all(data, start_, higher - 1, extra1, use_extra1, extra2, use_extra2) if k <= higher - 1: end_ = higher else: start_ = higher -def partition_tuple(data_ord, data_arr, start, end, p_ord): - higher = start - for j in range(start, end): - j_ord = data_ord[j] - if j_ord < p_ord: - __swap(data_arr, higher, j) - __swap(data_ord, higher, j) - higher += 1 - return higher +def _to_array(extra1=None, extra2=None): + extra1_arr = _ARR if extra1 is None else extra1 + extra2_arr = _ARR if extra2 is None else extra2 + return extra1_arr, extra2_arr -def quickselect_tuple(data_ord, data_arr, start, end, k): - if (k < start) or (k >= end): - return - start_, end_, higher = start, end, None - while higher != k + 1: - p_ord = data_ord[k] - __swap(data_arr, start_, k) - __swap(data_ord, start_, k) - higher = partition_tuple(data_ord, data_arr, start_ + 1, end_, p_ord) - __swap(data_arr, start_, higher - 1) - __swap(data_ord, start_, higher - 1) - if k <= higher - 1: - end_ = higher - else: - start_ = higher +def _use_array(extra1=None, extra2=None): + use_extra1 = extra1 is not None + use_extra2 = extra2 is not None + return use_extra1, use_extra2 + + +def swap_all(arr, i, j, extra1=None, extra2=None): + extra1_arr, extra2_arr = _to_array(extra1, extra2) + use_extra1, use_extra2 = _use_array(extra1, extra2) + _swap_all( + arr, + i, + j, + extra1=extra1_arr, + use_extra1=use_extra1, + extra2=extra2_arr, + use_extra2=use_extra2, + ) + + +def partition(data, start, end, p_ord, extra1=None, extra2=None): + extra1_arr, extra2_arr = _to_array(extra1, extra2) + use_extra1, use_extra2 = _use_array(extra1, extra2) + return _partition( + data, + start, + end, + p_ord, + extra1=extra1_arr, + use_extra1=use_extra1, + extra2=extra2_arr, + use_extra2=use_extra2, + ) + + +def quickselect(data, start, end, k, extra1=None, extra2=None): + extra1_arr, extra2_arr = _to_array(extra1, extra2) + use_extra1, use_extra2 = _use_array(extra1, extra2) + _quickselect( + data, + start, + end, + k, + extra1=extra1_arr, + use_extra1=use_extra1, + extra2=extra2_arr, + use_extra2=use_extra2, + ) diff --git a/src/tdamapper/utils/vptree.py b/src/tdamapper/utils/vptree.py index a85ec3b9..3b3a62f6 100644 --- a/src/tdamapper/utils/vptree.py +++ b/src/tdamapper/utils/vptree.py @@ -2,8 +2,8 @@ A module for fast knn and range searches, depending only on a given metric """ -from tdamapper.utils.vptree_flat import VPTree as FVPT -from tdamapper.utils.vptree_hier import VPTree as HVPT +from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT +from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT class VPTree: diff --git a/src/tdamapper/utils/vptree_flat.py b/src/tdamapper/utils/vptree_flat.py deleted file mode 100755 index 3af82584..00000000 --- a/src/tdamapper/utils/vptree_flat.py +++ /dev/null @@ -1,259 +0,0 @@ -from random import randrange - -from tdamapper.utils.heap import MaxHeap -from tdamapper.utils.metrics import get_metric - - -def _swap(arr, i, j): - arr[i], arr[j] = arr[j], arr[i] - - -def _mid(start, end): - return (start + end) // 2 - - -def _partition(data, start, end, p_ord): - higher = start - for j in range(start, end): - j_ord, _, _ = data[j] - if j_ord < p_ord: - _swap(data, higher, j) - higher += 1 - return higher - - -def _quickselect(data, start, end, k): - if (k < start) or (k >= end): - return - start_, end_, higher = start, end, None - while higher != k + 1: - p, _, _ = data[k] - _swap(data, start_, k) - higher = _partition(data, start_ + 1, end_, p) - _swap(data, start_, higher - 1) - if k <= higher - 1: - end_ = higher - else: - start_ = higher - - -class VPTree: - - def __init__( - self, - X, - metric="euclidean", - metric_params=None, - leaf_capacity=1, - leaf_radius=0.0, - pivoting=None, - ): - self.__metric = metric - self.__metric_params = metric_params - self.__leaf_capacity = leaf_capacity - self.__leaf_radius = leaf_radius - self.__pivoting = pivoting - self.__dataset = self._Build(self, X).build() - - def get_metric(self): - return self.__metric - - def get_metric_params(self): - return self.__metric_params - - def get_leaf_capacity(self): - return self.__leaf_capacity - - def get_leaf_radius(self): - return self.__leaf_radius - - def get_pivoting(self): - return self.__pivoting - - def _get_dataset(self): - return self.__dataset - - def _get_distance(self): - metric_params = self.__metric_params or {} - return get_metric(self.__metric, **metric_params) - - class _Build: - - def __init__(self, vpt, X): - self.__distance = vpt._get_distance() - self.__dataset = [(0.0, x, False) for x in X] - self.__leaf_capacity = vpt.get_leaf_capacity() - self.__leaf_radius = vpt.get_leaf_radius() - pivoting = vpt.get_pivoting() - self.__pivoting = self._pivoting_disabled - if pivoting == "random": - self.__pivoting = self._pivoting_random - elif pivoting == "furthest": - self.__pivoting = self._pivoting_furthest - - def _pivoting_disabled(self, start, end): - pass - - def _pivoting_random(self, start, end): - if end <= start: - return - pivot = randrange(start, end) - if pivot > start: - _swap(self.__dataset, start, pivot) - - def _furthest(self, start, end, i): - furthest_dist = 0.0 - furthest = start - _, i_point, _ = self.__dataset[i] - for j in range(start, end): - _, j_point, _ = self.__dataset[j] - j_dist = self.__distance(i_point, j_point) - if j_dist > furthest_dist: - furthest = j - furthest_dist = j_dist - return furthest - - def _pivoting_furthest(self, start, end): - if end <= start: - return - rnd = randrange(start, end) - furthest_rnd = self._furthest(start, end, rnd) - furthest = self._furthest(start, end, furthest_rnd) - if furthest > start: - _swap(self.__dataset, start, furthest) - - def _update(self, start, end): - self.__pivoting(start, end) - _, v_point, is_terminal = self.__dataset[start] - for i in range(start + 1, end): - _, point, _ = self.__dataset[i] - self.__dataset[i] = self.__distance(v_point, point), point, is_terminal - - def build(self): - self._build_iter() - return self.__dataset - - def _build_iter(self): - stack = [(0, len(self.__dataset))] - while stack: - start, end = stack.pop() - mid = _mid(start, end) - self._update(start, end) - _, v_point, _ = self.__dataset[start] - _quickselect(self.__dataset, start + 1, end, mid) - v_radius, _, _ = self.__dataset[mid] - if (end - start > 2 * self.__leaf_capacity) and ( - v_radius > self.__leaf_radius - ): - self.__dataset[start] = (v_radius, v_point, False) - stack.append((mid, end)) - stack.append((start + 1, mid)) - else: - self.__dataset[start] = (v_radius, v_point, True) - - def ball_search(self, point, eps, inclusive=True): - return self._BallSearch(self, point, eps, inclusive).search() - - class _BallSearch: - - def __init__(self, vpt, point, eps, inclusive=True): - self.__dataset = vpt._get_dataset() - self.__distance = vpt._get_distance() - self.__point = point - self.__eps = eps - self.__inclusive = inclusive - - def search(self): - return self._search_iter() - - def _inside(self, dist): - if self.__inclusive: - return dist <= self.__eps - return dist < self.__eps - - def _search_iter(self): - stack = [(0, len(self.__dataset))] - result = [] - while stack: - start, end = stack.pop() - v_radius, v_point, is_terminal = self.__dataset[start] - if is_terminal: - for _, x, _ in self.__dataset[start:end]: - dist = self.__distance(self.__point, x) - if self._inside(dist): - result.append(x) - else: - dist = self.__distance(self.__point, v_point) - mid = _mid(start, end) - if self._inside(dist): - result.append(v_point) - if dist <= v_radius: - fst = (start + 1, mid) - snd = (mid, end) - else: - fst = (mid, end) - snd = (start + 1, mid) - if abs(dist - v_radius) <= self.__eps: - stack.append(snd) - stack.append(fst) - return result - - def knn_search(self, point, k): - return self._KnnSearch(self, point, k).search() - - class _KnnSearch: - - def __init__(self, vpt, point, neighbors): - self.__dataset = vpt._get_dataset() - self.__distance = vpt._get_distance() - self.__point = point - self.__neighbors = neighbors - self.__radius = float("inf") - self.__result = MaxHeap() - - def _get_items(self): - while len(self.__result) > self.__neighbors: - self.__result.pop() - return [x for (_, x) in self.__result] - - def search(self): - self._search_iter() - return self._get_items() - - def _process(self, x): - dist = self.__distance(self.__point, x) - if dist >= self.__radius: - return dist - self.__result.add(dist, x) - while len(self.__result) > self.__neighbors: - self.__result.pop() - if len(self.__result) == self.__neighbors: - self.__radius, _ = self.__result.top() - return dist - - def _search_iter(self): - PRE, POST = 0, 1 - self.__result = MaxHeap() - stack = [(0, len(self.__dataset), 0.0, PRE)] - while stack: - start, end, thr, action = stack.pop() - v_radius, v_point, is_terminal = self.__dataset[start] - if is_terminal: - for _, x, _ in self.__dataset[start:end]: - self._process(x) - else: - if action == PRE: - mid = _mid(start, end) - dist = self._process(v_point) - if dist <= v_radius: - fst_start, fst_end = start + 1, mid - snd_start, snd_end = mid, end - else: - fst_start, fst_end = mid, end - snd_start, snd_end = start + 1, mid - stack.append((snd_start, snd_end, abs(v_radius - dist), POST)) - stack.append((fst_start, fst_end, 0.0, PRE)) - elif action == POST: - if self.__radius > thr: - stack.append((start, end, 0.0, PRE)) - return self._get_items() diff --git a/src/tdamapper/utils/vptree_flat/__init__.py b/src/tdamapper/utils/vptree_flat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tdamapper/utils/vptree_flat/ball_search.py b/src/tdamapper/utils/vptree_flat/ball_search.py new file mode 100644 index 00000000..418375f5 --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/ball_search.py @@ -0,0 +1,48 @@ +from tdamapper.utils.vptree_flat.common import _mid + + +class BallSearch: + + def __init__(self, vpt, point, eps, inclusive=True): + self._arr = vpt._get_arr() + self.__distance = vpt._get_distance() + self.__point = point + self.__eps = eps + self.__inclusive = inclusive + + def search(self): + return self._search_iter() + + def _inside(self, dist): + if self.__inclusive: + return dist <= self.__eps + return dist < self.__eps + + def _search_iter(self): + stack = [(0, self._arr.size())] + result = [] + while stack: + start, end = stack.pop() + v_radius = self._arr.get_distance(start) + v_point = self._arr.get_point(start) + is_terminal = self._arr.is_terminal(start) + if is_terminal: + for x in self._arr.get_points(start, end): + dist = self.__distance(self.__point, x) + if self._inside(dist): + result.append(x) + else: + dist = self.__distance(self.__point, v_point) + mid = _mid(start, end) + if self._inside(dist): + result.append(v_point) + if dist <= v_radius: + fst = (start + 1, mid) + snd = (mid, end) + else: + fst = (mid, end) + snd = (start + 1, mid) + if abs(dist - v_radius) <= self.__eps: + stack.append(snd) + stack.append(fst) + return result diff --git a/src/tdamapper/utils/vptree_flat/builder.py b/src/tdamapper/utils/vptree_flat/builder.py new file mode 100644 index 00000000..a19b1c12 --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/builder.py @@ -0,0 +1,93 @@ +from random import randrange + +import numpy as np + +from tdamapper.utils.vptree_flat.common import VPArray + + +def _mid(start, end): + return (start + end) // 2 + + +class Builder: + + def __init__(self, vpt, X): + self.__distance = vpt._get_distance() + + dataset = [x for x in X] + indices = np.array([i for i in range(len(dataset))]) + distances = np.array([0.0 for _ in X]) + is_terminal = np.array([False for _ in X]) + self._arr = VPArray(dataset, distances, indices, is_terminal) + + self.__leaf_capacity = vpt.get_leaf_capacity() + self.__leaf_radius = vpt.get_leaf_radius() + pivoting = vpt.get_pivoting() + self.__pivoting = self._pivoting_disabled + if pivoting == "random": + self.__pivoting = self._pivoting_random + elif pivoting == "furthest": + self.__pivoting = self._pivoting_furthest + + def _pivoting_disabled(self, start, end): + pass + + def _pivoting_random(self, start, end): + if end <= start: + return + pivot = randrange(start, end) + if pivot > start: + self._arr.swap(start, pivot) + + def _furthest(self, start, end, i): + furthest_dist = 0.0 + furthest = start + i_point = self._arr.get_point(i) + for j in range(start, end): + j_point = self._arr.get_point(j) + j_dist = self.__distance(i_point, j_point) + if j_dist > furthest_dist: + furthest = j + furthest_dist = j_dist + return furthest + + def _pivoting_furthest(self, start, end): + if end <= start: + return + rnd = randrange(start, end) + furthest_rnd = self._furthest(start, end, rnd) + furthest = self._furthest(start, end, furthest_rnd) + if furthest > start: + self._arr.swap(start, furthest) + + def _update(self, start, end): + self.__pivoting(start, end) + v_point = self._arr.get_point(start) + is_terminal = self._arr.is_terminal(start) + for i in range(start + 1, end): + point = self._arr.get_point(i) + self._arr.set_distance(i, self.__distance(v_point, point)) + self._arr.set_terminal(i, is_terminal) + + def build(self): + self._build_iter() + return self._arr + + def _build_iter(self): + stack = [(0, self._arr.size())] + while stack: + start, end = stack.pop() + mid = _mid(start, end) + self._update(start, end) + self._arr.partition(start + 1, end, mid) + v_radius = self._arr.get_distance(mid) + if (end - start > 2 * self.__leaf_capacity) and ( + v_radius > self.__leaf_radius + ): + self._arr.set_distance(start, v_radius) + self._arr.set_terminal(start, False) + stack.append((mid, end)) + stack.append((start + 1, mid)) + else: + self._arr.set_distance(start, v_radius) + self._arr.set_terminal(start, True) diff --git a/src/tdamapper/utils/vptree_flat/common.py b/src/tdamapper/utils/vptree_flat/common.py new file mode 100644 index 00000000..f51e9c95 --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/common.py @@ -0,0 +1,42 @@ +from tdamapper.utils.quickselect import quickselect, swap_all + + +def _mid(start, end): + return (start + end) // 2 + + +class VPArray: + + def __init__(self, dataset, distances, indices, is_terminal): + self._dataset = dataset + self._distances = distances + self._indices = indices + self._is_terminal = is_terminal + + def size(self): + return len(self._dataset) + + def get_point(self, i): + return self._dataset[self._indices[i]] + + def get_points(self, s, e): + for x_index in self._indices[s:e]: + yield self._dataset[x_index] + + def get_distance(self, i): + return self._distances[i] + + def set_distance(self, i, dist): + self._distances[i] = dist + + def set_terminal(self, i, terminal): + self._is_terminal[i] = terminal + + def is_terminal(self, i): + return self._is_terminal[i] + + def swap(self, i, j): + swap_all(self._distances, i, j, self._indices, self._is_terminal) + + def partition(self, s, e, k): + quickselect(self._distances, s, e, k, self._indices, self._is_terminal) diff --git a/src/tdamapper/utils/vptree_flat/knn_search.py b/src/tdamapper/utils/vptree_flat/knn_search.py new file mode 100644 index 00000000..57b1d6cb --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/knn_search.py @@ -0,0 +1,64 @@ +from tdamapper.utils.heap import MaxHeap +from tdamapper.utils.vptree_flat.common import _mid + + +class KnnSearch: + + def __init__(self, vpt, point, neighbors): + self._arr = vpt._get_arr() + self.__distance = vpt._get_distance() + self.__point = point + self.__neighbors = neighbors + self.__radius = float("inf") + self.__result = MaxHeap() + + def _get_items(self): + while len(self.__result) > self.__neighbors: + self.__result.pop() + return [x for (_, x) in self.__result] + + def search(self): + self._search_iter() + return self._get_items() + + def _process(self, x): + dist = self.__distance(self.__point, x) + if dist >= self.__radius: + return dist + self.__result.add(dist, x) + while len(self.__result) > self.__neighbors: + self.__result.pop() + if len(self.__result) == self.__neighbors: + self.__radius, _ = self.__result.top() + return dist + + def _search_iter(self): + PRE, POST = 0, 1 + self.__result = MaxHeap() + stack = [(0, self._arr.size(), 0.0, PRE)] + while stack: + start, end, thr, action = stack.pop() + + v_radius = self._arr.get_distance(start) + v_point = self._arr.get_point(start) + is_terminal = self._arr.is_terminal(start) + + if is_terminal: + for x in self._arr.get_points(start, end): + self._process(x) + else: + if action == PRE: + mid = _mid(start, end) + dist = self._process(v_point) + if dist <= v_radius: + fst_start, fst_end = start + 1, mid + snd_start, snd_end = mid, end + else: + fst_start, fst_end = mid, end + snd_start, snd_end = start + 1, mid + stack.append((snd_start, snd_end, abs(v_radius - dist), POST)) + stack.append((fst_start, fst_end, 0.0, PRE)) + elif action == POST: + if self.__radius > thr: + stack.append((start, end, 0.0, PRE)) + return self._get_items() diff --git a/src/tdamapper/utils/vptree_flat/vptree.py b/src/tdamapper/utils/vptree_flat/vptree.py new file mode 100755 index 00000000..f157bbc6 --- /dev/null +++ b/src/tdamapper/utils/vptree_flat/vptree.py @@ -0,0 +1,51 @@ +from tdamapper.utils.metrics import get_metric +from tdamapper.utils.vptree_flat.ball_search import BallSearch +from tdamapper.utils.vptree_flat.builder import Builder +from tdamapper.utils.vptree_flat.knn_search import KnnSearch + + +class VPTree: + + def __init__( + self, + X, + metric="euclidean", + metric_params=None, + leaf_capacity=1, + leaf_radius=0.0, + pivoting=None, + ): + self.__metric = metric + self.__metric_params = metric_params + self.__leaf_capacity = leaf_capacity + self.__leaf_radius = leaf_radius + self.__pivoting = pivoting + self._arr = Builder(self, X).build() + + def get_metric(self): + return self.__metric + + def get_metric_params(self): + return self.__metric_params + + def get_leaf_capacity(self): + return self.__leaf_capacity + + def get_leaf_radius(self): + return self.__leaf_radius + + def get_pivoting(self): + return self.__pivoting + + def _get_arr(self): + return self._arr + + def _get_distance(self): + metric_params = self.__metric_params or {} + return get_metric(self.__metric, **metric_params) + + def ball_search(self, point, eps, inclusive=True): + return BallSearch(self, point, eps, inclusive).search() + + def knn_search(self, point, k): + return KnnSearch(self, point, k).search() diff --git a/src/tdamapper/utils/vptree_hier.py b/src/tdamapper/utils/vptree_hier.py deleted file mode 100644 index 474ed247..00000000 --- a/src/tdamapper/utils/vptree_hier.py +++ /dev/null @@ -1,262 +0,0 @@ -from random import randrange - -from tdamapper.utils.heap import MaxHeap -from tdamapper.utils.metrics import get_metric -from tdamapper.utils.quickselect import quickselect - - -def _swap(arr, i, j): - arr[i], arr[j] = arr[j], arr[i] - - -def _mid(start, end): - return (start + end) // 2 - - -class VPTree: - - def __init__( - self, - X, - metric="euclidean", - metric_params=None, - leaf_capacity=1, - leaf_radius=0.0, - pivoting=None, - ): - self.__metric = metric - self.__metric_params = metric_params - self.__leaf_capacity = leaf_capacity - self.__leaf_radius = leaf_radius - self.__pivoting = pivoting - tree, dataset = self._Build(self, X).build() - self.__tree, self.__dataset = tree, dataset - - def get_metric(self): - return self.__metric - - def get_metric_params(self): - return self.__metric_params - - def get_leaf_capacity(self): - return self.__leaf_capacity - - def get_leaf_radius(self): - return self.__leaf_radius - - def get_pivoting(self): - return self.__pivoting - - def _get_tree(self): - return self.__tree - - def _get_dataset(self): - return self.__dataset - - def _get_distance(self): - metric_params = self.__metric_params or {} - return get_metric(self.__metric, **metric_params) - - class _Build: - - def __init__(self, vpt, X): - self.__distance = vpt._get_distance() - self.__dataset = [(0.0, x) for x in X] - self.__leaf_capacity = vpt.get_leaf_capacity() - self.__leaf_radius = vpt.get_leaf_radius() - pivoting = vpt.get_pivoting() - self.__pivoting = self._pivoting_disabled - if pivoting == "random": - self.__pivoting = self._pivoting_random - elif pivoting == "furthest": - self.__pivoting = self._pivoting_furthest - - def _pivoting_disabled(self, start, end): - pass - - def _pivoting_random(self, start, end): - if end <= start: - return - pivot = randrange(start, end) - if pivot > start: - _swap(self.__dataset, start, pivot) - - def _furthest(self, start, end, i): - furthest_dist = 0.0 - furthest = start - _, i_point = self.__dataset[i] - for j in range(start, end): - _, j_point = self.__dataset[j] - j_dist = self.__distance(i_point, j_point) - if j_dist > furthest_dist: - furthest = j - furthest_dist = j_dist - return furthest - - def _pivoting_furthest(self, start, end): - if end <= start: - return - rnd = randrange(start, end) - furthest_rnd = self._furthest(start, end, rnd) - furthest = self._furthest(start, end, furthest_rnd) - if furthest > start: - _swap(self.__dataset, start, furthest) - - def _update(self, start, end): - self.__pivoting(start, end) - _, v_point = self.__dataset[start] - for i in range(start + 1, end): - _, point = self.__dataset[i] - self.__dataset[i] = self.__distance(v_point, point), point - - def build(self): - tree = self._build_rec(0, len(self.__dataset)) - return tree, self.__dataset - - def _build_rec(self, start, end): - mid = _mid(start, end) - self._update(start, end) - _, v_point = self.__dataset[start] - quickselect(self.__dataset, start + 1, end, mid) - v_radius, _ = self.__dataset[mid] - self.__dataset[start] = (v_radius, v_point) - if (end - start <= 2 * self.__leaf_capacity) or ( - v_radius <= self.__leaf_radius - ): - left = _Leaf(start + 1, mid) - right = _Leaf(mid, end) - else: - left = self._build_rec(start + 1, mid) - right = self._build_rec(mid, end) - return _Node(v_radius, v_point, left, right) - - def ball_search(self, point, eps, inclusive=True): - return self._BallSearch(self, point, eps, inclusive).search() - - class _BallSearch: - - def __init__(self, vpt, point, eps, inclusive=True): - self.__tree = vpt._get_tree() - self.__dataset = vpt._get_dataset() - self.__distance = vpt._get_distance() - self.__point = point - self.__eps = eps - self.__inclusive = inclusive - self.__result = [] - - def search(self): - self.__result.clear() - self._search_rec(self.__tree) - return self.__result - - def _inside(self, dist): - if self.__inclusive: - return dist <= self.__eps - return dist < self.__eps - - def _search_rec(self, tree): - if tree.is_terminal(): - start, end = tree.get_bounds() - for _, x in self.__dataset[start:end]: - dist = self.__distance(self.__point, x) - if self._inside(dist): - self.__result.append(x) - else: - v_radius, v_point = tree.get_ball() - dist = self.__distance(v_point, self.__point) - if self._inside(dist): - self.__result.append(v_point) - if dist <= v_radius: - fst, snd = tree.get_left(), tree.get_right() - else: - fst, snd = tree.get_right(), tree.get_left() - self._search_rec(fst) - if abs(dist - v_radius) <= self.__eps: - self._search_rec(snd) - - def knn_search(self, point, k): - return self._KnnSearch(self, point, k).search() - - class _KnnSearch: - - def __init__(self, vpt, point, neighbors): - self.__tree = vpt._get_tree() - self.__dataset = vpt._get_dataset() - self.__distance = vpt._get_distance() - self.__point = point - self.__neighbors = neighbors - self.__items = MaxHeap() - - def _add(self, dist, x): - self.__items.add(dist, x) - if len(self.__items) > self.__neighbors: - self.__items.pop() - - def _get_items(self): - while len(self.__items) > self.__neighbors: - self.__items.pop() - return [x for (_, x) in self.__items] - - def _get_radius(self): - if len(self.__items) < self.__neighbors: - return float("inf") - furthest_dist, _ = self.__items.top() - return furthest_dist - - def search(self): - self._search_rec(self.__tree) - return self._get_items() - - def _search_rec(self, tree): - if tree.is_terminal(): - start, end = tree.get_bounds() - for _, x in self.__dataset[start:end]: - dist = self.__distance(self.__point, x) - if dist < self._get_radius(): - self._add(dist, x) - else: - v_radius, v_point = tree.get_ball() - dist = self.__distance(v_point, self.__point) - if dist < self._get_radius(): - self._add(dist, v_point) - if dist <= v_radius: - fst, snd = tree.get_left(), tree.get_right() - else: - fst, snd = tree.get_right(), tree.get_left() - self._search_rec(fst) - if abs(dist - v_radius) <= self._get_radius(): - self._search_rec(snd) - - -class _Node: - - def __init__(self, radius, center, left, right): - self.__radius = radius - self.__center = center - self.__left = left - self.__right = right - - def get_ball(self): - return self.__radius, self.__center - - def is_terminal(self): - return False - - def get_left(self): - return self.__left - - def get_right(self): - return self.__right - - -class _Leaf: - - def __init__(self, start, end): - self.__start = start - self.__end = end - - def get_bounds(self): - return self.__start, self.__end - - def is_terminal(self): - return True diff --git a/src/tdamapper/utils/vptree_hier/__init__.py b/src/tdamapper/utils/vptree_hier/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tdamapper/utils/vptree_hier/ball_search.py b/src/tdamapper/utils/vptree_hier/ball_search.py new file mode 100644 index 00000000..1f519a39 --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/ball_search.py @@ -0,0 +1,40 @@ +class BallSearch: + + def __init__(self, vpt, point, eps, inclusive=True): + self.__tree = vpt._get_tree() + self._arr = vpt._get_arr() + self.__distance = vpt._get_distance() + self.__point = point + self.__eps = eps + self.__inclusive = inclusive + self.__result = [] + + def search(self): + self.__result.clear() + self._search_rec(self.__tree) + return self.__result + + def _inside(self, dist): + if self.__inclusive: + return dist <= self.__eps + return dist < self.__eps + + def _search_rec(self, tree): + if tree.is_terminal(): + start, end = tree.get_bounds() + for x in self._arr.get_points(start, end): + dist = self.__distance(self.__point, x) + if self._inside(dist): + self.__result.append(x) + else: + v_radius, v_point = tree.get_ball() + dist = self.__distance(v_point, self.__point) + if self._inside(dist): + self.__result.append(v_point) + if dist <= v_radius: + fst, snd = tree.get_left(), tree.get_right() + else: + fst, snd = tree.get_right(), tree.get_left() + self._search_rec(fst) + if abs(dist - v_radius) <= self.__eps: + self._search_rec(snd) diff --git a/src/tdamapper/utils/vptree_hier/builder.py b/src/tdamapper/utils/vptree_hier/builder.py new file mode 100644 index 00000000..944b93e7 --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/builder.py @@ -0,0 +1,85 @@ +from random import randrange + +import numpy as np + +from tdamapper.utils.quickselect import quickselect, swap_all +from tdamapper.utils.vptree_hier.common import Leaf, Node, VPArray, _mid + + +class Builder: + + def __init__(self, vpt, X): + self.__distance = vpt._get_distance() + + dataset = [x for x in X] + indices = np.array([i for i in range(len(dataset))]) + distances = np.array([0.0 for _ in X]) + self._arr = VPArray(dataset, distances, indices) + + self.__leaf_capacity = vpt.get_leaf_capacity() + self.__leaf_radius = vpt.get_leaf_radius() + pivoting = vpt.get_pivoting() + self.__pivoting = self._pivoting_disabled + if pivoting == "random": + self.__pivoting = self._pivoting_random + elif pivoting == "furthest": + self.__pivoting = self._pivoting_furthest + + def _pivoting_disabled(self, start, end): + pass + + def _pivoting_random(self, start, end): + if end <= start: + return + pivot = randrange(start, end) + if pivot > start: + self._arr.swap(start, pivot) + + def _furthest(self, start, end, i): + furthest_dist = 0.0 + furthest = start + i_point = self._arr.get_point(i) + for j in range(start, end): + j_point = self._arr.get_point(j) + j_dist = self.__distance(i_point, j_point) + if j_dist > furthest_dist: + furthest = j + furthest_dist = j_dist + return furthest + + def _pivoting_furthest(self, start, end): + if end <= start: + return + rnd = randrange(start, end) + furthest_rnd = self._furthest(start, end, rnd) + furthest = self._furthest(start, end, furthest_rnd) + if furthest > start: + self._arr.swap(start, furthest) + + def _update(self, start, end): + self.__pivoting(start, end) + v_point = self._arr.get_point(start) + for i in range(start + 1, end): + point = self._arr.get_point(i) + self._arr.set_distance(i, self.__distance(v_point, point)) + + def build(self): + tree = self._build_rec(0, self._arr.size()) + return tree, self._arr + + def _build_rec(self, start, end): + mid = _mid(start, end) + self._update(start, end) + v_point = self._arr.get_point(start) + self._arr.partition(start + 1, end, mid) + v_radius = self._arr.get_distance(mid) + self._arr.set_distance(start, v_radius) + if (end - start <= 2 * self.__leaf_capacity) or ( + v_radius <= self.__leaf_radius + ): + left = Leaf(start + 1, mid) + right = Leaf(mid, end) + else: + left = self._build_rec(start + 1, mid) + right = self._build_rec(mid, end) + return Node(v_radius, v_point, left, right) diff --git a/src/tdamapper/utils/vptree_hier/common.py b/src/tdamapper/utils/vptree_hier/common.py new file mode 100644 index 00000000..749c8fab --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/common.py @@ -0,0 +1,69 @@ +from tdamapper.utils.quickselect import quickselect, swap_all + + +def _mid(start, end): + return (start + end) // 2 + + +class VPArray: + + def __init__(self, dataset, distances, indices): + self._dataset = dataset + self._distances = distances + self._indices = indices + + def size(self): + return len(self._dataset) + + def get_point(self, i): + return self._dataset[self._indices[i]] + + def get_points(self, s, e): + for x_index in self._indices[s:e]: + yield self._dataset[x_index] + + def get_distance(self, i): + return self._distances[i] + + def set_distance(self, i, dist): + self._distances[i] = dist + + def swap(self, i, j): + swap_all(self._distances, i, j, self._indices) + + def partition(self, s, e, k): + quickselect(self._distances, s, e, k, self._indices) + + +class Node: + + def __init__(self, radius, center, left, right): + self.__radius = radius + self.__center = center + self.__left = left + self.__right = right + + def get_ball(self): + return self.__radius, self.__center + + def is_terminal(self): + return False + + def get_left(self): + return self.__left + + def get_right(self): + return self.__right + + +class Leaf: + + def __init__(self, start, end): + self.__start = start + self.__end = end + + def get_bounds(self): + return self.__start, self.__end + + def is_terminal(self): + return True diff --git a/src/tdamapper/utils/vptree_hier/knn_search.py b/src/tdamapper/utils/vptree_hier/knn_search.py new file mode 100644 index 00000000..a6531414 --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/knn_search.py @@ -0,0 +1,52 @@ +from tdamapper.utils.heap import MaxHeap + + +class KnnSearch: + + def __init__(self, vpt, point, neighbors): + self.__tree = vpt._get_tree() + self._arr = vpt._get_arr() + self.__distance = vpt._get_distance() + self.__point = point + self.__neighbors = neighbors + self.__items = MaxHeap() + + def _add(self, dist, x): + self.__items.add(dist, x) + if len(self.__items) > self.__neighbors: + self.__items.pop() + + def _get_items(self): + while len(self.__items) > self.__neighbors: + self.__items.pop() + return [x for (_, x) in self.__items] + + def _get_radius(self): + if len(self.__items) < self.__neighbors: + return float("inf") + furthest_dist, _ = self.__items.top() + return furthest_dist + + def search(self): + self._search_rec(self.__tree) + return self._get_items() + + def _search_rec(self, tree): + if tree.is_terminal(): + start, end = tree.get_bounds() + for x in self._arr.get_points(start, end): + dist = self.__distance(self.__point, x) + if dist < self._get_radius(): + self._add(dist, x) + else: + v_radius, v_point = tree.get_ball() + dist = self.__distance(v_point, self.__point) + if dist < self._get_radius(): + self._add(dist, v_point) + if dist <= v_radius: + fst, snd = tree.get_left(), tree.get_right() + else: + fst, snd = tree.get_right(), tree.get_left() + self._search_rec(fst) + if abs(dist - v_radius) <= self._get_radius(): + self._search_rec(snd) diff --git a/src/tdamapper/utils/vptree_hier/vptree.py b/src/tdamapper/utils/vptree_hier/vptree.py new file mode 100644 index 00000000..cd0f0172 --- /dev/null +++ b/src/tdamapper/utils/vptree_hier/vptree.py @@ -0,0 +1,54 @@ +from tdamapper.utils.metrics import get_metric +from tdamapper.utils.vptree_hier.ball_search import BallSearch +from tdamapper.utils.vptree_hier.builder import Builder +from tdamapper.utils.vptree_hier.knn_search import KnnSearch + + +class VPTree: + + def __init__( + self, + X, + metric="euclidean", + metric_params=None, + leaf_capacity=1, + leaf_radius=0.0, + pivoting=None, + ): + self.__metric = metric + self.__metric_params = metric_params + self.__leaf_capacity = leaf_capacity + self.__leaf_radius = leaf_radius + self.__pivoting = pivoting + self.__tree, self._arr = Builder(self, X).build() + + def get_metric(self): + return self.__metric + + def get_metric_params(self): + return self.__metric_params + + def get_leaf_capacity(self): + return self.__leaf_capacity + + def get_leaf_radius(self): + return self.__leaf_radius + + def get_pivoting(self): + return self.__pivoting + + def _get_tree(self): + return self.__tree + + def _get_arr(self): + return self._arr + + def _get_distance(self): + metric_params = self.__metric_params or {} + return get_metric(self.__metric, **metric_params) + + def ball_search(self, point, eps, inclusive=True): + return BallSearch(self, point, eps, inclusive).search() + + def knn_search(self, point, k): + return KnnSearch(self, point, k).search() diff --git a/tests/test_bench_cover.py b/tests/test_bench_cover.py index 67bd8549..0f94ee30 100644 --- a/tests/test_bench_cover.py +++ b/tests/test_bench_cover.py @@ -5,9 +5,10 @@ import numpy as np from sklearn.datasets import load_digits +from tdamapper._common import profile from tdamapper.utils.metrics import euclidean -from tdamapper.utils.vptree_flat import VPTree as FVPT -from tdamapper.utils.vptree_hier import VPTree as HVPT +from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT +from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT from tests.ball_tree import SkBallTree from tests.setup_logging import setup_logging @@ -21,6 +22,10 @@ def dataset(dim=10, num=1000): return [np.random.rand(dim) for _ in range(num)] +def dist_proj(x, y): + return dist(x[1:], x[1:]) + + class TestVpSettings(unittest.TestCase): setup_logging() @@ -39,11 +44,12 @@ def cover(self, vpt, X, r): def run_bench(self, X, r, dist, vp, **kwargs): XX = np.array([[i] + [xi for xi in x] for i, x in enumerate(X)]) t0 = time.time() - vpt = vp(XX, metric=lambda x, y: dist(x[1:], y[1:]), **kwargs) + vpt = vp(XX, metric=dist_proj, **kwargs) list(self.cover(vpt, XX, r)) t1 = time.time() self.logger.info(f"time: {t1 - t0}") + @profile(n_lines=20) def test_cover_random(self): for r in [1.0, 10.0, 100.0]: for n in [100, 1000, 10000]: @@ -61,6 +67,7 @@ def test_cover_random(self): self.run_bench(X, r, dist, SkBallTree, leaf_radius=r) self.logger.info("") + @profile(n_lines=20) def test_cover_digits(self): X, _ = load_digits(return_X_y=True) # X = PCA(n_components=3).fit_transform(X) diff --git a/tests/test_bench_metrics.py b/tests/test_bench_metrics.py index 5c7e1751..f5644e9f 100644 --- a/tests/test_bench_metrics.py +++ b/tests/test_bench_metrics.py @@ -47,7 +47,11 @@ def eval_dist(X, d): def run_dist_bench(X, d): eval_dist(X, d) - return timeit.timeit(lambda: eval_dist(X, d), number=200) + + def _eval(): + eval_dist(X, d) + + return timeit.timeit(_eval, number=200) def run_euclidean_bench(X): diff --git a/tests/test_bench_vptree.py b/tests/test_bench_vptree.py index 9f7f9530..df8ff2e1 100644 --- a/tests/test_bench_vptree.py +++ b/tests/test_bench_vptree.py @@ -6,8 +6,8 @@ from sklearn.datasets import load_breast_cancer, load_digits, load_iris from tdamapper.utils.metrics import euclidean, get_metric -from tdamapper.utils.vptree_flat import VPTree as FVPT -from tdamapper.utils.vptree_hier import VPTree as HVPT +from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT +from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT from tests.ball_tree import SkBallTree from tests.setup_logging import setup_logging @@ -93,7 +93,11 @@ def _test_knn_search_naive(self, data, name): d(np.array([0.0]), np.array([0.0])) # jit-compile numba t0 = time() for val in data: - data.sort(key=lambda x: d(x, val)) + + def _dist_key(x): + return d(x, val) + + data.sort(key=_dist_key) [x for x in data[: self.k]] t1 = time() self.logger.info(f"{name}: {t1 - t0}") diff --git a/tests/test_unit_knn.py b/tests/test_unit_knn.py index 02d402f4..a38e1c9e 100644 --- a/tests/test_unit_knn.py +++ b/tests/test_unit_knn.py @@ -4,7 +4,7 @@ from tdamapper.cover import KNNCover from tdamapper.utils.metrics import euclidean -from tdamapper.utils.vptree_flat import VPTree +from tdamapper.utils.vptree_flat.vptree import VPTree X = np.array( [ @@ -126,23 +126,32 @@ def test_vptree_simple(self): self.assertTrue(0.0 in dists) def check_vptree(self, vpt): - data = vpt._get_dataset() + arr = vpt._get_arr() + data = arr._dataset + distances = arr._distances + indices = arr._indices + dist = vpt._get_distance() leaf_capacity = vpt.get_leaf_capacity() leaf_radius = vpt.get_leaf_radius() def check_sub(start, end): - v_radius, v_point, *_ = data[start] + v_radius = distances[start] + v_point_index = indices[start] + v_point = data[v_point_index] + mid = (start + end) // 2 for i in range(start + 1, mid): - _, y, *_ = data[i] + y_index = indices[i] + y = data[y_index] self.assertTrue(dist(v_point, y) <= v_radius) for i in range(mid, end): - _, y, *_ = data[i] + y_index = indices[i] + y = data[y_index] self.assertTrue(dist(v_point, y) >= v_radius) def check_rec(start, end): - v_radius, *_ = data[start] + v_radius = distances[start] if (end - start > leaf_capacity) and (v_radius > leaf_radius): check_sub(start, end) mid = (start + end) // 2 diff --git a/tests/test_unit_proximity.py b/tests/test_unit_proximity.py index 90854235..333da250 100644 --- a/tests/test_unit_proximity.py +++ b/tests/test_unit_proximity.py @@ -10,11 +10,15 @@ def dataset(dim=1, num=10000): return [np.random.rand(dim) for _ in range(num)] +def absdist(x, y): + return abs(x - y) + + class TestProximity(unittest.TestCase): def test_ball_proximity(self): data = list(range(100)) - cover = BallCover(radius=10, metric=lambda x, y: abs(x - y)) + cover = BallCover(radius=10, metric=absdist) cover.fit(data) for x in data: result = cover.search(x) @@ -23,7 +27,7 @@ def test_ball_proximity(self): def test_knn_proximity(self): data = list(range(100)) - cover = KNNCover(neighbors=11, metric=lambda x, y: abs(x - y)) + cover = KNNCover(neighbors=11, metric=absdist) cover.fit(data) for x in range(5, 94): result = cover.search(x) diff --git a/tests/test_unit_quickselect.py b/tests/test_unit_quickselect.py index 86f23247..ca091dc3 100755 --- a/tests/test_unit_quickselect.py +++ b/tests/test_unit_quickselect.py @@ -1,50 +1,53 @@ import random import unittest -from tdamapper.utils.quickselect import ( - partition, - partition_tuple, - quickselect, - quickselect_tuple, -) +import numpy as np + +from tdamapper.utils.quickselect import partition, quickselect class TestQuickSelect(unittest.TestCase): def test_partition(self): n = 1000 - arr = [(i, random.randint(0, n - 1)) for i in range(n)] + arr = np.array([i for i in range(n)]) + arr_extra = np.array([random.randint(0, n - 1) for i in range(n)]) for choice in range(n): - h = partition(arr, 0, n, choice) + h = partition(arr, 0, n, choice, arr_extra) for i in range(0, h): - self.assertTrue(arr[i][0] < choice) + self.assertTrue(arr[i] < choice) for i in range(h, n): - self.assertTrue(arr[i][0] >= choice) + self.assertTrue(arr[i] >= choice) def test_quickselect_bounds(self): - arr = [(0, 4), (1, 5), (-1, 6)] - quickselect(arr, 1, 2, 0) - self.assertEqual((0, 4), arr[0]) - self.assertEqual((1, 5), arr[1]) - self.assertEqual((-1, 6), arr[2]) + arr = np.array([0, 1, -1]) + arr_extra = np.array([4, 5, 6]) + quickselect(arr, 1, 2, 0, arr_extra) + self.assertEqual(0, arr[0]) + self.assertEqual(1, arr[1]) + self.assertEqual(-1, arr[2]) + self.assertEqual(4, arr_extra[0]) + self.assertEqual(5, arr_extra[1]) + self.assertEqual(6, arr_extra[2]) def test_quickselect(self): n = 1000 - arr = [(i, random.randint(0, n - 1)) for i in range(n)] + arr = np.array([i for i in range(n)]) + arr_extra = np.array([random.randint(0, n - 1) for i in range(n)]) for choice in range(n): - quickselect(arr, 0, n, choice) - val = arr[choice][0] + quickselect(arr, 0, n, choice, arr_extra) + val = arr[choice] for i in range(0, choice): - self.assertTrue(arr[i][0] <= val) + self.assertTrue(arr[i] <= val) for i in range(choice, n): - self.assertTrue(arr[i][0] >= val) + self.assertTrue(arr[i] >= val) def test_partition_tuple(self): n = 1000 - arr_data = [random.randint(0, n - 1) for i in range(n)] - arr_ord = list(range(n)) + arr_data = np.array([random.randint(0, n - 1) for i in range(n)]) + arr_ord = np.array(list(range(n))) for choice in range(n): - h = partition_tuple(arr_ord, arr_data, 0, n, choice) + h = partition(arr_ord, 0, n, choice, arr_data) for i in range(0, h): self.assertTrue(arr_ord[i] < choice) for i in range(h, n): @@ -52,10 +55,10 @@ def test_partition_tuple(self): def test_quickselect_tuple(self): n = 1000 - arr_data = [random.randint(0, n - 1) for i in range(n)] - arr_ord = list(range(n)) + arr_data = np.array([random.randint(0, n - 1) for i in range(n)]) + arr_ord = np.array(list(range(n))) for choice in range(n): - quickselect_tuple(arr_ord, arr_data, 0, n, choice) + quickselect(arr_ord, 0, n, choice, arr_data) val = arr_ord[choice] for i in range(0, choice): self.assertTrue(arr_ord[i] <= val) diff --git a/tests/test_unit_vptree.py b/tests/test_unit_vptree.py index 2b31c152..e50ec1e5 100644 --- a/tests/test_unit_vptree.py +++ b/tests/test_unit_vptree.py @@ -4,8 +4,8 @@ import numpy as np from tdamapper.utils.metrics import get_metric -from tdamapper.utils.vptree_flat import VPTree as FVPT -from tdamapper.utils.vptree_hier import VPTree as HVPT +from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT +from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT from tests.ball_tree import SkBallTree distance = "euclidean"