|
1 | 1 | """Unit tests for Function.from_grid() method and grid interpolation.""" |
2 | 2 |
|
| 3 | +import warnings |
| 4 | + |
3 | 5 | import numpy as np |
4 | 6 | import pytest |
5 | 7 |
|
@@ -137,3 +139,141 @@ def test_from_grid_backward_compatibility(): |
137 | 139 | # Test callable function |
138 | 140 | func3 = Function(lambda x: x**2) |
139 | 141 | assert func3(2) == 4 |
| 142 | + |
| 143 | + |
| 144 | +def test_shepard_fallback_warning(): |
| 145 | + """Test that shepard_fallback is triggered and emits a warning. |
| 146 | +
|
| 147 | + When linear_grid interpolation is set but no grid interpolator is available, |
| 148 | + the Function class should fall back to shepard interpolation and emit a warning. |
| 149 | + """ |
| 150 | + # Create a 2D function with scattered points (not structured grid) |
| 151 | + source = [(0, 0, 0), (1, 0, 1), (0, 1, 2), (1, 1, 3)] |
| 152 | + func = Function( |
| 153 | + source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" |
| 154 | + ) |
| 155 | + |
| 156 | + # Now manually change interpolation to linear_grid without setting up the grid |
| 157 | + # This simulates the fallback scenario |
| 158 | + with warnings.catch_warnings(record=True) as w: |
| 159 | + warnings.simplefilter("always") |
| 160 | + func.set_interpolation("linear_grid") |
| 161 | + |
| 162 | + # Check that a warning was issued |
| 163 | + assert len(w) == 1 |
| 164 | + assert "falling back to shepard interpolation" in str(w[0].message) |
| 165 | + |
| 166 | + |
| 167 | +def test_shepard_fallback_2d_interpolation(): |
| 168 | + """Test that shepard_fallback produces correct interpolation for 2D data. |
| 169 | +
|
| 170 | + This test verifies the fallback interpolation works correctly when |
| 171 | + linear_grid is set without a grid interpolator. |
| 172 | + """ |
| 173 | + # Create a 2D function: z = x + y |
| 174 | + source = [ |
| 175 | + (0, 0, 0), # f(0, 0) = 0 |
| 176 | + (1, 0, 1), # f(1, 0) = 1 |
| 177 | + (0, 1, 1), # f(0, 1) = 1 |
| 178 | + (1, 1, 2), # f(1, 1) = 2 |
| 179 | + ] |
| 180 | + |
| 181 | + # First, create with shepard to get baseline results |
| 182 | + func_shepard = Function( |
| 183 | + source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" |
| 184 | + ) |
| 185 | + |
| 186 | + # Create another function and trigger the fallback |
| 187 | + func_fallback = Function( |
| 188 | + source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" |
| 189 | + ) |
| 190 | + |
| 191 | + # Trigger fallback |
| 192 | + with warnings.catch_warnings(): |
| 193 | + warnings.simplefilter("ignore") # Suppress warnings for this test |
| 194 | + func_fallback.set_interpolation("linear_grid") |
| 195 | + |
| 196 | + # Test that both produce the same results at exact points |
| 197 | + assert func_fallback(0, 0) == func_shepard(0, 0) |
| 198 | + assert func_fallback(1, 1) == func_shepard(1, 1) |
| 199 | + |
| 200 | + # Test interpolation at an intermediate point |
| 201 | + result_fallback = func_fallback(0.5, 0.5) |
| 202 | + result_shepard = func_shepard(0.5, 0.5) |
| 203 | + assert np.isclose(result_fallback, result_shepard, atol=1e-6) |
| 204 | + |
| 205 | + |
| 206 | +def test_shepard_fallback_3d_interpolation(): |
| 207 | + """Test that shepard_fallback produces correct interpolation for 3D data. |
| 208 | +
|
| 209 | + This test verifies the fallback interpolation works correctly for |
| 210 | + 3-dimensional input data. |
| 211 | + """ |
| 212 | + # Create a 3D function: w = x + y + z |
| 213 | + source = [ |
| 214 | + (0, 0, 0, 0), # f(0, 0, 0) = 0 |
| 215 | + (1, 0, 0, 1), # f(1, 0, 0) = 1 |
| 216 | + (0, 1, 0, 1), # f(0, 1, 0) = 1 |
| 217 | + (0, 0, 1, 1), # f(0, 0, 1) = 1 |
| 218 | + (1, 1, 1, 3), # f(1, 1, 1) = 3 |
| 219 | + ] |
| 220 | + |
| 221 | + # Create with shepard to get baseline results |
| 222 | + func_shepard = Function( |
| 223 | + source=source, |
| 224 | + inputs=["x", "y", "z"], |
| 225 | + outputs="w", |
| 226 | + interpolation="shepard", |
| 227 | + ) |
| 228 | + |
| 229 | + # Create another function and trigger the fallback |
| 230 | + func_fallback = Function( |
| 231 | + source=source, |
| 232 | + inputs=["x", "y", "z"], |
| 233 | + outputs="w", |
| 234 | + interpolation="shepard", |
| 235 | + ) |
| 236 | + |
| 237 | + # Trigger fallback |
| 238 | + with warnings.catch_warnings(): |
| 239 | + warnings.simplefilter("ignore") |
| 240 | + func_fallback.set_interpolation("linear_grid") |
| 241 | + |
| 242 | + # Test that both produce the same results at exact points |
| 243 | + assert func_fallback(0, 0, 0) == func_shepard(0, 0, 0) |
| 244 | + assert func_fallback(1, 1, 1) == func_shepard(1, 1, 1) |
| 245 | + |
| 246 | + # Test interpolation at an intermediate point |
| 247 | + result_fallback = func_fallback(0.5, 0.5, 0.5) |
| 248 | + result_shepard = func_shepard(0.5, 0.5, 0.5) |
| 249 | + assert np.isclose(result_fallback, result_shepard, atol=1e-6) |
| 250 | + |
| 251 | + |
| 252 | +def test_shepard_fallback_at_exact_data_points(): |
| 253 | + """Test that shepard_fallback returns exact values at data points. |
| 254 | +
|
| 255 | + When querying at exact data points, the fallback should return the |
| 256 | + exact value stored at that point. |
| 257 | + """ |
| 258 | + # Create a 2D function |
| 259 | + source = [ |
| 260 | + (0, 0, 10), |
| 261 | + (1, 0, 20), |
| 262 | + (0, 1, 30), |
| 263 | + (1, 1, 40), |
| 264 | + ] |
| 265 | + |
| 266 | + func = Function( |
| 267 | + source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" |
| 268 | + ) |
| 269 | + |
| 270 | + # Trigger fallback |
| 271 | + with warnings.catch_warnings(): |
| 272 | + warnings.simplefilter("ignore") |
| 273 | + func.set_interpolation("linear_grid") |
| 274 | + |
| 275 | + # Test exact data points - should return exact values |
| 276 | + assert func(0, 0) == 10 |
| 277 | + assert func(1, 0) == 20 |
| 278 | + assert func(0, 1) == 30 |
| 279 | + assert func(1, 1) == 40 |
0 commit comments