Skip to content

Commit 7e3b9ad

Browse files
Merge branch 'v4-dev' into xarray_dataset_for_particle_data
2 parents 8fc2b06 + b510f11 commit 7e3b9ad

6 files changed

Lines changed: 206 additions & 3 deletions

File tree

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ dependencies: #! Keep in sync with [tool.pixi.dependencies] in pyproject.toml
4141

4242
# Docs
4343
- ipython
44-
- numpydoc
44+
- numpydoc!=1.9.0
4545
- nbsphinx
4646
- sphinx
4747
- pandoc

parcels/_datasets/structured/generic.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _unrolled_cone_curvilinear_grid():
136136

137137
datasets = {
138138
"2d_left_rotated": _rotated_curvilinear_grid(),
139-
"ds_2d_left": xr.Dataset(
139+
"ds_2d_left": xr.Dataset( # MITgcm indexing style
140140
{
141141
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
142142
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
@@ -178,5 +178,47 @@ def _unrolled_cone_curvilinear_grid():
178178
"time": (["time"], TIME, {"axis": "T"}),
179179
},
180180
),
181+
"ds_2d_right": xr.Dataset( # NEMO indexing style
182+
{
183+
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
184+
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
185+
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
186+
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
187+
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
188+
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
189+
},
190+
coords={
191+
"XG": (
192+
["XG"],
193+
2 * np.pi / X * np.arange(0, X),
194+
{"axis": "X", "c_grid_axis_shift": 0.5},
195+
),
196+
"XC": (["XC"], 2 * np.pi / X * (np.arange(0, X) - 0.5), {"axis": "X"}),
197+
"YG": (
198+
["YG"],
199+
2 * np.pi / (Y) * np.arange(0, Y),
200+
{"axis": "Y", "c_grid_axis_shift": 0.5},
201+
),
202+
"YC": (
203+
["YC"],
204+
2 * np.pi / (Y) * (np.arange(0, Y) - 0.5),
205+
{"axis": "Y"},
206+
),
207+
"ZG": (
208+
["ZG"],
209+
np.arange(Z),
210+
{"axis": "Z", "c_grid_axis_shift": 0.5},
211+
),
212+
"ZC": (
213+
["ZC"],
214+
np.arange(Z) - 0.5,
215+
{"axis": "Z"},
216+
),
217+
"lon": (["XG"], 2 * np.pi / X * np.arange(0, X)),
218+
"lat": (["YG"], 2 * np.pi / (Y) * np.arange(0, Y)),
219+
"depth": (["ZG"], np.arange(Z)),
220+
"time": (["time"], TIME, {"axis": "T"}),
221+
},
222+
),
181223
"2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(),
182224
}

parcels/xgrid.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,53 @@ def get_axis_dim(self, axis: _XGRID_AXES) -> int:
129129

130130
return get_cell_count_along_dim(self.xgcm_grid.axes[axis])
131131

132+
def localize(self, position: dict[_XGRID_AXES, tuple[int, float]], dims: list[str]) -> dict[str, tuple[int, float]]:
133+
"""
134+
Uses the grid context (i.e., the staggering of the grid) to convert a position relative
135+
to the F-points in the grid to a position relative to the staggered grid the array
136+
of interest is defined on.
137+
138+
Uses dimensions of the DataArray to determine the staggered grid.
139+
140+
WARNING: This API is unstable and subject to change in future versions.
141+
142+
Parameters
143+
----------
144+
position : dict
145+
A mapping of the axis to a tuple of (index, barycentric coordinate) for the
146+
F-points in the grid.
147+
dims : list[str]
148+
A list of dimension names that the DataArray is defined on. This is used to determine
149+
the staggering of the grid and which axis each dimension corresponds to.
150+
151+
Returns
152+
-------
153+
dict[str, tuple[int, float]]
154+
A mapping of the dimension names to a tuple of (index, barycentric coordinate) for
155+
the staggered grid the DataArray is defined on.
156+
157+
Example
158+
-------
159+
>>> position = {'X': (5, 0.51), 'Y': (
160+
10, 0.25), 'Z': (3, 0.75)}
161+
>>> dims = ['time', 'depth', 'YC', 'XC']
162+
>>> grid.localize(position, dims)
163+
{'depth': (3, 0.75), 'YC': (9, 0.75), 'XC': (5, 0.01)}
164+
"""
165+
axis_to_var = {get_axis_from_dim_name(self.xgcm_grid.axes, dim): dim for dim in dims}
166+
var_positions = {
167+
axis: get_xgcm_position_from_dim_name(self.xgcm_grid.axes, dim) for axis, dim in axis_to_var.items()
168+
}
169+
return {
170+
axis_to_var[axis]: _convert_center_pos_to_fpoint(
171+
index=index,
172+
bcoord=bcoord,
173+
xgcm_position=var_positions[axis],
174+
f_points_xgcm_position=self._fpoint_info[axis],
175+
)
176+
for axis, (index, bcoord) in position.items()
177+
}
178+
132179
@property
133180
def _z4d(self) -> Literal[0, 1]:
134181
"""
@@ -185,6 +232,20 @@ def search(self, z, y, x, ei=None):
185232

186233
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
187234

235+
@cached_property
236+
def _fpoint_info(self):
237+
"""Returns a mapping of the spatial axes in the Grid to their XGCM positions."""
238+
xgcm_axes = self.xgcm_grid.axes
239+
f_point_positions = ["left", "right", "inner", "outer"]
240+
axis_position_mapping = {}
241+
for axis in self.axes:
242+
coords = xgcm_axes[axis].coords
243+
edge_positions = [pos for pos in coords.keys() if pos in f_point_positions]
244+
assert len(edge_positions) == 1, f"Axis {axis} has multiple edge positions: {edge_positions}"
245+
axis_position_mapping[axis] = edge_positions[0]
246+
247+
return axis_position_mapping
248+
188249

189250
def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
190251
"""For a given dimension name in a grid, returns the direction axis it is on."""
@@ -337,3 +398,28 @@ def _search_1d_array(
337398
i = np.argmin(arr <= x) - 1
338399
bcoord = (x - arr[i]) / (arr[i + 1] - arr[i])
339400
return i, bcoord
401+
402+
403+
def _convert_center_pos_to_fpoint(
404+
*, index: int, bcoord: float, xgcm_position: _XGCM_AXIS_POSITION, f_points_xgcm_position: _XGCM_AXIS_POSITION
405+
) -> tuple[int, float]:
406+
"""Converts a physical position relative to the cell edges defined in the grid to be relative to the center point.
407+
408+
This is used to "localize" a position to be relative to the staggered grid at which the field is defined, so that
409+
it can be easily interpolated.
410+
411+
This also handles different model input cell edges and centers are staggered in different directions (e.g., with NEMO and MITgcm).
412+
"""
413+
if xgcm_position != "center": # Data is already defined on the F points
414+
return index, bcoord
415+
416+
bcoord = bcoord - 0.5
417+
if bcoord < 0:
418+
bcoord += 1.0
419+
index -= 1
420+
421+
# Correct relative to the f-point position
422+
if f_points_xgcm_position in ["inner", "right"]:
423+
index += 1
424+
425+
return index, bcoord

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ pre_commit = "*"
9696

9797
# Docs
9898
ipython = "*"
99-
numpydoc = "*"
99+
numpydoc = "!=1.9.0"
100100
nbsphinx = "*"
101101
sphinx = "*"
102102
pandoc = "*"

tests/v4/test_datasets.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from parcels._datasets.structured.generic import datasets
2+
from parcels.xgcm import Grid
3+
4+
5+
def test_left_indexed_dataset():
6+
"""Checks that 'ds_2d_left' is right indexed on all variables."""
7+
ds = datasets["ds_2d_left"]
8+
grid = Grid(ds)
9+
10+
for _axis_name, axis in grid.axes.items():
11+
for pos, _dim_name in axis.coords.items():
12+
assert pos in ["left", "center"]
13+
14+
15+
def test_right_indexed_dataset():
16+
"""Checks that 'ds_2d_right' is right indexed on all variables."""
17+
ds = datasets["ds_2d_right"]
18+
grid = Grid(ds)
19+
for _axis_name, axis in grid.axes.items():
20+
for pos, _dim_name in axis.coords.items():
21+
assert pos in ["center", "right"]

tests/v4/test_xgrid.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,57 @@ def test_search_1d_array(array, x, expected_xi, expected_xsi):
145145
xi, xsi = _search_1d_array(array, x)
146146
assert xi == expected_xi
147147
assert np.isclose(xsi, expected_xsi)
148+
149+
150+
@pytest.mark.parametrize(
151+
"grid, da_name, expected",
152+
[
153+
pytest.param(
154+
XGrid(xgcm.Grid(datasets["ds_2d_left"], periodic=False)),
155+
"U (C grid)",
156+
{
157+
"XG": (np.int64(0), np.float64(0.0)),
158+
"YC": (np.int64(-1), np.float64(0.5)),
159+
"ZG": (np.int64(0), np.float64(0.0)),
160+
},
161+
id="MITgcm indexing style U (C grid)",
162+
),
163+
pytest.param(
164+
XGrid(xgcm.Grid(datasets["ds_2d_left"], periodic=False)),
165+
"V (C grid)",
166+
{
167+
"XC": (np.int64(-1), np.float64(0.5)),
168+
"YG": (np.int64(0), np.float64(0.0)),
169+
"ZG": (np.int64(0), np.float64(0.0)),
170+
},
171+
id="MITgcm indexing style V (C grid)",
172+
),
173+
pytest.param(
174+
XGrid(xgcm.Grid(datasets["ds_2d_right"], periodic=False)),
175+
"U (C grid)",
176+
{
177+
"XG": (np.int64(0), np.float64(0.0)),
178+
"YC": (np.int64(0), np.float64(0.5)),
179+
"ZG": (np.int64(0), np.float64(0.0)),
180+
},
181+
id="NEMO indexing style U (C grid)",
182+
),
183+
pytest.param(
184+
XGrid(xgcm.Grid(datasets["ds_2d_right"], periodic=False)),
185+
"V (C grid)",
186+
{
187+
"XC": (np.int64(0), np.float64(0.5)),
188+
"YG": (np.int64(0), np.float64(0.0)),
189+
"ZG": (np.int64(0), np.float64(0.0)),
190+
},
191+
id="NEMO indexing style V (C grid)",
192+
),
193+
],
194+
)
195+
def test_xgrid_localize_zero_position(grid, da_name, expected):
196+
"""Test localize function using left and right datasets."""
197+
position = grid.search(0, 0, 0)
198+
da = grid.xgcm_grid._ds[da_name]
199+
200+
local_position = grid.localize(position, da.dims)
201+
assert local_position == expected, f"Expected {expected}, got {local_position}"

0 commit comments

Comments
 (0)