Skip to content

Commit 553f1fc

Browse files
committed
Update Field init to take XGrid
1 parent 5a68537 commit 553f1fc

4 files changed

Lines changed: 39 additions & 46 deletions

File tree

parcels/field.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
_raise_field_out_of_bound_error,
3131
)
3232
from parcels.uxgrid import UxGrid
33-
from parcels.xgcm.grid import Grid
3433
from parcels.xgrid import XGrid
3534

3635
from ._index_search import _search_time_index
@@ -145,7 +144,7 @@ def __init__(
145144
self,
146145
name: str,
147146
data: xr.DataArray | ux.UxDataArray,
148-
grid: UxGrid | Grid,
147+
grid: UxGrid | XGrid,
149148
mesh_type: Mesh = "flat",
150149
interp_method: Callable | None = None,
151150
):
@@ -155,7 +154,7 @@ def __init__(
155154
)
156155
if not isinstance(name, str):
157156
raise ValueError(f"Expected `name` to be a string, got {type(name)}.")
158-
if not isinstance(grid, (UxGrid, Grid)):
157+
if not isinstance(grid, (UxGrid, XGrid)):
159158
raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels Grid object, got {type(grid)}.")
160159

161160
assert_valid_mesh(mesh_type)
@@ -174,13 +173,6 @@ def __init__(
174173
)
175174
raise e
176175

177-
# For compatibility with parts of the codebase that rely on v3 definition of Grid.
178-
# Should be worked to be removed in v4
179-
if isinstance(grid, Grid):
180-
self.gridadapter = XGrid(grid)
181-
else:
182-
self.gridadapter = None
183-
184176
try:
185177
if isinstance(data, ux.UxDataArray):
186178
_assert_valid_uxdataarray(data)
@@ -232,7 +224,7 @@ def lat(self):
232224
elif self.data.attrs["location"] == "edge":
233225
return self.grid.edge_lat
234226
else:
235-
return self.gridadapter.lat
227+
return self.grid.lat
236228

237229
@property
238230
def lon(self):
@@ -244,7 +236,7 @@ def lon(self):
244236
elif self.data.attrs["location"] == "edge":
245237
return self.grid.edge_lon
246238
else:
247-
return self.gridadapter.lon
239+
return self.grid.lon
248240

249241
@property
250242
def depth(self):
@@ -255,26 +247,26 @@ def depth(self):
255247
elif vertical_location == "face":
256248
return self.grid.nz
257249
else:
258-
return self.gridadapter.depth
250+
return self.grid.depth
259251

260252
@property
261253
def xdim(self):
262254
if type(self.data) is xr.DataArray:
263-
return self.gridadapter.xdim
255+
return self.grid.xdim
264256
else:
265257
raise NotImplementedError("xdim not implemented for unstructured grids")
266258

267259
@property
268260
def ydim(self):
269261
if type(self.data) is xr.DataArray:
270-
return self.gridadapter.ydim
262+
return self.grid.ydim
271263
else:
272264
raise NotImplementedError("ydim not implemented for unstructured grids")
273265

274266
@property
275267
def zdim(self):
276268
if type(self.data) is xr.DataArray:
277-
return self.gridadapter.zdim
269+
return self.grid.zdim
278270
else:
279271
if "nz1" in self.data.dims:
280272
return self.data.sizes["nz1"]
@@ -523,14 +515,14 @@ def _assert_valid_uxgrid(grid):
523515
)
524516

525517

526-
def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: ux.Grid | Grid):
518+
def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: ux.Grid | XGrid):
527519
if isinstance(data, ux.UxDataArray):
528520
if not isinstance(grid, UxGrid):
529521
raise ValueError(
530522
f"Incompatible data-grid combination. Data is a uxarray.UxDataArray, expected `grid` to be a UxGrid object, got {type(grid)}."
531523
)
532524
elif isinstance(data, xr.DataArray):
533-
if not isinstance(grid, Grid):
525+
if not isinstance(grid, XGrid):
534526
raise ValueError(
535527
f"Incompatible data-grid combination. Data is a xarray.DataArray, expected `grid` to be a parcels Grid object, got {type(grid)}."
536528
)

parcels/fieldset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import numpy as np
88
import xarray as xr
99

10+
from parcels import xgcm
1011
from parcels._core.utils.time import get_datetime_type_calendar
1112
from parcels._core.utils.time import is_compatible as datetime_is_compatible
1213
from parcels._typing import Mesh
1314
from parcels.field import Field, VectorField
14-
from parcels.xgcm.grid import Grid
15+
from parcels.xgrid import XGrid
1516

1617
if TYPE_CHECKING:
1718
from parcels._typing import DatetimeLike
@@ -171,7 +172,7 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"):
171172
"time": (["time"], np.arange(1), {"axis": "T"}),
172173
},
173174
)
174-
grid = Grid(da)
175+
grid = XGrid(xgcm.Grid(da))
175176
self.add_field(
176177
Field(
177178
name,

tests/v4/test_field.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,34 @@
33
import uxarray as ux
44
import xarray as xr
55

6-
from parcels import Field
6+
from parcels import Field, xgcm
77
from parcels._datasets.structured.generic import T as T_structured
88
from parcels._datasets.structured.generic import datasets as datasets_structured
99
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
1010
from parcels.uxgrid import UxGrid
11-
from parcels.xgcm import Grid
11+
from parcels.xgrid import XGrid
1212

1313

1414
def test_field_init_param_types():
15-
data = xr.DataArray(
16-
attrs={
17-
"location": "node",
18-
"mesh": "flat",
19-
}
20-
)
21-
grid = Grid(data)
15+
data = datasets_structured["ds_2d_left"]
16+
grid = XGrid(xgcm.Grid(data))
2217
with pytest.raises(ValueError, match="Expected `name` to be a string"):
23-
Field(name=123, data=data, grid=grid)
18+
Field(name=123, data=data["data_g"], grid=grid)
2419

2520
with pytest.raises(ValueError, match="Expected `data` to be a uxarray.UxDataArray or xarray.DataArray"):
2621
Field(name="test", data=123, grid=grid)
2722

28-
with pytest.raises(ValueError, match="Expected `grid` to be a parcels UxGrid, or parcels Grid"):
29-
Field(name="test", data=data, grid=123)
23+
with pytest.raises(ValueError, match="Expected `grid` to be a parcels UxGrid, or parcels XGrid"):
24+
Field(name="test", data=data["data_g"], grid=123)
3025

3126
with pytest.raises(ValueError, match="Invalid value 'invalid'. Valid options are.*"):
32-
Field(name="test", data=data, grid=grid, mesh_type="invalid")
27+
Field(name="test", data=data["data_g"], grid=grid, mesh_type="invalid")
3328

3429

3530
@pytest.mark.parametrize(
3631
"data,grid",
3732
[
38-
pytest.param(ux.UxDataArray(), Grid(xr.Dataset()), id="uxdata-grid"),
33+
pytest.param(ux.UxDataArray(), XGrid(xgcm.Grid(datasets_structured["ds_2d_left"])), id="uxdata-grid"),
3934
pytest.param(
4035
xr.DataArray(),
4136
UxGrid(datasets_unstructured["stommel_gyre_delaunay"].uxgrid),
@@ -56,7 +51,9 @@ def test_field_incompatible_combination(data, grid):
5651
"data,grid",
5752
[
5853
pytest.param(
59-
datasets_structured["ds_2d_left"]["data_g"], Grid(datasets_structured["ds_2d_left"]), id="ds_2d_left"
54+
datasets_structured["ds_2d_left"]["data_g"],
55+
XGrid(xgcm.Grid(datasets_structured["ds_2d_left"])),
56+
id="ds_2d_left",
6057
), # TODO: Perhaps this test should be expanded to cover more datasets?
6158
],
6259
)
@@ -79,7 +76,7 @@ def test_field_init_fail_on_bad_time_type(numpy_dtype):
7976
ds["time"] = np.arange(0, T_structured, dtype=numpy_dtype)
8077

8178
data = ds["data_g"]
82-
grid = Grid(ds)
79+
grid = XGrid(xgcm.Grid(ds))
8380
with pytest.raises(
8481
ValueError,
8582
match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as datetime or cftime datetime objects\?",
@@ -95,7 +92,9 @@ def test_field_init_fail_on_bad_time_type(numpy_dtype):
9592
"data,grid",
9693
[
9794
pytest.param(
98-
datasets_structured["ds_2d_left"]["data_g"], Grid(datasets_structured["ds_2d_left"]), id="ds_2d_left"
95+
datasets_structured["ds_2d_left"]["data_g"],
96+
XGrid(xgcm.Grid(datasets_structured["ds_2d_left"])),
97+
id="ds_2d_left",
9998
),
10099
],
101100
)

tests/v4/test_fieldset.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@
55
import pytest
66
import xarray as xr
77

8+
from parcels import xgcm
89
from parcels._datasets.structured.generic import T as T_structured
910
from parcels._datasets.structured.generic import datasets as datasets_structured
1011
from parcels.field import Field, VectorField
1112
from parcels.fieldset import CalendarError, FieldSet, _datetime_to_msg
12-
from parcels.xgcm import Grid
13+
from parcels.xgrid import XGrid
1314

1415
ds = datasets_structured["ds_2d_left"]
1516

1617

1718
@pytest.fixture
1819
def fieldset() -> FieldSet:
1920
"""Fixture to create a FieldSet object for testing."""
20-
grid = Grid(ds)
21+
grid = XGrid(xgcm.Grid(ds))
2122
U = Field("U", ds["U (A grid)"], grid, mesh_type="flat")
2223
V = Field("V", ds["V (A grid)"], grid, mesh_type="flat")
2324
UV = VectorField("UV", U, V)
@@ -51,7 +52,7 @@ def test_fieldset_add_constant_field(fieldset):
5152

5253

5354
def test_fieldset_add_field(fieldset):
54-
grid = Grid(ds)
55+
grid = XGrid(xgcm.Grid(ds))
5556
field = Field("test_field", ds["U (A grid)"], grid, mesh_type="flat")
5657
fieldset.add_field(field)
5758
assert fieldset.test_field == field
@@ -64,7 +65,7 @@ def test_fieldset_add_field_wrong_type(fieldset):
6465

6566

6667
def test_fieldset_add_field_already_exists(fieldset):
67-
grid = Grid(ds)
68+
grid = XGrid(xgcm.Grid(ds))
6869
field = Field("test_field", ds["U (A grid)"], grid, mesh_type="flat")
6970
fieldset.add_field(field, "test_field")
7071
with pytest.raises(ValueError, match="FieldSet already has a Field with name 'test_field'"):
@@ -77,12 +78,12 @@ def test_fieldset_gridset_size(fieldset):
7778

7879

7980
def test_fieldset_time_interval():
80-
grid1 = Grid(ds)
81+
grid1 = XGrid(xgcm.Grid(ds))
8182
field1 = Field("field1", ds["U (A grid)"], grid1, mesh_type="flat")
8283

8384
ds2 = ds.copy()
8485
ds2["time"] = ds2["time"] + np.timedelta64(timedelta(days=1))
85-
grid2 = Grid(ds2)
86+
grid2 = XGrid(xgcm.Grid(ds2))
8687
field2 = Field("field2", ds2["U (A grid)"], grid2, mesh_type="flat")
8788

8889
fieldset = FieldSet([field1, field2])
@@ -96,14 +97,14 @@ def test_fieldset_init_incompatible_calendars():
9697
ds1 = ds.copy()
9798
ds1["time"] = xr.date_range("2000", "2001", T_structured, calendar="365_day", use_cftime=True)
9899

99-
grid = Grid(ds1)
100+
grid = XGrid(xgcm.Grid(ds1))
100101
U = Field("U", ds1["U (A grid)"], grid, mesh_type="flat")
101102
V = Field("V", ds1["V (A grid)"], grid, mesh_type="flat")
102103
UV = VectorField("UV", U, V)
103104

104105
ds2 = ds.copy()
105106
ds2["time"] = xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True)
106-
grid2 = Grid(ds2)
107+
grid2 = XGrid(xgcm.Grid(ds2))
107108
incompatible_calendar = Field("test", ds2["data_g"], grid2, mesh_type="flat")
108109

109110
with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"):
@@ -113,7 +114,7 @@ def test_fieldset_init_incompatible_calendars():
113114
def test_fieldset_add_field_incompatible_calendars(fieldset):
114115
ds_test = ds.copy()
115116
ds_test["time"] = xr.date_range("2000", "2001", T_structured, calendar="360_day", use_cftime=True)
116-
grid = Grid(ds_test)
117+
grid = XGrid(xgcm.Grid(ds_test))
117118
field = Field("test_field", ds_test["data_g"], grid, mesh_type="flat")
118119

119120
with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"):

0 commit comments

Comments
 (0)