Skip to content

Commit fa42c42

Browse files
committed
Improved testing
1 parent 2e33ace commit fa42c42

2 files changed

Lines changed: 77 additions & 10 deletions

File tree

src/tdamapper/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ def transform(self, x_arr: ArrayLike) -> Generator[List[int], None, None]:
377377
:param x_arr: A dataset of n points.
378378
:yield: A generator yielding a single list of indices.
379379
"""
380-
yield list(range(len(x_arr)))
380+
if len(x_arr) > 0:
381+
yield list(range(len(x_arr)))
381382

382383
def fit_transform(self, x_arr: ArrayLike) -> Generator[List[int], None, None]:
383384
"""

tests/test_unit_cover.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,31 @@
44
from tdamapper.cover import BallCover, CubicalCover, KNNCover
55

66

7-
def dataset(dim=1, num=10000):
8-
return [np.random.rand(dim) for _ in range(num)]
7+
def test_trivial_cover_empty():
8+
data = []
9+
cover = TrivialCover()
10+
cover.fit(data)
11+
charts = list(cover.transform(data))
12+
assert 0 == len(charts)
913

1014

11-
def test_trivial_cover():
12-
data = dataset()
15+
def test_trivial_cover_ok():
16+
data = np.array([[0.0, 1.0], [1.0, 0.0], [0.0, 0.0], [1.0, 1.0]])
1317
cover = TrivialCover()
18+
cover.fit(data)
1419
charts = list(cover.transform(data))
1520
assert 1 == len(charts)
1621

1722

18-
def test_ball_cover_simple():
23+
def test_ball_cover_empty():
24+
data = []
25+
cover = BallCover(radius=1.0, metric="euclidean")
26+
cover.fit(data)
27+
charts = list(cover.transform(data))
28+
assert 0 == len(charts)
29+
30+
31+
def test_ball_cover_ok():
1932
data = [
2033
np.array([0.0, 1.0]),
2134
np.array([1.0, 0.0]),
@@ -28,7 +41,22 @@ def test_ball_cover_simple():
2841
assert 2 == len(charts)
2942

3043

31-
def test_knn_cover_simple():
44+
def test_ball_cover_params():
45+
cover = BallCover(radius=1.0, metric="euclidean")
46+
params = cover.get_params(deep=True)
47+
assert 1.0 == params["radius"]
48+
assert "euclidean" == params["metric"]
49+
50+
51+
def test_knn_cover_empty():
52+
data = []
53+
cover = KNNCover(neighbors=2, metric="euclidean")
54+
cover.fit(data)
55+
charts = list(cover.transform(data))
56+
assert 0 == len(charts)
57+
58+
59+
def test_knn_cover_ok():
3260
data = [
3361
np.array([0.0, 1.0]),
3462
np.array([1.1, 0.0]),
@@ -41,7 +69,22 @@ def test_knn_cover_simple():
4169
assert 2 == len(charts)
4270

4371

44-
def test_cubical_cover_simple():
72+
def test_knn_cover_params():
73+
cover = KNNCover(neighbors=2, metric="euclidean")
74+
params = cover.get_params(deep=True)
75+
assert 2 == params["neighbors"]
76+
assert "euclidean" == params["metric"]
77+
78+
79+
def test_cubical_cover_empty():
80+
data = []
81+
cover = CubicalCover(n_intervals=2, overlap_frac=0.5)
82+
cover.fit(data)
83+
charts = list(cover.transform(data))
84+
assert 0 == len(charts)
85+
86+
87+
def test_cubical_cover_ok():
4588
data = [
4689
np.array([0.0, 1.0]),
4790
np.array([1.1, 0.0]),
@@ -54,14 +97,25 @@ def test_cubical_cover_simple():
5497
assert 4 == len(charts)
5598

5699

57-
def test_params():
100+
def test_cubical_cover_params():
58101
cover = CubicalCover(n_intervals=2, overlap_frac=0.5)
59102
params = cover.get_params(deep=True)
60103
assert 2 == params["n_intervals"]
61104
assert 0.5 == params["overlap_frac"]
62105

63106

64-
def test_standard_cover_simple():
107+
def test_standard_cover_empty():
108+
data = []
109+
cover = CubicalCover(
110+
n_intervals=2,
111+
overlap_frac=0.5,
112+
)
113+
cover.fit(data)
114+
charts = list(cover.transform(data))
115+
assert 0 == len(charts)
116+
117+
118+
def test_standard_cover_ok():
65119
data = [
66120
np.array([0.0, 1.0]),
67121
np.array([1.1, 0.0]),
@@ -76,3 +130,15 @@ def test_standard_cover_simple():
76130
cover.fit(data)
77131
charts = list(cover.transform(data))
78132
assert 4 == len(charts)
133+
134+
135+
def test_standard_cover_params():
136+
cover = CubicalCover(
137+
n_intervals=2,
138+
overlap_frac=0.5,
139+
algorithm="standard",
140+
)
141+
params = cover.get_params(deep=True)
142+
assert 2 == params["n_intervals"]
143+
assert 0.5 == params["overlap_frac"]
144+
assert "standard" == params["algorithm"]

0 commit comments

Comments
 (0)