Skip to content

Commit f14fab8

Browse files
Porting interpolation unit test from v3 to v4
1 parent 62a596c commit f14fab8

2 files changed

Lines changed: 48 additions & 47 deletions

File tree

tests/test_interpolation.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import numpy as np
21
import pytest
3-
import xarray as xr
42

53
import parcels._interpolation as interpolation
64
from tests.utils import create_fieldset_zeros_3d
@@ -31,50 +29,6 @@ def some_function():
3129
assert f() == g() == "test"
3230

3331

34-
def create_interpolation_data():
35-
"""Reference data used for testing interpolation.
36-
37-
Most interpolation will be focussed around index
38-
(depth, lat, lon) = (zi, yi, xi) = (1, 1, 1) with ti=0.
39-
"""
40-
z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous
41-
[
42-
[0.0, 1.0, 2.0, 3.0],
43-
[2.0, 3.0, 4.0, 5.0],
44-
[4.0, 5.0, 6.0, 7.0],
45-
[6.0, 7.0, 8.0, 9.0],
46-
]
47-
)
48-
spatial_data = [z0, z0 + 3, z0 + 6, z0 + 9] # each z is +3 from the previous
49-
return xr.DataArray([spatial_data, spatial_data, spatial_data], dims=("time", "depth", "lat", "lon"))
50-
51-
52-
@pytest.fixture
53-
def data_2d():
54-
"""2D slice of the reference data at depth=0."""
55-
return create_interpolation_data().isel(depth=0).values
56-
57-
58-
@pytest.mark.v4remove
59-
@pytest.mark.xfail(reason="GH1946")
60-
@pytest.mark.parametrize(
61-
"func, eta, xsi, expected",
62-
[
63-
pytest.param(interpolation._nearest_2d, 0.49, 0.49, 3.0, id="nearest_2d-1"),
64-
pytest.param(interpolation._nearest_2d, 0.49, 0.51, 4.0, id="nearest_2d-2"),
65-
pytest.param(interpolation._nearest_2d, 0.51, 0.49, 5.0, id="nearest_2d-3"),
66-
pytest.param(interpolation._nearest_2d, 0.51, 0.51, 6.0, id="nearest_2d-4"),
67-
pytest.param(interpolation._tracer_2d, None, None, 6.0, id="tracer_2d"),
68-
],
69-
)
70-
def test_raw_2d_interpolation(data_2d, func, eta, xsi, expected):
71-
"""Test the 2D interpolation functions on the raw arrays."""
72-
tau, ti = 0, 0
73-
yi, xi = 1, 1
74-
ctx = interpolation.InterpolationContext2D(data_2d, tau, eta, xsi, ti, yi, xi)
75-
assert func(ctx) == expected
76-
77-
7832
@pytest.mark.v4remove
7933
@pytest.mark.xfail(reason="GH1946")
8034
@pytest.mark.usefixtures("tmp_interpolator_registry")

tests/v4/test_interpolation.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
from parcels._datasets.structured.generated import simple_UV_dataset
66
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
7+
from parcels._index_search import _search_time_index
78
from parcels.application_kernels.advection import AdvectionRK4_3D
8-
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear
9+
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, ZeroInterpolator
910
from parcels.field import Field, VectorField
1011
from parcels.fieldset import FieldSet
1112
from parcels.particle import Particle, Variable
@@ -16,6 +17,52 @@
1617
from tests.utils import TEST_DATA
1718

1819

20+
@pytest.fixture
21+
def field():
22+
"""Reference data used for testing interpolation."""
23+
z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous
24+
[
25+
[0.0, 1.0, 2.0, 3.0],
26+
[2.0, 3.0, 4.0, 5.0],
27+
[4.0, 5.0, 6.0, 7.0],
28+
[6.0, 7.0, 8.0, 9.0],
29+
]
30+
)
31+
spatial_data = np.array([z0, z0 + 3, z0 + 6, z0 + 9]) # each z is +3 from the previous
32+
temporal_data = np.array([spatial_data, spatial_data + 10, spatial_data + 20]) # each t is +10 from the previous
33+
34+
ds = xr.Dataset(
35+
{"U": (["time", "depth", "lat", "lon"], temporal_data)},
36+
coords={
37+
"time": (["time"], [np.timedelta64(t, "s") for t in [0, 2, 4]], {"axis": "T"}),
38+
"depth": (["depth"], [0, 1, 2, 3], {"axis": "Z"}),
39+
"lat": (["lat"], [0, 1, 2, 3], {"axis": "Y", "c_grid_axis_shift": -0.5}),
40+
"lon": (["lon"], [0, 1, 2, 3], {"axis": "X", "c_grid_axis_shift": -0.5}),
41+
"x": (["x"], [0.5, 1.5, 2.5, 3.5], {"axis": "X"}),
42+
"y": (["y"], [0.5, 1.5, 2.5, 3.5], {"axis": "Y"}),
43+
},
44+
)
45+
return Field("U", ds["U"], XGrid.from_dataset(ds))
46+
47+
48+
@pytest.mark.parametrize(
49+
"func, t, z, y, x, expected",
50+
[
51+
pytest.param(ZeroInterpolator, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 0, id="Zero"),
52+
pytest.param(XLinear, np.timedelta64(0, "s"), 0, 0.49, 0.51, 1.49, id="Linear-1"),
53+
pytest.param(XLinear, np.timedelta64(1, "s"), 0, 0.49, 0.51, 6.49, id="Linear-2"),
54+
pytest.param(XLinear, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 13.99, id="Linear-2"),
55+
],
56+
)
57+
def test_raw_2d_interpolation(field, func, t, z, y, x, expected):
58+
"""Test the interpolation functions on the Field."""
59+
tau, ti = _search_time_index(field, t)
60+
position = field.grid.search(z, y, x)
61+
62+
value = func(field, ti, position, tau, 0, 0, y, x)
63+
np.testing.assert_equal(value, expected)
64+
65+
1966
@pytest.mark.parametrize("mesh_type", ["spherical", "flat"])
2067
def test_interpolation_mesh_type(mesh_type, npart=10):
2168
ds = simple_UV_dataset(mesh_type=mesh_type)

0 commit comments

Comments
 (0)