Skip to content

Commit 0e0d528

Browse files
committed
Implement XGrid.search with tests for 2D lon lat
1 parent 1ca8f0b commit 0e0d528

3 files changed

Lines changed: 74 additions & 14 deletions

File tree

parcels/_index_search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
_raise_field_sampling_error,
1818
_raise_time_extrapolation_error,
1919
)
20-
from parcels.xgrid import XGrid
2120

2221
from .grid import GridType
2322

2423
if TYPE_CHECKING:
24+
from parcels.xgrid import XGrid
25+
2526
from .field import Field
2627
# from .grid import Grid
2728

parcels/xgrid.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import xarray as xr
77

88
from parcels import xgcm
9+
from parcels._index_search import _search_indices_curvilinear_2d
910
from parcels.basegrid import BaseGrid
1011
from parcels.tools.converters import TimeConverter
1112

@@ -182,18 +183,26 @@ def _gtype(self):
182183
def search(self, z, y, x, ei=None, search2D=False):
183184
ds = self.xgcm_grid._ds
184185

186+
if search2D:
187+
zi = 0
188+
else:
189+
zi, _ = _search_1d_array(ds.depth.values, z)
190+
185191
if ds.lon.ndim == 1:
186-
yi, bcoord_y = _search_1d_array(ds.lat.values, y)
187-
xi, bcoord_x = _search_1d_array(ds.lon.values, x)
192+
yi, eta = _search_1d_array(ds.lat.values, y)
193+
xi, xsi = _search_1d_array(ds.lon.values, x)
194+
return (zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi])
188195

189-
if search2D:
190-
zi = 0
191-
else:
192-
zi, _ = _search_1d_array(ds.depth.values, z)
196+
yi, xi = None, None
197+
if ei is not None:
198+
_, yi, xi = self.unravel_index(ei)
199+
200+
if ds.lon.ndim == 2:
201+
eta, xsi, yi, xi = _search_indices_curvilinear_2d(self, y, x, yi, xi)
193202

194-
return (zi, yi, xi), np.array([bcoord_y, bcoord_x, 1 - bcoord_y, 1 - bcoord_x])
203+
return (zi, yi, xi), np.array([eta, xsi, 1 - eta, 1 - xsi])
195204

196-
raise NotImplementedError("Searching in 2D arrays is not implemented yet.")
205+
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
197206

198207
def ravel_index(self, zi, yi, xi):
199208
"""
@@ -415,6 +424,6 @@ def _search_1d_array(
415424
float
416425
Barycentric coordinate.
417426
"""
418-
i = np.argmin(arr < x)
427+
i = np.argmin(arr <= x) - 1
419428
barry = (x - arr[i]) / (arr[i + 1] - arr[i])
420429
return i, barry

tests/v4/test_xgrid.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010
from parcels._datasets.structured.generic import T, X, Y, Z, datasets
1111
from parcels.grid import Grid as OldGrid
1212
from parcels.tools.converters import TimeConverter
13-
from parcels.xgrid import (
14-
XGrid,
15-
_generate_cells,
16-
)
13+
from parcels.xgrid import XGrid, _generate_cells, _search_1d_array
1714

1815
GridTestCase = namedtuple("GridTestCase", ["Grid", "attr", "expected"])
1916

@@ -187,3 +184,56 @@ def _assert_point_is(
187184
raise ValueError(f"Invalid method: {direction}")
188185

189186
np.testing.assert_allclose(reference_cell + delta, test_cell)
187+
188+
189+
@pytest.mark.parametrize(
190+
"ds",
191+
[
192+
pytest.param(datasets["ds_2d_left"], id="1D lon/lat"),
193+
pytest.param(datasets["2d_left_rotated"], id="2D lon/lat"),
194+
],
195+
) # for key, ds in datasets.items()])
196+
def test_xgrid_search_cpoints(ds):
197+
grid = XGrid(xgcm.Grid(ds, periodic=False))
198+
lat_array, lon_array = get_2d_fpoint_mesh(grid)
199+
lat_array, lon_array = corner_to_cell_center_points(lat_array, lon_array)
200+
201+
for xi in range(grid.xdim - 1):
202+
for yi in range(grid.ydim - 1):
203+
lat, lon = lat_array[yi, xi], lon_array[yi, xi]
204+
(zi_test, yi_test, xi_test), bcoords = grid.search(0, lat, lon, ei=None, search2D=True)
205+
assert xi == xi_test
206+
assert yi == yi_test
207+
assert zi_test == 0
208+
209+
# assert np.isclose(bcoords[0], 0.5) #? Should this not be the case with the cell center points?
210+
# assert np.isclose(bcoords[1], 0.5)
211+
212+
213+
def get_2d_fpoint_mesh(grid: XGrid):
214+
lat, lon = grid.lat, grid.lon
215+
if lon.ndim == 1:
216+
lat, lon = np.meshgrid(lat, lon, indexing="ij")
217+
return lat, lon
218+
219+
220+
def corner_to_cell_center_points(lat, lon):
221+
"""Convert F points to C points."""
222+
lon_c = (lon[:-1, :-1] + lon[:-1, 1:]) / 2
223+
lat_c = (lat[:-1, :-1] + lat[1:, :-1]) / 2
224+
return lat_c, lon_c
225+
226+
227+
@pytest.mark.parametrize(
228+
"array, x, expected_xi, expected_xsi",
229+
[
230+
(np.array([1, 2, 3, 4, 5]), 1.1, 0, 0.1),
231+
(np.array([1, 2, 3, 4, 5]), 2.1, 1, 0.1),
232+
(np.array([1, 2, 3, 4, 5]), 3.1, 2, 0.1),
233+
(np.array([1, 2, 3, 4, 5]), 4.5, 3, 0.5),
234+
],
235+
)
236+
def test_search_1d_array(array, x, expected_xi, expected_xsi):
237+
xi, xsi = _search_1d_array(array, x)
238+
assert xi == expected_xi
239+
assert np.isclose(xsi, expected_xsi)

0 commit comments

Comments
 (0)