11import numpy as np
22import pytest
3+ import xarray as xr
34
45from parcels ._datasets .structured .generic import simple_UV_dataset
6+ from parcels .application_kernels .advection import AdvectionRK4_3D
7+ from parcels .application_kernels .interpolation import XBiLinear , XTriLinear
58from parcels .field import Field , VectorField
6- from parcels .xgrid import _XGRID_AXES , XGrid
7-
8-
9- def BiRectiLinear ( # TODO move to interpolation file
10- field : Field ,
11- ti : int ,
12- position : dict [_XGRID_AXES , tuple [int , float | np .ndarray ]],
13- tau : np .float32 | np .float64 ,
14- t : np .float32 | np .float64 ,
15- z : np .float32 | np .float64 ,
16- y : np .float32 | np .float64 ,
17- x : np .float32 | np .float64 ,
18- ):
19- """Bilinear interpolation on a rectilinear grid."""
20- xi , xsi = position ["X" ]
21- yi , eta = position ["Y" ]
22-
23- data = field .data .data [:, :, yi : yi + 2 , xi : xi + 2 ]
24- val_t0 = (
25- (1 - xsi ) * (1 - eta ) * data [0 , 0 , 0 , 0 ]
26- + xsi * (1 - eta ) * data [0 , 0 , 0 , 1 ]
27- + xsi * eta * data [0 , 0 , 1 , 1 ]
28- + (1 - xsi ) * eta * data [0 , 0 , 1 , 0 ]
29- )
30-
31- val_t1 = (
32- (1 - xsi ) * (1 - eta ) * data [1 , 0 , 0 , 0 ]
33- + xsi * (1 - eta ) * data [1 , 0 , 0 , 1 ]
34- + xsi * eta * data [1 , 0 , 1 , 1 ]
35- + (1 - xsi ) * eta * data [1 , 0 , 1 , 0 ]
36- )
37- return val_t0 * (1 - tau ) + val_t1 * tau
9+ from parcels .fieldset import FieldSet
10+ from parcels .particle import Particle , Variable
11+ from parcels .particleset import ParticleSet
12+ from parcels .xgrid import XGrid
13+ from tests .utils import TEST_DATA
3814
3915
4016@pytest .mark .parametrize ("mesh_type" , ["spherical" , "flat" ])
4117def test_interpolation_mesh_type (mesh_type , npart = 10 ):
4218 ds = simple_UV_dataset (mesh_type = mesh_type )
4319 ds ["U" ].data [:] = 1.0
4420 grid = XGrid .from_dataset (ds )
45- U = Field ("U" , ds ["U" ], grid , mesh_type = mesh_type , interp_method = BiRectiLinear )
46- V = Field ("V" , ds ["V" ], grid , mesh_type = mesh_type , interp_method = BiRectiLinear )
21+ U = Field ("U" , ds ["U" ], grid , mesh_type = mesh_type , interp_method = XBiLinear )
22+ V = Field ("V" , ds ["V" ], grid , mesh_type = mesh_type , interp_method = XBiLinear )
4723 UV = VectorField ("UV" , U , V )
4824
4925 lat = 30.0
@@ -58,3 +34,76 @@ def test_interpolation_mesh_type(mesh_type, npart=10):
5834 assert v == 0.0
5935
6036 assert U .eval (time , 0 , lat , 0 , applyConversion = False ) == 1
37+
38+
39+ interp_methods = {
40+ "linear" : XTriLinear ,
41+ }
42+
43+
44+ @pytest .mark .xfail (reason = "ParticleFile not implemented yet" )
45+ @pytest .mark .parametrize (
46+ "interp_name" ,
47+ [
48+ "linear" ,
49+ # "freeslip",
50+ # "nearest",
51+ # "cgrid_velocity",
52+ ],
53+ )
54+ def test_interp_regression_v3 (interp_name ):
55+ """Test that the v4 versions of the interpolation are the same as the v3 versions."""
56+ ds_input = xr .open_dataset (str (TEST_DATA / f"test_interpolation_data_random_{ interp_name } .nc" ))
57+ ydim = ds_input ["U" ].shape [2 ]
58+ xdim = ds_input ["U" ].shape [3 ]
59+ time = [np .timedelta64 (int (t ), "s" ) for t in ds_input ["time" ].values ]
60+
61+ ds = xr .Dataset (
62+ {
63+ "U" : (["time" , "depth" , "YG" , "XG" ], ds_input ["U" ].values ),
64+ "V" : (["time" , "depth" , "YG" , "XG" ], ds_input ["V" ].values ),
65+ "W" : (["time" , "depth" , "YG" , "XG" ], ds_input ["W" ].values ),
66+ },
67+ coords = {
68+ "time" : (["time" ], time , {"axis" : "T" }),
69+ "depth" : (["depth" ], ds_input ["depth" ].values , {"axis" : "Z" }),
70+ "YC" : (["YC" ], np .arange (ydim ) + 0.5 , {"axis" : "Y" }),
71+ "YG" : (["YG" ], np .arange (ydim ), {"axis" : "Y" , "c_grid_axis_shift" : - 0.5 }),
72+ "XC" : (["XC" ], np .arange (xdim ) + 0.5 , {"axis" : "X" }),
73+ "XG" : (["XG" ], np .arange (xdim ), {"axis" : "X" , "c_grid_axis_shift" : - 0.5 }),
74+ "lat" : (["YG" ], ds_input ["lat" ].values , {"axis" : "Y" , "c_grid_axis_shift" : 0.5 }),
75+ "lon" : (["XG" ], ds_input ["lon" ].values , {"axis" : "X" , "c_grid_axis_shift" : - 0.5 }),
76+ },
77+ )
78+
79+ grid = XGrid .from_dataset (ds )
80+ U = Field ("U" , ds ["U" ], grid , mesh_type = "flat" , interp_method = interp_methods [interp_name ])
81+ V = Field ("V" , ds ["V" ], grid , mesh_type = "flat" , interp_method = interp_methods [interp_name ])
82+ W = Field ("W" , ds ["W" ], grid , mesh_type = "flat" , interp_method = interp_methods [interp_name ])
83+ fieldset = FieldSet ([U , V , W , VectorField ("UVW" , U , V , W )])
84+
85+ x , y , z = np .meshgrid (np .linspace (0 , 1 , 7 ), np .linspace (0 , 1 , 13 ), np .linspace (0 , 1 , 5 ))
86+
87+ TestP = Particle .add_variable (Variable ("pid" , dtype = np .int32 , initial = 0 ))
88+ pset = ParticleSet (fieldset , pclass = TestP , lon = x , lat = y , depth = z , pid = np .arange (x .size ))
89+
90+ def DeleteParticle (particle , fieldset , time ):
91+ if particle .state >= 50 :
92+ particle .delete ()
93+
94+ outfile = pset .ParticleFile (f"test_interpolation_v4_{ interp_name } " , outputdt = np .timedelta64 (1 , "s" ))
95+ pset .execute (
96+ [AdvectionRK4_3D , DeleteParticle ],
97+ runtime = np .timedelta64 (4 , "s" ),
98+ dt = np .timedelta64 (1 , "s" ),
99+ output_file = outfile ,
100+ )
101+
102+ print (str (TEST_DATA / f"test_interpolation_jit_{ interp_name } .zarr" ))
103+ ds_v3 = xr .open_zarr (str (TEST_DATA / f"test_interpolation_jit_{ interp_name } .zarr" ))
104+ ds_v4 = xr .open_zarr (f"test_interpolation_v4_{ interp_name } .zarr" )
105+
106+ tol = 1e-6
107+ np .testing .assert_allclose (ds_v3 .lon , ds_v4 .lon , atol = tol )
108+ np .testing .assert_allclose (ds_v3 .lat , ds_v4 .lat , atol = tol )
109+ np .testing .assert_allclose (ds_v3 .z , ds_v4 .z , atol = tol )
0 commit comments