Skip to content

Commit 689061d

Browse files
Adding XNearest for nearest neighbour interpolation
1 parent f14fab8 commit 689061d

2 files changed

Lines changed: 83 additions & 3 deletions

File tree

parcels/application_kernels/interpolation.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"UXPiecewiseConstantFace",
1818
"UXPiecewiseLinearNode",
1919
"XLinear",
20+
"XNearest",
2021
"ZeroInterpolator",
2122
]
2223

@@ -111,6 +112,69 @@ def XLinear(
111112
return value.compute() if isinstance(value, dask.Array) else value
112113

113114

115+
def XNearest(
116+
field: Field,
117+
ti: int,
118+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
119+
tau: np.float32 | np.float64,
120+
t: np.float32 | np.float64,
121+
z: np.float32 | np.float64,
122+
y: np.float32 | np.float64,
123+
x: np.float32 | np.float64,
124+
):
125+
"""
126+
Nearest-Neighbour spatial interpolation on a regular grid.
127+
Note that this still uses linear interpolation in time.
128+
"""
129+
xi, xsi = position["X"]
130+
yi, eta = position["Y"]
131+
zi, zeta = position["Z"]
132+
133+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
134+
data = field.data
135+
136+
lenT = 2 if np.any(tau > 0) else 1
137+
138+
# Spatial coordinates: left if barycentric < 0.5, otherwise right
139+
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
140+
zi_full = np.where(zeta < 0.5, zi, zi_1)
141+
142+
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
143+
yi_full = np.where(eta < 0.5, yi, yi_1)
144+
145+
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
146+
xi_full = np.where(xsi < 0.5, xi, xi_1)
147+
148+
# Time coordinates: 1 point at ti, then 1 point at ti+1
149+
if lenT == 1:
150+
ti_full = ti
151+
else:
152+
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
153+
ti_full = np.concatenate([ti, ti_1])
154+
xi_full = np.repeat(xi_full, 2)
155+
yi_full = np.repeat(yi_full, 2)
156+
zi_full = np.repeat(zi_full, 2)
157+
158+
# Create DataArrays for indexing
159+
selection_dict = {
160+
axis_dim["X"]: xr.DataArray(xi_full, dims=("points")),
161+
axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")),
162+
}
163+
if "Z" in axis_dim:
164+
selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points"))
165+
if "time" in data.dims:
166+
selection_dict["time"] = xr.DataArray(ti_full, dims=("points"))
167+
168+
corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi))
169+
170+
if lenT == 2:
171+
value = corner_data[0, :] * (1 - tau) + corner_data[1, :] * tau
172+
else:
173+
value = corner_data[0, :]
174+
175+
return value.compute() if isinstance(value, dask.Array) else value
176+
177+
114178
def UXPiecewiseConstantFace(
115179
field: Field,
116180
ti: int,

tests/v4/test_interpolation.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
77
from parcels._index_search import _search_time_index
88
from parcels.application_kernels.advection import AdvectionRK4_3D
9-
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, ZeroInterpolator
9+
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, XNearest, ZeroInterpolator
1010
from parcels.field import Field, VectorField
1111
from parcels.fieldset import FieldSet
1212
from parcels.particle import Particle, Variable
@@ -49,9 +49,25 @@ def field():
4949
"func, t, z, y, x, expected",
5050
[
5151
pytest.param(ZeroInterpolator, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 0, id="Zero"),
52-
pytest.param(XLinear, np.timedelta64(0, "s"), 0, 0.49, 0.51, 1.49, id="Linear-1"),
53-
pytest.param(XLinear, np.timedelta64(1, "s"), 0, 0.49, 0.51, 6.49, id="Linear-2"),
52+
pytest.param(
53+
XLinear,
54+
[np.timedelta64(0, "s"), np.timedelta64(1, "s")],
55+
[0, 0],
56+
[0.49, 0.49],
57+
[0.51, 0.51],
58+
[1.49, 6.49],
59+
id="Linear",
60+
),
5461
pytest.param(XLinear, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 13.99, id="Linear-2"),
62+
pytest.param(
63+
XNearest,
64+
[np.timedelta64(0, "s"), np.timedelta64(3, "s")],
65+
[0.2, 0.2],
66+
[0.2, 0.2],
67+
[0.51, 0.51],
68+
[1.0, 16.0],
69+
id="Nearest",
70+
),
5571
],
5672
)
5773
def test_raw_2d_interpolation(field, func, t, z, y, x, expected):

0 commit comments

Comments
 (0)