Skip to content

Commit 74fe825

Browse files
CopilotaZira371
authored andcommitted
TST: Add tests for shepard_fallback in test_function_grid.py (#879)
* Add tests for shepard_fallback in test_function_grid.py Co-authored-by: aZira371 <99824864+aZira371@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: aZira371 <99824864+aZira371@users.noreply.github.com>
1 parent 3f76344 commit 74fe825

1 file changed

Lines changed: 140 additions & 0 deletions

File tree

tests/unit/mathutils/test_function_grid.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Unit tests for Function.from_grid() method and grid interpolation."""
22

3+
import warnings
4+
35
import numpy as np
46
import pytest
57

@@ -137,3 +139,141 @@ def test_from_grid_backward_compatibility():
137139
# Test callable function
138140
func3 = Function(lambda x: x**2)
139141
assert func3(2) == 4
142+
143+
144+
def test_shepard_fallback_warning():
145+
"""Test that shepard_fallback is triggered and emits a warning.
146+
147+
When linear_grid interpolation is set but no grid interpolator is available,
148+
the Function class should fall back to shepard interpolation and emit a warning.
149+
"""
150+
# Create a 2D function with scattered points (not structured grid)
151+
source = [(0, 0, 0), (1, 0, 1), (0, 1, 2), (1, 1, 3)]
152+
func = Function(
153+
source=source, inputs=["x", "y"], outputs="z", interpolation="shepard"
154+
)
155+
156+
# Now manually change interpolation to linear_grid without setting up the grid
157+
# This simulates the fallback scenario
158+
with warnings.catch_warnings(record=True) as w:
159+
warnings.simplefilter("always")
160+
func.set_interpolation("linear_grid")
161+
162+
# Check that a warning was issued
163+
assert len(w) == 1
164+
assert "falling back to shepard interpolation" in str(w[0].message)
165+
166+
167+
def test_shepard_fallback_2d_interpolation():
168+
"""Test that shepard_fallback produces correct interpolation for 2D data.
169+
170+
This test verifies the fallback interpolation works correctly when
171+
linear_grid is set without a grid interpolator.
172+
"""
173+
# Create a 2D function: z = x + y
174+
source = [
175+
(0, 0, 0), # f(0, 0) = 0
176+
(1, 0, 1), # f(1, 0) = 1
177+
(0, 1, 1), # f(0, 1) = 1
178+
(1, 1, 2), # f(1, 1) = 2
179+
]
180+
181+
# First, create with shepard to get baseline results
182+
func_shepard = Function(
183+
source=source, inputs=["x", "y"], outputs="z", interpolation="shepard"
184+
)
185+
186+
# Create another function and trigger the fallback
187+
func_fallback = Function(
188+
source=source, inputs=["x", "y"], outputs="z", interpolation="shepard"
189+
)
190+
191+
# Trigger fallback
192+
with warnings.catch_warnings():
193+
warnings.simplefilter("ignore") # Suppress warnings for this test
194+
func_fallback.set_interpolation("linear_grid")
195+
196+
# Test that both produce the same results at exact points
197+
assert func_fallback(0, 0) == func_shepard(0, 0)
198+
assert func_fallback(1, 1) == func_shepard(1, 1)
199+
200+
# Test interpolation at an intermediate point
201+
result_fallback = func_fallback(0.5, 0.5)
202+
result_shepard = func_shepard(0.5, 0.5)
203+
assert np.isclose(result_fallback, result_shepard, atol=1e-6)
204+
205+
206+
def test_shepard_fallback_3d_interpolation():
207+
"""Test that shepard_fallback produces correct interpolation for 3D data.
208+
209+
This test verifies the fallback interpolation works correctly for
210+
3-dimensional input data.
211+
"""
212+
# Create a 3D function: w = x + y + z
213+
source = [
214+
(0, 0, 0, 0), # f(0, 0, 0) = 0
215+
(1, 0, 0, 1), # f(1, 0, 0) = 1
216+
(0, 1, 0, 1), # f(0, 1, 0) = 1
217+
(0, 0, 1, 1), # f(0, 0, 1) = 1
218+
(1, 1, 1, 3), # f(1, 1, 1) = 3
219+
]
220+
221+
# Create with shepard to get baseline results
222+
func_shepard = Function(
223+
source=source,
224+
inputs=["x", "y", "z"],
225+
outputs="w",
226+
interpolation="shepard",
227+
)
228+
229+
# Create another function and trigger the fallback
230+
func_fallback = Function(
231+
source=source,
232+
inputs=["x", "y", "z"],
233+
outputs="w",
234+
interpolation="shepard",
235+
)
236+
237+
# Trigger fallback
238+
with warnings.catch_warnings():
239+
warnings.simplefilter("ignore")
240+
func_fallback.set_interpolation("linear_grid")
241+
242+
# Test that both produce the same results at exact points
243+
assert func_fallback(0, 0, 0) == func_shepard(0, 0, 0)
244+
assert func_fallback(1, 1, 1) == func_shepard(1, 1, 1)
245+
246+
# Test interpolation at an intermediate point
247+
result_fallback = func_fallback(0.5, 0.5, 0.5)
248+
result_shepard = func_shepard(0.5, 0.5, 0.5)
249+
assert np.isclose(result_fallback, result_shepard, atol=1e-6)
250+
251+
252+
def test_shepard_fallback_at_exact_data_points():
253+
"""Test that shepard_fallback returns exact values at data points.
254+
255+
When querying at exact data points, the fallback should return the
256+
exact value stored at that point.
257+
"""
258+
# Create a 2D function
259+
source = [
260+
(0, 0, 10),
261+
(1, 0, 20),
262+
(0, 1, 30),
263+
(1, 1, 40),
264+
]
265+
266+
func = Function(
267+
source=source, inputs=["x", "y"], outputs="z", interpolation="shepard"
268+
)
269+
270+
# Trigger fallback
271+
with warnings.catch_warnings():
272+
warnings.simplefilter("ignore")
273+
func.set_interpolation("linear_grid")
274+
275+
# Test exact data points - should return exact values
276+
assert func(0, 0) == 10
277+
assert func(1, 0) == 20
278+
assert func(0, 1) == 30
279+
assert func(1, 1) == 40

0 commit comments

Comments
 (0)