Skip to content

Commit 85a9744

Browse files
committed
refactor: consolidate 1D cleanup tests
1 parent 7d2f750 commit 85a9744

5 files changed

Lines changed: 180 additions & 189 deletions

File tree

adaptive/tests/test_learnernd.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23
import scipy.spatial
34

45
from adaptive.learner import LearnerND
@@ -50,16 +51,21 @@ def test_vector_return_with_a_flat_layer():
5051
simple(learner, loss_goal=0.1)
5152

5253

53-
def test_learnerND_1d_basic():
54+
@pytest.mark.parametrize(
55+
("run_kwargs", "expected_npoints"),
56+
[
57+
({"npoints_goal": 10}, 10),
58+
({"loss_goal": 0.1}, None),
59+
],
60+
ids=["npoints-goal", "loss-goal"],
61+
)
62+
def test_learnerND_1d(run_kwargs, expected_npoints):
5463
"""Test LearnerND works with 1D bounds."""
5564
learner = LearnerND(lambda x: x[0] ** 2, bounds=[(-1, 1)])
56-
simple(learner, npoints_goal=10)
57-
assert learner.npoints == 10
58-
assert learner.loss() < float("inf")
59-
65+
simple(learner, **run_kwargs)
6066

61-
def test_learnerND_1d_with_loss_goal():
62-
"""Test LearnerND 1D converges with a loss goal."""
63-
learner = LearnerND(lambda x: x[0] ** 2, bounds=[(-1, 1)])
64-
simple(learner, loss_goal=0.1)
65-
assert learner.loss() <= 0.1
67+
if expected_npoints is not None:
68+
assert learner.npoints == expected_npoints
69+
assert learner.loss() < float("inf")
70+
if "loss_goal" in run_kwargs:
71+
assert learner.loss() <= run_kwargs["loss_goal"]

adaptive/tests/test_triangulation.py

Lines changed: 47 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,6 @@ def test_triangulation_raises_exception_for_1d_list():
7676
with pytest.raises(TypeError):
7777
Triangulation(pts)
7878

79-
80-
def test_triangulation_supports_1d_points():
81-
pts = [(0,), (1,)]
82-
t = Triangulation(pts)
83-
assert t.simplices == {(0, 1)}
84-
assert t.hull == {0, 1}
85-
86-
8779
@with_dimension_incl_1d
8880
def test_triangulation_of_standard_simplex(dim):
8981
t = Triangulation(_make_standard_simplex(dim))
@@ -336,61 +328,44 @@ def test_initialisation_accepts_more_than_one_simplex(dim):
336328
# ---- 1D-specific triangulation tests ----
337329

338330

339-
def test_1d_triangulation_basic():
340-
"""Test basic 1D triangulation with two points."""
341-
t = Triangulation([(0.0,), (1.0,)])
342-
assert t.simplices == {(0, 1)}
343-
assert t.hull == {0, 1}
344-
assert t.volume((0, 1)) == 1.0
345-
_check_triangulation_is_valid(t)
346-
347-
348-
def test_1d_triangulation_multiple_points():
349-
"""Test 1D triangulation with multiple initial points."""
350-
pts = [(0.0,), (0.5,), (1.0,)]
351-
t = Triangulation(pts)
352-
# Points 0=(0.0,), 1=(0.5,), 2=(1.0,) → sorted: 0, 1, 2
353-
assert len(t.simplices) == 2
354-
assert t.hull == {0, 2} # endpoints
355-
_check_triangulation_is_valid(t)
356-
assert np.isclose(sum(t.volumes()), 1.0)
357-
358-
359-
def test_1d_triangulation_unsorted_points():
360-
"""Test that 1D triangulation handles unsorted initial points."""
361-
pts = [(1.0,), (0.0,), (0.5,)]
362-
t = Triangulation(pts)
363-
assert len(t.simplices) == 2
331+
@pytest.mark.parametrize(
332+
("points", "expected_simplices", "expected_hull", "expected_total_volume"),
333+
[
334+
([(0.0,), (1.0,)], {(0, 1)}, {0, 1}, 1.0),
335+
([(0.0,), (0.5,), (1.0,)], {(0, 1), (1, 2)}, {0, 2}, 1.0),
336+
([(1.0,), (0.0,), (0.5,)], {(0, 2), (1, 2)}, {0, 1}, 1.0),
337+
],
338+
ids=["two-points", "sorted", "unsorted"],
339+
)
340+
def test_1d_triangulation_initialisation(
341+
points, expected_simplices, expected_hull, expected_total_volume
342+
):
343+
t = Triangulation(points)
344+
345+
assert t.simplices == expected_simplices
346+
assert t.hull == expected_hull
364347
_check_triangulation_is_valid(t)
365-
assert np.isclose(sum(t.volumes()), 1.0)
366-
367-
368-
def test_1d_add_point_inside():
369-
"""Test adding a point inside a 1D interval."""
370-
t = Triangulation([(0.0,), (1.0,)])
371-
_add_point_with_check(t, (0.5,))
372-
assert len(t.simplices) == 2
373-
assert t.hull == {0, 1} # original endpoints are still hull
374-
_check_triangulation_is_valid(t)
375-
assert np.isclose(sum(t.volumes()), 1.0)
376-
377-
378-
def test_1d_add_point_outside_right():
379-
"""Test adding a point to the right of a 1D triangulation."""
348+
assert np.isclose(sum(t.volumes()), expected_total_volume)
349+
350+
351+
@pytest.mark.parametrize(
352+
("point", "expected_simplices", "expected_hull", "expected_total_volume"),
353+
[
354+
((0.5,), {(0, 2), (1, 2)}, {0, 1}, 1.0),
355+
((2.0,), {(0, 1), (1, 2)}, {0, 2}, 2.0),
356+
((-1.0,), {(0, 1), (0, 2)}, {1, 2}, 2.0),
357+
],
358+
ids=["inside", "outside-right", "outside-left"],
359+
)
360+
def test_1d_add_point(point, expected_simplices, expected_hull, expected_total_volume):
380361
t = Triangulation([(0.0,), (1.0,)])
381-
_add_point_with_check(t, (2.0,))
382-
assert t.simplices == {(0, 1), (1, 2)}
383-
assert t.hull == {0, 2}
384-
_check_triangulation_is_valid(t)
385362

363+
_add_point_with_check(t, point)
386364

387-
def test_1d_add_point_outside_left():
388-
"""Test adding a point to the left of a 1D triangulation."""
389-
t = Triangulation([(0.0,), (1.0,)])
390-
_add_point_with_check(t, (-1.0,))
391-
assert t.simplices == {(0, 1), (0, 2)}
392-
assert t.hull == {1, 2}
365+
assert t.simplices == expected_simplices
366+
assert t.hull == expected_hull
393367
_check_triangulation_is_valid(t)
368+
assert np.isclose(sum(t.volumes()), expected_total_volume)
394369

395370

396371
def test_1d_locate_point():
@@ -407,18 +382,23 @@ def test_1d_locate_point():
407382
assert simplex == ()
408383

409384

410-
def test_1d_duplicate_coordinates_skipped():
385+
@pytest.mark.parametrize(
386+
("points", "expected_simplices"),
387+
[
388+
([(0.0,), (1.0,), (1.0,)], {(0, 1)}),
389+
([(0.0,), (0.0,), (0.5,), (1.0,), (1.0,)], None),
390+
],
391+
ids=["single-duplicate", "multiple-duplicates"],
392+
)
393+
def test_1d_duplicate_coordinates_skipped(points, expected_simplices):
411394
"""Test that duplicate 1D coordinates don't create degenerate simplices."""
412-
t = Triangulation([(0.0,), (1.0,), (1.0,)])
413-
# The duplicate (1.0,) should be skipped, leaving only one simplex
414-
assert t.simplices == {(0, 1)}
395+
t = Triangulation(points)
396+
397+
if expected_simplices is not None:
398+
assert t.simplices == expected_simplices
415399
assert all(v > 0 for v in t.volumes())
416400
_check_triangulation_is_valid(t)
417-
418-
# Multiple duplicates
419-
t2 = Triangulation([(0.0,), (0.0,), (0.5,), (1.0,), (1.0,)])
420-
assert all(v > 0 for v in t2.volumes())
421-
_check_triangulation_is_valid(t2)
401+
assert np.isclose(sum(t.volumes()), 1.0)
422402

423403

424404
def test_1d_opposing_vertices():

adaptive/tests/unit/test_learnernd.py

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
from scipy.spatial import ConvexHull
66

77
from adaptive.learner.base_learner import uses_nth_neighbors
8-
from adaptive.learner.learnerND import LearnerND, curvature_loss_function
8+
from adaptive.learner.learnerND import (
9+
LearnerND,
10+
curvature_loss_function,
11+
default_loss,
12+
std_loss,
13+
uniform_loss,
14+
)
915

1016

1117
def ring_of_fire(xy):
@@ -47,76 +53,87 @@ def loss(*args):
4753
# ---- 1D-specific LearnerND tests ----
4854

4955

56+
ONE_D_BOUNDS = [(-1, 1)]
57+
ONE_D_POINTS = (-1.0, -0.5, 0.0, 0.5, 1.0)
58+
59+
5060
def f_1d(x):
5161
"""Simple 1D test function."""
5262
return x[0] ** 2
5363

5464

65+
def make_1d_learner(function=f_1d, **kwargs):
66+
return LearnerND(function, bounds=ONE_D_BOUNDS, **kwargs)
67+
68+
69+
def tell_1d_points(learner, function=None, points=ONE_D_POINTS):
70+
function = learner.function if function is None else function
71+
for x in points:
72+
learner.tell((x,), function((x,)))
73+
74+
75+
def initialize_1d_learner(**kwargs):
76+
learner = make_1d_learner(**kwargs)
77+
points, _ = learner.ask(2)
78+
for point in points:
79+
learner.tell(point, learner.function(point))
80+
return learner
81+
82+
5583
def test_learnerND_1d_construction():
5684
"""Test that LearnerND can be constructed with 1D bounds."""
57-
learner = LearnerND(f_1d, bounds=[(-1, 1)])
85+
learner = make_1d_learner()
5886
assert learner.ndim == 1
5987
assert learner._bounds_points == [(-1,), (1,)]
6088
assert learner._bbox == ((-1.0, 1.0),)
6189

6290

63-
def test_learnerND_1d_tell_ask():
91+
@pytest.mark.parametrize(
92+
("loss_fn", "expected_nth_neighbors"),
93+
[
94+
(None, 0),
95+
(curvature_loss_function(), 1),
96+
],
97+
ids=["default", "curvature"],
98+
)
99+
def test_learnerND_1d_tell_ask(loss_fn, expected_nth_neighbors):
64100
"""Test basic tell/ask cycle for 1D LearnerND."""
65-
learner = LearnerND(f_1d, bounds=[(-1, 1)])
66-
# Ask for bound points first
67-
points, losses = learner.ask(2)
68-
assert len(points) == 2
69-
# Tell the boundary values
70-
for p in points:
71-
learner.tell(p, f_1d(p))
72-
# Now we should have a triangulation
101+
kwargs = {} if loss_fn is None else {"loss_per_simplex": loss_fn}
102+
learner = initialize_1d_learner(**kwargs)
103+
73104
assert learner.tri is not None
74-
# Ask for more points
105+
assert learner.nth_neighbors == expected_nth_neighbors
106+
75107
points2, losses2 = learner.ask(3)
108+
76109
assert len(points2) == 3
110+
assert all(loss > 0 for loss in losses2)
77111

78112

79-
def test_learnerND_1d_loss_functions():
113+
@pytest.mark.parametrize(
114+
"loss_fn",
115+
[
116+
pytest.param(loss_fn, id=loss_fn.__name__)
117+
for loss_fn in (default_loss, uniform_loss, std_loss)
118+
],
119+
)
120+
def test_learnerND_1d_loss_functions(loss_fn):
80121
"""Test that all standard loss functions work for 1D."""
81-
from adaptive.learner.learnerND import (
82-
default_loss,
83-
std_loss,
84-
uniform_loss,
85-
)
86-
87-
for loss_fn in [default_loss, uniform_loss, std_loss]:
88-
learner = LearnerND(f_1d, bounds=[(-1, 1)], loss_per_simplex=loss_fn)
89-
points, _ = learner.ask(2)
90-
for p in points:
91-
learner.tell(p, f_1d(p))
92-
points2, losses2 = learner.ask(3)
93-
assert len(points2) == 3
94-
assert all(l > 0 for l in losses2)
95-
96-
97-
def test_learnerND_1d_curvature_loss():
98-
"""Test that curvature loss function works for 1D."""
99-
loss = curvature_loss_function()
100-
learner = LearnerND(f_1d, bounds=[(-1, 1)], loss_per_simplex=loss)
101-
assert learner.nth_neighbors == 1
102-
points, _ = learner.ask(2)
103-
for p in points:
104-
learner.tell(p, f_1d(p))
105-
points2, _ = learner.ask(3)
122+
learner = initialize_1d_learner(loss_per_simplex=loss_fn)
123+
points2, losses2 = learner.ask(3)
124+
106125
assert len(points2) == 3
126+
assert all(loss > 0 for loss in losses2)
107127

108128

109129
def test_learnerND_1d_interpolation():
110130
"""Test that 1D interpolation works correctly."""
111-
learner = LearnerND(f_1d, bounds=[(-1, 1)])
112-
# Tell some points
113-
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
114-
learner.tell((x,), x**2)
131+
learner = make_1d_learner()
132+
tell_1d_points(learner)
115133
ip = learner._ip()
116-
# Check interpolation at known points
134+
117135
assert np.isclose(ip(0.0), 0.0)
118136
assert np.isclose(ip(1.0), 1.0)
119-
# Check interpolation at midpoint (linear interpolation)
120137
assert np.isclose(ip(0.25), 0.125) # linear between 0 and 0.5
121138

122139

@@ -126,9 +143,8 @@ def test_learnerND_1d_vector_output_interpolation():
126143
def f_vec(x):
127144
return np.array([x[0] ** 2, np.sin(x[0])])
128145

129-
learner = LearnerND(f_vec, bounds=[(-1, 1)])
130-
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
131-
learner.tell((x,), f_vec((x,)))
146+
learner = make_1d_learner(function=f_vec)
147+
tell_1d_points(learner, function=f_vec)
132148
ip = learner._ip()
133149
result = ip(0.0)
134150
assert result.shape == (2,)
@@ -141,8 +157,7 @@ def test_learnerND_1d_plot():
141157
import holoviews as hv
142158

143159
hv.extension("bokeh")
144-
learner = LearnerND(f_1d, bounds=[(-1, 1)])
145-
for x in [-1.0, -0.5, 0.0, 0.5, 1.0]:
146-
learner.tell((x,), x**2)
160+
learner = make_1d_learner()
161+
tell_1d_points(learner)
147162
plot = learner.plot()
148163
assert plot is not None

0 commit comments

Comments
 (0)