Skip to content

Commit 25fb9d0

Browse files
Merge pull request #2157 from OceanParcels/nearest_interpolation
`XNearest` interpolation
2 parents 8ee9b08 + 3e481b0 commit 25fb9d0

3 files changed

Lines changed: 129 additions & 48 deletions

File tree

parcels/application_kernels/interpolation.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"UXPiecewiseConstantFace",
1818
"UXPiecewiseLinearNode",
1919
"XLinear",
20+
"XNearest",
2021
"ZeroInterpolator",
2122
]
2223

@@ -111,6 +112,69 @@ def XLinear(
111112
return value.compute() if isinstance(value, dask.Array) else value
112113

113114

115+
def XNearest(
116+
field: Field,
117+
ti: int,
118+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
119+
tau: np.float32 | np.float64,
120+
t: np.float32 | np.float64,
121+
z: np.float32 | np.float64,
122+
y: np.float32 | np.float64,
123+
x: np.float32 | np.float64,
124+
):
125+
"""
126+
Nearest-Neighbour spatial interpolation on a regular grid.
127+
Note that this still uses linear interpolation in time.
128+
"""
129+
xi, xsi = position["X"]
130+
yi, eta = position["Y"]
131+
zi, zeta = position["Z"]
132+
133+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
134+
data = field.data
135+
136+
lenT = 2 if np.any(tau > 0) else 1
137+
138+
# Spatial coordinates: left if barycentric < 0.5, otherwise right
139+
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
140+
zi_full = np.where(zeta < 0.5, zi, zi_1)
141+
142+
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
143+
yi_full = np.where(eta < 0.5, yi, yi_1)
144+
145+
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
146+
xi_full = np.where(xsi < 0.5, xi, xi_1)
147+
148+
# Time coordinates: 1 point at ti, then 1 point at ti+1
149+
if lenT == 1:
150+
ti_full = ti
151+
else:
152+
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
153+
ti_full = np.concatenate([ti, ti_1])
154+
xi_full = np.repeat(xi_full, 2)
155+
yi_full = np.repeat(yi_full, 2)
156+
zi_full = np.repeat(zi_full, 2)
157+
158+
# Create DataArrays for indexing
159+
selection_dict = {
160+
axis_dim["X"]: xr.DataArray(xi_full, dims=("points")),
161+
axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")),
162+
}
163+
if "Z" in axis_dim:
164+
selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points"))
165+
if "time" in data.dims:
166+
selection_dict["time"] = xr.DataArray(ti_full, dims=("points"))
167+
168+
corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi))
169+
170+
if lenT == 2:
171+
value = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau
172+
else:
173+
value = corner_data[0, :]
174+
175+
return value.compute() if isinstance(value, dask.Array) else value
176+
177+
114178
def UXPiecewiseConstantFace(
115179
field: Field,
116180
ti: int,

tests/test_interpolation.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import numpy as np
21
import pytest
3-
import xarray as xr
42

53
import parcels._interpolation as interpolation
64
from tests.utils import create_fieldset_zeros_3d
@@ -31,50 +29,6 @@ def some_function():
3129
assert f() == g() == "test"
3230

3331

34-
def create_interpolation_data():
35-
"""Reference data used for testing interpolation.
36-
37-
Most interpolation will be focussed around index
38-
(depth, lat, lon) = (zi, yi, xi) = (1, 1, 1) with ti=0.
39-
"""
40-
z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous
41-
[
42-
[0.0, 1.0, 2.0, 3.0],
43-
[2.0, 3.0, 4.0, 5.0],
44-
[4.0, 5.0, 6.0, 7.0],
45-
[6.0, 7.0, 8.0, 9.0],
46-
]
47-
)
48-
spatial_data = [z0, z0 + 3, z0 + 6, z0 + 9] # each z is +3 from the previous
49-
return xr.DataArray([spatial_data, spatial_data, spatial_data], dims=("time", "depth", "lat", "lon"))
50-
51-
52-
@pytest.fixture
53-
def data_2d():
54-
"""2D slice of the reference data at depth=0."""
55-
return create_interpolation_data().isel(depth=0).values
56-
57-
58-
@pytest.mark.v4remove
59-
@pytest.mark.xfail(reason="GH1946")
60-
@pytest.mark.parametrize(
61-
"func, eta, xsi, expected",
62-
[
63-
pytest.param(interpolation._nearest_2d, 0.49, 0.49, 3.0, id="nearest_2d-1"),
64-
pytest.param(interpolation._nearest_2d, 0.49, 0.51, 4.0, id="nearest_2d-2"),
65-
pytest.param(interpolation._nearest_2d, 0.51, 0.49, 5.0, id="nearest_2d-3"),
66-
pytest.param(interpolation._nearest_2d, 0.51, 0.51, 6.0, id="nearest_2d-4"),
67-
pytest.param(interpolation._tracer_2d, None, None, 6.0, id="tracer_2d"),
68-
],
69-
)
70-
def test_raw_2d_interpolation(data_2d, func, eta, xsi, expected):
71-
"""Test the 2D interpolation functions on the raw arrays."""
72-
tau, ti = 0, 0
73-
yi, xi = 1, 1
74-
ctx = interpolation.InterpolationContext2D(data_2d, tau, eta, xsi, ti, yi, xi)
75-
assert func(ctx) == expected
76-
77-
7832
@pytest.mark.v4remove
7933
@pytest.mark.xfail(reason="GH1946")
8034
@pytest.mark.usefixtures("tmp_interpolator_registry")

tests/v4/test_interpolation.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
from parcels._datasets.structured.generated import simple_UV_dataset
66
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
7+
from parcels._index_search import _search_time_index
78
from parcels.application_kernels.advection import AdvectionRK4_3D
8-
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear
9+
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, XNearest, ZeroInterpolator
910
from parcels.field import Field, VectorField
1011
from parcels.fieldset import FieldSet
1112
from parcels.particle import Particle, Variable
@@ -16,8 +17,70 @@
1617
from tests.utils import TEST_DATA
1718

1819

20+
@pytest.fixture
21+
def field():
22+
"""Reference data used for testing interpolation."""
23+
z0 = np.array( # each x is +1 from the previous, each y is +2 from the previous
24+
[
25+
[0.0, 1.0, 2.0, 3.0],
26+
[2.0, 3.0, 4.0, 5.0],
27+
[4.0, 5.0, 6.0, 7.0],
28+
[6.0, 7.0, 8.0, 9.0],
29+
]
30+
)
31+
spatial_data = np.array([z0, z0 + 3, z0 + 6, z0 + 9]) # each z is +3 from the previous
32+
temporal_data = np.array([spatial_data, spatial_data + 10, spatial_data + 20]) # each t is +10 from the previous
33+
34+
ds = xr.Dataset(
35+
{"U": (["time", "depth", "lat", "lon"], temporal_data)},
36+
coords={
37+
"time": (["time"], [np.timedelta64(t, "s") for t in [0, 2, 4]], {"axis": "T"}),
38+
"depth": (["depth"], [0, 1, 2, 3], {"axis": "Z"}),
39+
"lat": (["lat"], [0, 1, 2, 3], {"axis": "Y", "c_grid_axis_shift": -0.5}),
40+
"lon": (["lon"], [0, 1, 2, 3], {"axis": "X", "c_grid_axis_shift": -0.5}),
41+
"x": (["x"], [0.5, 1.5, 2.5, 3.5], {"axis": "X"}),
42+
"y": (["y"], [0.5, 1.5, 2.5, 3.5], {"axis": "Y"}),
43+
},
44+
)
45+
return Field("U", ds["U"], XGrid.from_dataset(ds))
46+
47+
48+
@pytest.mark.parametrize(
49+
"func, t, z, y, x, expected",
50+
[
51+
pytest.param(ZeroInterpolator, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 0, id="Zero"),
52+
pytest.param(
53+
XLinear,
54+
[np.timedelta64(0, "s"), np.timedelta64(1, "s")],
55+
[0, 0],
56+
[0.49, 0.49],
57+
[0.51, 0.51],
58+
[1.49, 6.49],
59+
id="Linear",
60+
),
61+
pytest.param(XLinear, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 13.99, id="Linear-2"),
62+
pytest.param(
63+
XNearest,
64+
[np.timedelta64(0, "s"), np.timedelta64(3, "s")],
65+
[0.2, 0.2],
66+
[0.2, 0.2],
67+
[0.51, 0.51],
68+
[1.0, 16.0],
69+
id="Nearest",
70+
),
71+
],
72+
)
73+
def test_raw_2d_interpolation(field, func, t, z, y, x, expected):
74+
"""Test the interpolation functions on the Field."""
75+
tau, ti = _search_time_index(field, t)
76+
position = field.grid.search(z, y, x)
77+
78+
value = func(field, ti, position, tau, 0, 0, y, x)
79+
np.testing.assert_equal(value, expected)
80+
81+
1982
@pytest.mark.parametrize("mesh", ["spherical", "flat"])
20-
def test_interpolation_mesh(mesh, npart=10):
83+
def test_interpolation_mesh_type(mesh, npart=10):
2184
ds = simple_UV_dataset(mesh=mesh)
2285
ds["U"].data[:] = 1.0
2386
grid = XGrid.from_dataset(ds, mesh=mesh)

0 commit comments

Comments
 (0)