11import numpy as np
2+ import pytest
23
34from tdamapper .core import TrivialCover
45from tdamapper .cover import BallCover , CubicalCover , KNNCover
56
67
7- def dataset (dim = 1 , num = 10000 ):
8+ def dataset_simple ():
9+ """
10+ Create a simple dataset of points in a 2D space.
11+ """
12+ return [
13+ np .array ([0.0 , 1.0 ]),
14+ np .array ([1.1 , 0.0 ]),
15+ np .array ([0.0 , 0.0 ]),
16+ np .array ([1.1 , 1.0 ]),
17+ ]
18+
19+
20+ def dataset_random (dim = 1 , num = 1000 ):
21+ """
22+ Create a random dataset of points in the unit square.
23+ """
824 return [np .random .rand (dim ) for _ in range (num )]
925
1026
11- def test_trivial_cover ():
12- data = dataset ()
13- cover = TrivialCover ()
27+ def dataset_two_lines (num = 1000 ):
28+ """
29+ Create a dataset consisting of two lines in the unit square.
30+ One line is horizontal at y=0, the other is vertical at x=1.
31+ """
32+ t = np .linspace (0.0 , 1.0 , num )
33+ line1 = np .array ([[x , 0.0 ] for x in t ])
34+ line2 = np .array ([[x , 1.0 ] for x in t ])
35+ return np .concatenate ((line1 , line2 ), axis = 0 )
36+
37+
38+ def dataset_grid (num = 1000 ):
39+ """
40+ Create a grid dataset in the unit square.
41+ The grid consists of points evenly spaced in both dimensions.
42+ """
43+ t = np .linspace (0.0 , 1.0 , num )
44+ s = np .linspace (0.0 , 1.0 , num )
45+ grid = np .array ([[x , y ] for x in t for y in s ])
46+ return grid
47+
48+
49+ def assert_coverage (data , cover ):
50+ """
51+ Assert that the cover applies to the data and covers all points.
52+ """
53+ covered = set ()
1454 charts = list (cover .apply (data ))
55+ for point_ids in charts :
56+ for point_id in point_ids :
57+ covered .add (point_id )
58+ assert len (covered ) == len (data )
59+ return charts
60+
61+
62+ def count_components (charts ):
63+ """
64+ Count the number of unique connected components in the charts.
65+ Each chart is a list of point ids. When multiple charts share points,
66+ they are considered connected.
67+ """
68+ # Create a mapping from point ids to chart ids
69+ point_charts = {}
70+ for chart_id , point_ids in enumerate (charts ):
71+ for point_id in point_ids :
72+ if point_id not in point_charts :
73+ point_charts [point_id ] = []
74+ point_charts [point_id ].append (chart_id )
75+
76+ chart_components = {x : x for x in range (len (charts ))}
77+ for point_id , chart_ids in point_charts .items ():
78+ if len (chart_ids ) > 1 :
79+ # Union all chart ids for this point
80+ first_chart = chart_ids [0 ]
81+ for chart_id in chart_ids [1 :]:
82+ chart_components [chart_id ] = chart_components [first_chart ]
83+ # Count unique components
84+ unique_components = set (chart_components .values ())
85+ return len (unique_components )
86+
87+
88+ def test_trivial_cover_random ():
89+ data = dataset_random ()
90+ cover = TrivialCover ()
91+ assert_coverage (data , cover )
92+
93+
94+ def test_trivial_cover_two_lines ():
95+ data = dataset_two_lines ()
96+ cover = TrivialCover ()
97+ charts = assert_coverage (data , cover )
1598 assert 1 == len (charts )
99+ num_components = count_components (charts )
100+ assert 1 == num_components
101+
102+
103+ @pytest .mark .parametrize (
104+ "dataset, cover, num_charts, num_components" ,
105+ [
106+ # Simple dataset tests
107+ (dataset_simple (), TrivialCover (), 1 , 1 ),
108+ (dataset_simple (), BallCover (radius = 1.1 , metric = "euclidean" ), 2 , 2 ),
109+ (dataset_simple (), KNNCover (neighbors = 2 , metric = "euclidean" ), 2 , 2 ),
110+ (dataset_simple (), CubicalCover (n_intervals = 2 , overlap_frac = 0.5 ), 4 , None ),
111+ # Two lines dataset tests
112+ (dataset_two_lines (), TrivialCover (), 1 , 1 ),
113+ (dataset_two_lines (), BallCover (radius = 0.2 , metric = "euclidean" ), None , 2 ),
114+ (dataset_two_lines (), KNNCover (neighbors = 10 , metric = "euclidean" ), None , 2 ),
115+ (dataset_two_lines (), CubicalCover (n_intervals = 2 , overlap_frac = 0.5 ), 4 , None ),
116+ # Grid dataset tests
117+ (dataset_grid (), TrivialCover (), 1 , 1 ),
118+ (dataset_grid (), BallCover (radius = 0.05 , metric = "euclidean" ), None , 1 ),
119+ (dataset_grid (), KNNCover (neighbors = 10 , metric = "euclidean" ), None , 1 ),
120+ (dataset_grid (), CubicalCover (n_intervals = 2 , overlap_frac = 0.5 ), 4 , None ),
121+ ],
122+ )
123+ def test_cover (dataset , cover , num_charts , num_components ):
124+ charts = assert_coverage (dataset , cover )
125+ if num_charts is not None :
126+ assert len (charts ) == num_charts
127+ if num_components is not None :
128+ assert count_components (charts ) == num_components
129+
130+
131+ def test_trivial_cover_grid ():
132+ data = dataset_two_lines ()
133+ cover = TrivialCover ()
134+ charts = assert_coverage (data , cover )
135+ assert 1 == len (charts )
136+ num_components = count_components (charts )
137+ assert 1 == num_components
16138
17139
18140def test_ball_cover_simple ():
@@ -23,8 +145,32 @@ def test_ball_cover_simple():
23145 np .array ([1.0 , 1.0 ]),
24146 ]
25147 cover = BallCover (radius = 1.1 , metric = "euclidean" )
26- charts = list ( cover . apply ( data ) )
148+ charts = assert_coverage ( data , cover )
27149 assert 2 == len (charts )
150+ num_components = count_components (charts )
151+ assert 1 == num_components
152+
153+
154+ def test_ball_cover_random ():
155+ data = dataset_random (dim = 2 , num = 10 )
156+ cover = BallCover (radius = 0.2 , metric = "euclidean" )
157+ assert_coverage (data , cover )
158+
159+
160+ def test_ball_cover_two_lines ():
161+ data = dataset_two_lines ()
162+ cover = BallCover (radius = 0.2 , metric = "euclidean" )
163+ charts = assert_coverage (data , cover )
164+ num_components = count_components (charts )
165+ assert 2 == num_components
166+
167+
168+ def test_ball_cover_grid ():
169+ data = dataset_grid (num = 100 )
170+ cover = BallCover (radius = 0.05 , metric = "euclidean" )
171+ charts = assert_coverage (data , cover )
172+ num_components = count_components (charts )
173+ assert 1 == num_components
28174
29175
30176def test_knn_cover_simple ():
@@ -35,10 +181,26 @@ def test_knn_cover_simple():
35181 np .array ([1.1 , 1.0 ]),
36182 ]
37183 cover = KNNCover (neighbors = 2 , metric = "euclidean" )
38- charts = list ( cover . apply ( data ) )
184+ charts = assert_coverage ( data , cover )
39185 assert 2 == len (charts )
40186
41187
188+ def test_knn_cover_two_lines ():
189+ data = dataset_two_lines ()
190+ cover = KNNCover (neighbors = 10 , metric = "euclidean" )
191+ charts = assert_coverage (data , cover )
192+ num_components = count_components (charts )
193+ assert 2 == num_components
194+
195+
196+ def test_knn_cover_grid ():
197+ data = dataset_grid (num = 100 )
198+ cover = KNNCover (neighbors = 10 , metric = "euclidean" )
199+ charts = assert_coverage (data , cover )
200+ num_components = count_components (charts )
201+ assert 1 == num_components
202+
203+
42204def test_cubical_cover_simple ():
43205 data = [
44206 np .array ([0.0 , 1.0 ]),
@@ -51,7 +213,13 @@ def test_cubical_cover_simple():
51213 assert 4 == len (charts )
52214
53215
54- def test_params ():
216+ def test_cubical_cover_random ():
217+ data = dataset_random (dim = 2 , num = 100 )
218+ cover = CubicalCover (n_intervals = 5 , overlap_frac = 0.1 )
219+ assert_coverage (data , cover )
220+
221+
222+ def test_cubical_cover_params ():
55223 cover = CubicalCover (n_intervals = 2 , overlap_frac = 0.5 )
56224 params = cover .get_params (deep = True )
57225 assert 2 == params ["n_intervals" ]
0 commit comments