Skip to content

Commit c677708

Browse files
Starting move of interpolation_regression_test to v4
But can't finsih until we have ParticleFile implemented
1 parent 4969bb3 commit c677708

3 files changed

Lines changed: 90 additions & 83 deletions

File tree

parcels/application_kernels/interpolation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import numpy as np
88

99
from parcels.field import Field
10+
from parcels.tools.statuscodes import (
11+
FieldOutOfBoundError,
12+
)
1013

1114
if TYPE_CHECKING:
1215
from parcels.uxgrid import _UXGRID_AXES
@@ -105,6 +108,9 @@ def XTriLinear(
105108
yi, eta = position["Y"]
106109
zi, zeta = position["Z"]
107110

111+
if zi < 0 or xi < 0 or yi < 0:
112+
raise FieldOutOfBoundError
113+
108114
data = field.data.data[:, zi : zi + 2, yi : yi + 2, xi : xi + 2]
109115
data = (1 - tau) * data[ti, :, :, :] + tau * data[ti + 1, :, :, :]
110116
if zeta > 0:

tests/test_interpolation.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import xarray as xr
44

55
import parcels._interpolation as interpolation
6-
from parcels import AdvectionRK4_3D, FieldSet, Particle, ParticleSet
7-
from tests.utils import TEST_DATA, create_fieldset_zeros_3d
6+
from tests.utils import create_fieldset_zeros_3d
87

98

109
@pytest.fixture
@@ -109,50 +108,3 @@ def test_interpolator2(ctx: interpolation.InterpolationContext3D):
109108
return 0
110109

111110
fieldset.U[0.5, 0.5, 0.5, 0.5]
112-
113-
114-
@pytest.mark.v4remove
115-
@pytest.mark.xfail(reason="GH1946")
116-
@pytest.mark.parametrize(
117-
"interp_method",
118-
[
119-
"linear",
120-
"freeslip",
121-
"nearest",
122-
"cgrid_velocity",
123-
],
124-
)
125-
def test_interp_regression_v3(interp_method):
126-
"""Test that the v4 versions of the interpolation are the same as the v3 versions."""
127-
variables = {"U": "U", "V": "V", "W": "W"}
128-
dimensions = {"time": "time", "lon": "lon", "lat": "lat", "depth": "depth"}
129-
ds = xr.open_dataset(str(TEST_DATA / f"test_interpolation_data_random_{interp_method}.nc"))
130-
fieldset = FieldSet.from_xarray_dataset(
131-
ds,
132-
variables,
133-
dimensions,
134-
mesh="flat",
135-
)
136-
137-
for field in [fieldset.U, fieldset.V, fieldset.W]: # Set a land point (for testing freeslip)
138-
field.interp_method = interp_method
139-
140-
x, y, z = np.meshgrid(np.linspace(0, 1, 7), np.linspace(0, 1, 13), np.linspace(0, 1, 5))
141-
142-
TestP = Particle.add_variable("pid", dtype=np.int32, initial=0)
143-
pset = ParticleSet(fieldset, pclass=TestP, lon=x, lat=y, depth=z, pid=np.arange(x.size))
144-
145-
def DeleteParticle(particle, fieldset, time):
146-
if particle.state >= 50:
147-
particle.delete()
148-
149-
outfile = pset.ParticleFile(f"test_interpolation_v4_{interp_method}", outputdt=1)
150-
pset.execute([AdvectionRK4_3D, DeleteParticle], runtime=4, dt=1, output_file=outfile)
151-
152-
ds_v3 = xr.open_zarr(str(TEST_DATA / f"test_interpolation_jit_{interp_method}.zarr"))
153-
ds_v4 = xr.open_zarr(f"test_interpolation_v4_{interp_method}.zarr")
154-
155-
tol = 1e-6
156-
np.testing.assert_allclose(ds_v3.lon, ds_v4.lon, atol=tol)
157-
np.testing.assert_allclose(ds_v3.lat, ds_v4.lat, atol=tol)
158-
np.testing.assert_allclose(ds_v3.z, ds_v4.z, atol=tol)

tests/v4/test_interpolation.py

Lines changed: 83 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,25 @@
11
import numpy as np
22
import pytest
3+
import xarray as xr
34

45
from parcels._datasets.structured.generic import simple_UV_dataset
6+
from parcels.application_kernels.advection import AdvectionRK4_3D
7+
from parcels.application_kernels.interpolation import XBiLinear, XTriLinear
58
from parcels.field import Field, VectorField
6-
from parcels.xgrid import _XGRID_AXES, XGrid
7-
8-
9-
def BiRectiLinear( # TODO move to interpolation file
10-
field: Field,
11-
ti: int,
12-
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
13-
tau: np.float32 | np.float64,
14-
t: np.float32 | np.float64,
15-
z: np.float32 | np.float64,
16-
y: np.float32 | np.float64,
17-
x: np.float32 | np.float64,
18-
):
19-
"""Bilinear interpolation on a rectilinear grid."""
20-
xi, xsi = position["X"]
21-
yi, eta = position["Y"]
22-
23-
data = field.data.data[:, :, yi : yi + 2, xi : xi + 2]
24-
val_t0 = (
25-
(1 - xsi) * (1 - eta) * data[0, 0, 0, 0]
26-
+ xsi * (1 - eta) * data[0, 0, 0, 1]
27-
+ xsi * eta * data[0, 0, 1, 1]
28-
+ (1 - xsi) * eta * data[0, 0, 1, 0]
29-
)
30-
31-
val_t1 = (
32-
(1 - xsi) * (1 - eta) * data[1, 0, 0, 0]
33-
+ xsi * (1 - eta) * data[1, 0, 0, 1]
34-
+ xsi * eta * data[1, 0, 1, 1]
35-
+ (1 - xsi) * eta * data[1, 0, 1, 0]
36-
)
37-
return val_t0 * (1 - tau) + val_t1 * tau
9+
from parcels.fieldset import FieldSet
10+
from parcels.particle import Particle, Variable
11+
from parcels.particleset import ParticleSet
12+
from parcels.xgrid import XGrid
13+
from tests.utils import TEST_DATA
3814

3915

4016
@pytest.mark.parametrize("mesh_type", ["spherical", "flat"])
4117
def test_interpolation_mesh_type(mesh_type, npart=10):
4218
ds = simple_UV_dataset(mesh_type=mesh_type)
4319
ds["U"].data[:] = 1.0
4420
grid = XGrid.from_dataset(ds)
45-
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=BiRectiLinear)
46-
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=BiRectiLinear)
21+
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XBiLinear)
22+
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XBiLinear)
4723
UV = VectorField("UV", U, V)
4824

4925
lat = 30.0
@@ -58,3 +34,76 @@ def test_interpolation_mesh_type(mesh_type, npart=10):
5834
assert v == 0.0
5935

6036
assert U.eval(time, 0, lat, 0, applyConversion=False) == 1
37+
38+
39+
interp_methods = {
40+
"linear": XTriLinear,
41+
}
42+
43+
44+
@pytest.mark.xfail(reason="ParticleFile not implemented yet")
45+
@pytest.mark.parametrize(
46+
"interp_name",
47+
[
48+
"linear",
49+
# "freeslip",
50+
# "nearest",
51+
# "cgrid_velocity",
52+
],
53+
)
54+
def test_interp_regression_v3(interp_name):
55+
"""Test that the v4 versions of the interpolation are the same as the v3 versions."""
56+
ds_input = xr.open_dataset(str(TEST_DATA / f"test_interpolation_data_random_{interp_name}.nc"))
57+
ydim = ds_input["U"].shape[2]
58+
xdim = ds_input["U"].shape[3]
59+
time = [np.timedelta64(int(t), "s") for t in ds_input["time"].values]
60+
61+
ds = xr.Dataset(
62+
{
63+
"U": (["time", "depth", "YG", "XG"], ds_input["U"].values),
64+
"V": (["time", "depth", "YG", "XG"], ds_input["V"].values),
65+
"W": (["time", "depth", "YG", "XG"], ds_input["W"].values),
66+
},
67+
coords={
68+
"time": (["time"], time, {"axis": "T"}),
69+
"depth": (["depth"], ds_input["depth"].values, {"axis": "Z"}),
70+
"YC": (["YC"], np.arange(ydim) + 0.5, {"axis": "Y"}),
71+
"YG": (["YG"], np.arange(ydim), {"axis": "Y", "c_grid_axis_shift": -0.5}),
72+
"XC": (["XC"], np.arange(xdim) + 0.5, {"axis": "X"}),
73+
"XG": (["XG"], np.arange(xdim), {"axis": "X", "c_grid_axis_shift": -0.5}),
74+
"lat": (["YG"], ds_input["lat"].values, {"axis": "Y", "c_grid_axis_shift": 0.5}),
75+
"lon": (["XG"], ds_input["lon"].values, {"axis": "X", "c_grid_axis_shift": -0.5}),
76+
},
77+
)
78+
79+
grid = XGrid.from_dataset(ds)
80+
U = Field("U", ds["U"], grid, mesh_type="flat", interp_method=interp_methods[interp_name])
81+
V = Field("V", ds["V"], grid, mesh_type="flat", interp_method=interp_methods[interp_name])
82+
W = Field("W", ds["W"], grid, mesh_type="flat", interp_method=interp_methods[interp_name])
83+
fieldset = FieldSet([U, V, W, VectorField("UVW", U, V, W)])
84+
85+
x, y, z = np.meshgrid(np.linspace(0, 1, 7), np.linspace(0, 1, 13), np.linspace(0, 1, 5))
86+
87+
TestP = Particle.add_variable(Variable("pid", dtype=np.int32, initial=0))
88+
pset = ParticleSet(fieldset, pclass=TestP, lon=x, lat=y, depth=z, pid=np.arange(x.size))
89+
90+
def DeleteParticle(particle, fieldset, time):
91+
if particle.state >= 50:
92+
particle.delete()
93+
94+
outfile = pset.ParticleFile(f"test_interpolation_v4_{interp_name}", outputdt=np.timedelta64(1, "s"))
95+
pset.execute(
96+
[AdvectionRK4_3D, DeleteParticle],
97+
runtime=np.timedelta64(4, "s"),
98+
dt=np.timedelta64(1, "s"),
99+
output_file=outfile,
100+
)
101+
102+
print(str(TEST_DATA / f"test_interpolation_jit_{interp_name}.zarr"))
103+
ds_v3 = xr.open_zarr(str(TEST_DATA / f"test_interpolation_jit_{interp_name}.zarr"))
104+
ds_v4 = xr.open_zarr(f"test_interpolation_v4_{interp_name}.zarr")
105+
106+
tol = 1e-6
107+
np.testing.assert_allclose(ds_v3.lon, ds_v4.lon, atol=tol)
108+
np.testing.assert_allclose(ds_v3.lat, ds_v4.lat, atol=tol)
109+
np.testing.assert_allclose(ds_v3.z, ds_v4.z, atol=tol)

0 commit comments

Comments
 (0)