|
4 | 4 |
|
5 | 5 | from parcels._datasets.structured.generated import simple_UV_dataset |
6 | 6 | from parcels._datasets.unstructured.generic import datasets as datasets_unstructured |
| 7 | +from parcels._index_search import _search_time_index |
7 | 8 | from parcels.application_kernels.advection import AdvectionRK4_3D |
8 | | -from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear |
| 9 | +from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, XNearest, ZeroInterpolator |
9 | 10 | from parcels.field import Field, VectorField |
10 | 11 | from parcels.fieldset import FieldSet |
11 | 12 | from parcels.particle import Particle, Variable |
|
16 | 17 | from tests.utils import TEST_DATA |
17 | 18 |
|
18 | 19 |
|
| 20 | +@pytest.fixture |
| 21 | +def field(): |
| 22 | + """Reference data used for testing interpolation.""" |
| 23 | + z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous |
| 24 | + [ |
| 25 | + [0.0, 1.0, 2.0, 3.0], |
| 26 | + [2.0, 3.0, 4.0, 5.0], |
| 27 | + [4.0, 5.0, 6.0, 7.0], |
| 28 | + [6.0, 7.0, 8.0, 9.0], |
| 29 | + ] |
| 30 | + ) |
| 31 | + spatial_data = np.array([z0, z0 + 3, z0 + 6, z0 + 9]) # each z is +3 from the previous |
| 32 | + temporal_data = np.array([spatial_data, spatial_data + 10, spatial_data + 20]) # each t is +10 from the previous |
| 33 | + |
| 34 | + ds = xr.Dataset( |
| 35 | + {"U": (["time", "depth", "lat", "lon"], temporal_data)}, |
| 36 | + coords={ |
| 37 | + "time": (["time"], [np.timedelta64(t, "s") for t in [0, 2, 4]], {"axis": "T"}), |
| 38 | + "depth": (["depth"], [0, 1, 2, 3], {"axis": "Z"}), |
| 39 | + "lat": (["lat"], [0, 1, 2, 3], {"axis": "Y", "c_grid_axis_shift": -0.5}), |
| 40 | + "lon": (["lon"], [0, 1, 2, 3], {"axis": "X", "c_grid_axis_shift": -0.5}), |
| 41 | + "x": (["x"], [0.5, 1.5, 2.5, 3.5], {"axis": "X"}), |
| 42 | + "y": (["y"], [0.5, 1.5, 2.5, 3.5], {"axis": "Y"}), |
| 43 | + }, |
| 44 | + ) |
| 45 | + return Field("U", ds["U"], XGrid.from_dataset(ds)) |
| 46 | + |
| 47 | + |
| 48 | +@pytest.mark.parametrize( |
| 49 | + "func, t, z, y, x, expected", |
| 50 | + [ |
| 51 | + pytest.param(ZeroInterpolator, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 0, id="Zero"), |
| 52 | + pytest.param( |
| 53 | + XLinear, |
| 54 | + [np.timedelta64(0, "s"), np.timedelta64(1, "s")], |
| 55 | + [0, 0], |
| 56 | + [0.49, 0.49], |
| 57 | + [0.51, 0.51], |
| 58 | + [1.49, 6.49], |
| 59 | + id="Linear", |
| 60 | + ), |
| 61 | + pytest.param(XLinear, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 13.99, id="Linear-2"), |
| 62 | + pytest.param( |
| 63 | + XNearest, |
| 64 | + [np.timedelta64(0, "s"), np.timedelta64(3, "s")], |
| 65 | + [0.2, 0.2], |
| 66 | + [0.2, 0.2], |
| 67 | + [0.51, 0.51], |
| 68 | + [1.0, 16.0], |
| 69 | + id="Nearest", |
| 70 | + ), |
| 71 | + ], |
| 72 | +) |
| 73 | +def test_raw_2d_interpolation(field, func, t, z, y, x, expected): |
| 74 | + """Test the interpolation functions on the Field.""" |
| 75 | + tau, ti = _search_time_index(field, t) |
| 76 | + position = field.grid.search(z, y, x) |
| 77 | + |
| 78 | + value = func(field, ti, position, tau, 0, 0, y, x) |
| 79 | + np.testing.assert_equal(value, expected) |
| 80 | + |
| 81 | + |
19 | 82 | @pytest.mark.parametrize("mesh", ["spherical", "flat"]) |
20 | | -def test_interpolation_mesh(mesh, npart=10): |
| 83 | +def test_interpolation_mesh_type(mesh, npart=10): |
21 | 84 | ds = simple_UV_dataset(mesh=mesh) |
22 | 85 | ds["U"].data[:] = 1.0 |
23 | 86 | grid = XGrid.from_dataset(ds, mesh=mesh) |
|
0 commit comments