Skip to content

Commit f4c74cc

Browse files
committed
Added parametric test
1 parent 7ca847b commit f4c74cc

File tree

1 file changed

+175
-7
lines changed

1 file changed

+175
-7
lines changed

tests/test_unit_cover.py

Lines changed: 175 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,140 @@
11
import numpy as np
2+
import pytest
23

34
from tdamapper.core import TrivialCover
45
from tdamapper.cover import BallCover, CubicalCover, KNNCover
56

67

7-
def dataset(dim=1, num=10000):
8+
def dataset_simple():
9+
"""
10+
Create a simple dataset of points in a 2D space.
11+
"""
12+
return [
13+
np.array([0.0, 1.0]),
14+
np.array([1.1, 0.0]),
15+
np.array([0.0, 0.0]),
16+
np.array([1.1, 1.0]),
17+
]
18+
19+
20+
def dataset_random(dim=1, num=1000):
21+
"""
22+
Create a random dataset of points in the unit square.
23+
"""
824
return [np.random.rand(dim) for _ in range(num)]
925

1026

11-
def test_trivial_cover():
12-
data = dataset()
13-
cover = TrivialCover()
27+
def dataset_two_lines(num=1000):
28+
"""
29+
Create a dataset consisting of two lines in the unit square.
30+
One line is horizontal at y=0, the other is vertical at x=1.
31+
"""
32+
t = np.linspace(0.0, 1.0, num)
33+
line1 = np.array([[x, 0.0] for x in t])
34+
line2 = np.array([[x, 1.0] for x in t])
35+
return np.concatenate((line1, line2), axis=0)
36+
37+
38+
def dataset_grid(num=1000):
39+
"""
40+
Create a grid dataset in the unit square.
41+
The grid consists of points evenly spaced in both dimensions.
42+
"""
43+
t = np.linspace(0.0, 1.0, num)
44+
s = np.linspace(0.0, 1.0, num)
45+
grid = np.array([[x, y] for x in t for y in s])
46+
return grid
47+
48+
49+
def assert_coverage(data, cover):
50+
"""
51+
Assert that the cover applies to the data and covers all points.
52+
"""
53+
covered = set()
1454
charts = list(cover.apply(data))
55+
for point_ids in charts:
56+
for point_id in point_ids:
57+
covered.add(point_id)
58+
assert len(covered) == len(data)
59+
return charts
60+
61+
62+
def count_components(charts):
63+
"""
64+
Count the number of unique connected components in the charts.
65+
Each chart is a list of point ids. When multiple charts share points,
66+
they are considered connected.
67+
"""
68+
# Create a mapping from point ids to chart ids
69+
point_charts = {}
70+
for chart_id, point_ids in enumerate(charts):
71+
for point_id in point_ids:
72+
if point_id not in point_charts:
73+
point_charts[point_id] = []
74+
point_charts[point_id].append(chart_id)
75+
76+
chart_components = {x: x for x in range(len(charts))}
77+
for point_id, chart_ids in point_charts.items():
78+
if len(chart_ids) > 1:
79+
# Union all chart ids for this point
80+
first_chart = chart_ids[0]
81+
for chart_id in chart_ids[1:]:
82+
chart_components[chart_id] = chart_components[first_chart]
83+
# Count unique components
84+
unique_components = set(chart_components.values())
85+
return len(unique_components)
86+
87+
88+
def test_trivial_cover_random():
89+
data = dataset_random()
90+
cover = TrivialCover()
91+
assert_coverage(data, cover)
92+
93+
94+
def test_trivial_cover_two_lines():
95+
data = dataset_two_lines()
96+
cover = TrivialCover()
97+
charts = assert_coverage(data, cover)
1598
assert 1 == len(charts)
99+
num_components = count_components(charts)
100+
assert 1 == num_components
101+
102+
103+
@pytest.mark.parametrize(
104+
"dataset, cover, num_charts, num_components",
105+
[
106+
# Simple dataset tests
107+
(dataset_simple(), TrivialCover(), 1, 1),
108+
(dataset_simple(), BallCover(radius=1.1, metric="euclidean"), 2, 2),
109+
(dataset_simple(), KNNCover(neighbors=2, metric="euclidean"), 2, 2),
110+
(dataset_simple(), CubicalCover(n_intervals=2, overlap_frac=0.5), 4, None),
111+
# Two lines dataset tests
112+
(dataset_two_lines(), TrivialCover(), 1, 1),
113+
(dataset_two_lines(), BallCover(radius=0.2, metric="euclidean"), None, 2),
114+
(dataset_two_lines(), KNNCover(neighbors=10, metric="euclidean"), None, 2),
115+
(dataset_two_lines(), CubicalCover(n_intervals=2, overlap_frac=0.5), 4, None),
116+
# Grid dataset tests
117+
(dataset_grid(), TrivialCover(), 1, 1),
118+
(dataset_grid(), BallCover(radius=0.05, metric="euclidean"), None, 1),
119+
(dataset_grid(), KNNCover(neighbors=10, metric="euclidean"), None, 1),
120+
(dataset_grid(), CubicalCover(n_intervals=2, overlap_frac=0.5), 4, None),
121+
],
122+
)
123+
def test_cover(dataset, cover, num_charts, num_components):
124+
charts = assert_coverage(dataset, cover)
125+
if num_charts is not None:
126+
assert len(charts) == num_charts
127+
if num_components is not None:
128+
assert count_components(charts) == num_components
129+
130+
131+
def test_trivial_cover_grid():
132+
data = dataset_two_lines()
133+
cover = TrivialCover()
134+
charts = assert_coverage(data, cover)
135+
assert 1 == len(charts)
136+
num_components = count_components(charts)
137+
assert 1 == num_components
16138

17139

18140
def test_ball_cover_simple():
@@ -23,8 +145,32 @@ def test_ball_cover_simple():
23145
np.array([1.0, 1.0]),
24146
]
25147
cover = BallCover(radius=1.1, metric="euclidean")
26-
charts = list(cover.apply(data))
148+
charts = assert_coverage(data, cover)
27149
assert 2 == len(charts)
150+
num_components = count_components(charts)
151+
assert 1 == num_components
152+
153+
154+
def test_ball_cover_random():
155+
data = dataset_random(dim=2, num=10)
156+
cover = BallCover(radius=0.2, metric="euclidean")
157+
assert_coverage(data, cover)
158+
159+
160+
def test_ball_cover_two_lines():
161+
data = dataset_two_lines()
162+
cover = BallCover(radius=0.2, metric="euclidean")
163+
charts = assert_coverage(data, cover)
164+
num_components = count_components(charts)
165+
assert 2 == num_components
166+
167+
168+
def test_ball_cover_grid():
169+
data = dataset_grid(num=100)
170+
cover = BallCover(radius=0.05, metric="euclidean")
171+
charts = assert_coverage(data, cover)
172+
num_components = count_components(charts)
173+
assert 1 == num_components
28174

29175

30176
def test_knn_cover_simple():
@@ -35,10 +181,26 @@ def test_knn_cover_simple():
35181
np.array([1.1, 1.0]),
36182
]
37183
cover = KNNCover(neighbors=2, metric="euclidean")
38-
charts = list(cover.apply(data))
184+
charts = assert_coverage(data, cover)
39185
assert 2 == len(charts)
40186

41187

188+
def test_knn_cover_two_lines():
189+
data = dataset_two_lines()
190+
cover = KNNCover(neighbors=10, metric="euclidean")
191+
charts = assert_coverage(data, cover)
192+
num_components = count_components(charts)
193+
assert 2 == num_components
194+
195+
196+
def test_knn_cover_grid():
197+
data = dataset_grid(num=100)
198+
cover = KNNCover(neighbors=10, metric="euclidean")
199+
charts = assert_coverage(data, cover)
200+
num_components = count_components(charts)
201+
assert 1 == num_components
202+
203+
42204
def test_cubical_cover_simple():
43205
data = [
44206
np.array([0.0, 1.0]),
@@ -51,7 +213,13 @@ def test_cubical_cover_simple():
51213
assert 4 == len(charts)
52214

53215

54-
def test_params():
216+
def test_cubical_cover_random():
217+
data = dataset_random(dim=2, num=100)
218+
cover = CubicalCover(n_intervals=5, overlap_frac=0.1)
219+
assert_coverage(data, cover)
220+
221+
222+
def test_cubical_cover_params():
55223
cover = CubicalCover(n_intervals=2, overlap_frac=0.5)
56224
params = cover.get_params(deep=True)
57225
assert 2 == params["n_intervals"]

0 commit comments

Comments
 (0)