Skip to content

Commit 2e33ace

Browse files
committed
Improve search algorithm testing
1 parent ac981fc commit 2e33ace

6 files changed

Lines changed: 45 additions & 23 deletions

File tree

src/tdamapper/search.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ def search(self, x: PointLike) -> List[int]:
120120
:param x: A query point for which we want to find neighbors.
121121
:return: The indices of the neighbors contained in the dataset.
122122
"""
123-
if self._vptree is None:
124-
return []
125123
neighs = self._vptree.ball_search(
126124
(-1, x),
127125
self._radius,
@@ -219,8 +217,6 @@ def search(self, x: PointLike) -> List[int]:
219217
:param x: A query point for which we want to find neighbors.
220218
:return: The indices of the neighbors contained in the dataset.
221219
"""
222-
if self._vptree is None:
223-
return []
224220
neighs = self._vptree.knn_search((-1, x), self._neighbors)
225221
return [x for (x, _) in neighs]
226222

@@ -306,7 +302,7 @@ def _gamma_n_inv(self, x):
306302
def _get_bounds(
307303
self, arr: NDArray[np.float64]
308304
) -> Optional[Tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]]:
309-
if (arr is None) or len(arr) == 0:
305+
if len(arr) == 0:
310306
return None
311307
_min, _max = arr[0], arr[0]
312308
eps = np.finfo(np.float64).eps
@@ -330,7 +326,9 @@ def fit(self, arr: ArrayLike) -> CubicalSearch:
330326
raise ValueError("The parameter overlap_frac is expected to be > 0.0")
331327
if self.overlap_frac is not None and self.overlap_frac > 0.5:
332328
warn_user("The parameter overlap_frac is expected to be <= 0.5")
333-
arr_ = np.asarray(arr).reshape(len(arr), -1).astype(float)
329+
arr_ = np.asarray(arr)
330+
if len(arr) > 0:
331+
arr_ = np.asarray(arr).reshape(len(arr), -1).astype(float)
334332
dim = 1 if arr_.ndim == 1 else arr_.shape[1]
335333
self._n_intervals = self.n_intervals
336334
self._overlap_frac = (
@@ -339,9 +337,12 @@ def fit(self, arr: ArrayLike) -> CubicalSearch:
339337
else self._get_overlap_frac(dim, 0.5)
340338
)
341339
bounds = self._get_bounds(arr_)
342-
if bounds is None:
343-
raise ValueError("The dataset is empty or not properly defined.")
344-
self._min, self._max, self._delta = bounds
340+
if bounds is not None:
341+
self._min, self._max, self._delta = bounds
342+
else:
343+
self._min = np.zeros(dim, dtype=np.float64)
344+
self._max = np.ones(dim, dtype=np.float64)
345+
self._delta = np.ones(dim, dtype=np.float64)
345346
radius = 1.0 / (2.0 - 2.0 * self._overlap_frac)
346347
self._ball_search = BallSearch(
347348
radius,

src/tdamapper/vptree_flat/ball_search.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def search(self) -> List[T]:
5555
:return: A list of points that are within the specified distance from
5656
the given point.
5757
"""
58+
if self._arr.size() == 0:
59+
return []
5860
return self._search_iter()
5961

6062
def _inside(self, dist: float) -> bool:

src/tdamapper/vptree_flat/builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def build(self) -> VPArray[T]:
102102
103103
:return: The VPArray instance containing the constructed VP-tree.
104104
"""
105+
if self._arr.size() == 0:
106+
return self._arr
105107
self._build_iter()
106108
return self._arr
107109

src/tdamapper/vptree_flat/common.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,7 @@
99

1010
from __future__ import annotations
1111

12-
from typing import (
13-
Callable,
14-
Generic,
15-
Iterable,
16-
List,
17-
Optional,
18-
Protocol,
19-
TypeVar,
20-
)
12+
from typing import Callable, Generic, Iterable, List, Optional, Protocol, TypeVar
2113

2214
import numpy as np
2315
from numpy.typing import NDArray

src/tdamapper/vptree_flat/knn_search.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ def search(self) -> List[T]:
5959
:return: A list of points that are the k-nearest neighbors of the
6060
given point.
6161
"""
62-
self._search_iter()
63-
return self._get_items()
62+
if self._arr.size() == 0:
63+
return []
64+
return self._search_iter()
6465

6566
def _process(self, x: T) -> float:
6667
dist = self._distance(self._point, x)

tests/test_unit_search.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
from tdamapper.search import BallSearch, CubicalSearch, KNNSearch
44

55

6-
def test_ball_search():
6+
def test_ball_search_empty():
7+
search = BallSearch(radius=1.0)
8+
data = np.array([])
9+
search.fit(data)
10+
11+
assert 0 == len(search.search(np.array([0.0])))
12+
13+
14+
def test_ball_search_ok():
715
search = BallSearch(radius=1.0)
816
data = np.array([[0.0], [1.0], [2.0]])
917
search.fit(data)
@@ -17,7 +25,15 @@ def test_ball_search():
1725
assert 1 in search.search(np.array([0.5]))
1826

1927

20-
def test_knn_search():
28+
def test_knn_search_empty():
29+
search = KNNSearch(neighbors=2)
30+
data = np.array([])
31+
search.fit(data)
32+
33+
assert 0 == len(search.search(np.array([0.0])))
34+
35+
36+
def test_knn_search_ok():
2137
search = KNNSearch(neighbors=2)
2238
data = np.array([[0.0], [1.0], [2.0], [3.0]])
2339
search.fit(data)
@@ -31,7 +47,15 @@ def test_knn_search():
3147
assert 1 in search.search(np.array([0.5]))
3248

3349

34-
def test_cubical_search():
50+
def test_cubical_search_empty():
51+
search = CubicalSearch(n_intervals=4, overlap_frac=0.25)
52+
data = np.array([])
53+
search.fit(data)
54+
55+
assert 0 == len(search.search(np.array([1.0])))
56+
57+
58+
def test_cubical_search_ok():
3559
search = CubicalSearch(n_intervals=4, overlap_frac=0.25)
3660
data = np.array([[0.0], [1.0], [2.0], [3.0], [4.0]])
3761
search.fit(data)

0 commit comments

Comments
 (0)