Skip to content

Commit a1934a0

Browse files
Merge branch 'v4-dev' into nearest_interpolation
2 parents 689061d + 807d2ee commit a1934a0

19 files changed

Lines changed: 121 additions & 129 deletions

docs/examples/tutorial_stommel_uxarray.ipynb

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
"\n",
8989
"A `UXArray.Dataset` consists of multiple `UXArray.UxDataArray`'s and a `UXArray.UxGrid`. Parcels views general circulation model data through the `Field` and `VectorField` classes. A `Field` is defined by its `name`, `data`, `grid`, and `interp_method`. A `VectorField` can be constructed by using 2 or 3 `Field`'s. The `Field.data` attribute can be either an `XArray.DataArray` or `UXArray.UxDataArray` object. The `Field.grid` attribute is of type `Parcels.XGrid` or `Parcels.UXGrid`. Last, the `interp_method` is a dynamic function that can be set at runtime to define the interpolation procedure for the `Field`. This gives you the flexibility to use one of the pre-defined interpolation methods included with Parcels v4, or to create your own interpolator. \n",
9090
"\n",
91-
"The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions."
91+
"The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions. Setting the `mesh` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees."
9292
]
9393
},
9494
{
@@ -99,7 +99,7 @@
9999
"source": [
100100
"from parcels.uxgrid import UxGrid\n",
101101
"\n",
102-
"grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"])\n",
102+
"grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"], mesh=\"spherical\")\n",
103103
"# You can view the uxgrid object with the following command:\n",
104104
"grid.uxgrid"
105105
]
@@ -112,7 +112,7 @@
112112
"\n",
113113
"In Parcels, grid searching is conducted with respect to the faces. In other words, when a grid index `ei` is provided to an interpolation method, this refers the face index `fi` at vertical layer `zi` (when unraveled). Within the interpolation method, the `field.grid.uxgrid.face_node_connectivity` attribute can be used to obtain the node indices that surround the face. Using these connectivity tables is necessary for properly indexing node registered data.\n",
114114
"\n",
115-
"For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method. Setting the `mesh_type` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees."
115+
"For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method."
116116
]
117117
},
118118
{
@@ -128,21 +128,18 @@
128128
" name=\"U\",\n",
129129
" data=ds.U,\n",
130130
" grid=grid,\n",
131-
" mesh_type=\"spherical\",\n",
132131
" interp_method=UXPiecewiseConstantFace,\n",
133132
")\n",
134133
"V = Field(\n",
135134
" name=\"V\",\n",
136135
" data=ds.V,\n",
137136
" grid=grid,\n",
138-
" mesh_type=\"spherical\",\n",
139137
" interp_method=UXPiecewiseConstantFace,\n",
140138
")\n",
141139
"P = Field(\n",
142140
" name=\"P\",\n",
143141
" data=ds.p,\n",
144142
" grid=grid,\n",
145-
" mesh_type=\"spherical\",\n",
146143
" interp_method=UXPiecewiseConstantFace,\n",
147144
")"
148145
]

parcels/_datasets/structured/generated.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import xarray as xr
55

66

7-
def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh_type="spherical"):
8-
max_lon = 180.0 if mesh_type == "spherical" else 1e6
7+
def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"):
8+
max_lon = 180.0 if mesh == "spherical" else 1e6
99

1010
return xr.Dataset(
1111
{"U": (["time", "depth", "YG", "XG"], np.zeros(dims)), "V": (["time", "depth", "YG", "XG"], np.zeros(dims))},

parcels/_index_search.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from parcels._typing import (
99
GridIndexingType,
1010
InterpMethodOption,
11+
Mesh,
1112
)
1213
from parcels.tools.statuscodes import (
1314
FieldOutOfBoundError,
@@ -174,7 +175,7 @@ def _search_indices_rectilinear(
174175
_raise_field_out_of_bound_error(z, y, x)
175176

176177
if field.xdim > 1:
177-
if field._mesh_type != "spherical":
178+
if field._mesh != "spherical":
178179
lon_index = field.lon < x
179180
if lon_index.all():
180181
xi = len(field.lon) - 2
@@ -305,7 +306,7 @@ def _search_indices_curvilinear_2d(
305306
xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
306307
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))
307308

308-
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh)
309+
(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh)
309310
it += 1
310311
if it > maxIterSearch:
311312
print(f"Correct cell not found after {maxIterSearch} iterations")
@@ -408,11 +409,11 @@ def _search_indices_curvilinear(field, time, z, y, x, ti, particle=None, search2
408409
return (zeta, eta, xsi, zi, yi, xi)
409410

410411

411-
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, sphere_mesh: bool):
412-
xi = np.where(xi < 0, (xdim - 2) if sphere_mesh else 0, xi)
413-
xi = np.where(xi > xdim - 2, 0 if sphere_mesh else (xdim - 2), xi)
412+
def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh):
413+
xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi)
414+
xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi)
414415

415-
xi = np.where(yi > ydim - 2, xdim - xi if sphere_mesh else xi, xi)
416+
xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi)
416417

417418
yi = np.where(yi < 0, 0, yi)
418419
yi = np.where(yi > ydim - 2, ydim - 2, yi)

parcels/field.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111

1212
from parcels._core.utils.time import TimeInterval
1313
from parcels._reprs import default_repr
14-
from parcels._typing import (
15-
Mesh,
16-
VectorType,
17-
assert_valid_mesh,
18-
)
14+
from parcels._typing import VectorType
1915
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, ZeroInterpolator
2016
from parcels.particle import KernelParticle
2117
from parcels.tools.converters import (
@@ -86,7 +82,7 @@ class Field:
8682
-----
8783
The xarray.DataArray or uxarray.UxDataArray object contains the field data and metadata.
8884
* dims: (time, [nz1 | nz], [face_lat | node_lat | edge_lat], [face_lon | node_lon | edge_lon])
89-
* attrs: (location, mesh, mesh_type)
85+
* attrs: (location, mesh, mesh)
9086
9187
When using a xarray.DataArray object,
9288
* The xarray.DataArray object must have the "location" and "mesh" attributes set.
@@ -114,7 +110,6 @@ def __init__(
114110
name: str,
115111
data: xr.DataArray | ux.UxDataArray,
116112
grid: UxGrid | XGrid,
117-
mesh_type: Mesh = "flat",
118113
interp_method: Callable | None = None,
119114
):
120115
if not isinstance(data, (ux.UxDataArray, xr.DataArray)):
@@ -126,8 +121,6 @@ def __init__(
126121
if not isinstance(grid, (UxGrid, XGrid)):
127122
raise ValueError(f"Expected `grid` to be a parcels UxGrid, or parcels XGrid object, got {type(grid)}.")
128123

129-
assert_valid_mesh(mesh_type)
130-
131124
_assert_compatible_combination(data, grid)
132125

133126
if isinstance(grid, XGrid):
@@ -155,8 +148,6 @@ def __init__(
155148
e.add_note(f"Error validating field {name!r}.")
156149
raise e
157150

158-
self._mesh_type = mesh_type
159-
160151
# Setting the interpolation method dynamically
161152
if interp_method is None:
162153
self._interp_method = _DEFAULT_INTERPOLATOR_MAPPING[type(self.grid)]
@@ -166,12 +157,10 @@ def __init__(
166157

167158
self.igrid = -1 # Default the grid index to -1
168159

169-
if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()):
160+
if self.grid._mesh == "flat" or (self.name not in unitconverters_map.keys()):
170161
self.units = UnitConverter()
171-
elif self._mesh_type == "spherical":
162+
elif self.grid._mesh == "spherical":
172163
self.units = unitconverters_map[self.name]
173-
else:
174-
raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'")
175164

176165
if self.data.shape[0] > 1:
177166
if "time" not in self.data.coords:

parcels/kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def check_fieldsets_in_kernels(self, pyfunc): # TODO v4: this can go into anoth
143143
stacklevel=2,
144144
)
145145
self.fieldset.add_constant("RK45_tol", 10)
146-
if self.fieldset.U.grid.mesh == "spherical":
146+
if self.fieldset.U.grid._mesh == "spherical":
147147
self.fieldset.RK45_tol /= (
148148
1852 * 60
149149
) # TODO does not account for zonal variation in meter -> degree conversion

parcels/spatialhash.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _hash_index2d(self, coords):
8686
as the source grid coordinates
8787
"""
8888
# Wrap longitude to [-180, 180]
89-
if self._source_grid.mesh == "spherical":
89+
if self._source_grid._mesh == "spherical":
9090
lon = (coords[:, 1] + 180.0) % (360.0) - 180.0
9191
else:
9292
lon = coords[:, 1]

parcels/uxgrid.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import uxarray as ux
77

8+
from parcels._typing import assert_valid_mesh
89
from parcels.spatialhash import _barycentric_coordinates
910
from parcels.tools.statuscodes import FieldOutOfBoundError
1011
from parcels.xgrid import _search_1d_array
@@ -20,7 +21,7 @@ class UxGrid(BaseGrid):
2021
for interpolation on unstructured grids.
2122
"""
2223

23-
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
24+
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid:
2425
"""
2526
Initializes the UxGrid with a uxarray grid and vertical coordinate array.
2627
@@ -32,13 +33,18 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray) -> UxGrid:
3233
A 1D array of vertical coordinates (depths) associated with the layer interface heights (not the mid-layer depths).
3334
While uxarray allows nz to be spatially and temporally varying, the parcels.UxGrid class considers the case where
3435
the vertical coordinate is constant in time and space. This implies flat bottom topography and no moving ALE vertical grid.
36+
mesh : str, optional
37+
The type of mesh used for the grid. Either "flat" (default) or "spherical".
3538
"""
3639
self.uxgrid = grid
3740
if not isinstance(z, ux.UxDataArray):
3841
raise TypeError("z must be an instance of ux.UxDataArray")
3942
if z.ndim != 1:
4043
raise ValueError("z must be a 1D array of vertical coordinates")
4144
self.z = z
45+
self._mesh = mesh
46+
47+
assert_valid_mesh(mesh)
4248

4349
@property
4450
def depth(self):

parcels/xgrid.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from parcels import xgcm
1010
from parcels._index_search import _search_indices_curvilinear_2d
11+
from parcels._typing import assert_valid_mesh
1112
from parcels.basegrid import BaseGrid
1213
from parcels.spatialhash import SpatialHash
1314

@@ -97,13 +98,15 @@ class XGrid(BaseGrid):
9798

9899
def __init__(self, grid: xgcm.Grid, mesh="flat"):
99100
self.xgcm_grid = grid
100-
self.mesh = mesh
101+
self._mesh = mesh
101102
self._spatialhash = None
102103
ds = grid._ds
103104

104105
if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development)
105106
assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes)
106107

108+
assert_valid_mesh(mesh)
109+
107110
@classmethod
108111
def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
109112
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release

tests/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ def create_fieldset_global(xdim=200, ydim=100):
6868
return FieldSet.from_data(data, dimensions, mesh="flat")
6969

7070

71-
def create_fieldset_zeros_conversion(mesh_type="spherical", xdim=200, ydim=100) -> FieldSet:
71+
def create_fieldset_zeros_conversion(mesh="spherical", xdim=200, ydim=100) -> FieldSet:
7272
"""Zero velocity field with lat and lon determined by a conversion factor."""
73-
mesh_conversion = 1 / 1852.0 / 60 if mesh_type == "spherical" else 1
74-
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh_type=mesh_type)
73+
mesh_conversion = 1 / 1852.0 / 60 if mesh == "spherical" else 1
74+
ds = simple_UV_dataset(dims=(2, 1, ydim, xdim), mesh=mesh)
7575
ds["lon"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, xdim)
7676
ds["lat"].data = np.linspace(-1e6 * mesh_conversion, 1e6 * mesh_conversion, ydim)
77-
grid = XGrid.from_dataset(ds)
78-
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
79-
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
77+
grid = XGrid.from_dataset(ds, mesh=mesh)
78+
U = Field("U", ds["U"], grid, interp_method=XLinear)
79+
V = Field("V", ds["V"], grid, interp_method=XLinear)
8080

8181
UV = VectorField("UV", U, V)
8282
return FieldSet([U, V, UV])

tests/v4/test_advection.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,21 @@
3232
}
3333

3434

35-
@pytest.mark.parametrize("mesh_type", ["spherical", "flat"])
36-
def test_advection_zonal(mesh_type, npart=10):
35+
@pytest.mark.parametrize("mesh", ["spherical", "flat"])
36+
def test_advection_zonal(mesh, npart=10):
3737
"""Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`."""
38-
ds = simple_UV_dataset(mesh_type=mesh_type)
38+
ds = simple_UV_dataset(mesh=mesh)
3939
ds["U"].data[:] = 1.0
40-
grid = XGrid.from_dataset(ds)
41-
U = Field("U", ds["U"], grid, mesh_type=mesh_type, interp_method=XLinear)
42-
V = Field("V", ds["V"], grid, mesh_type=mesh_type, interp_method=XLinear)
40+
grid = XGrid.from_dataset(ds, mesh=mesh)
41+
U = Field("U", ds["U"], grid, interp_method=XLinear)
42+
V = Field("V", ds["V"], grid, interp_method=XLinear)
4343
UV = VectorField("UV", U, V)
4444
fieldset = FieldSet([U, V, UV])
4545

4646
pset = ParticleSet(fieldset, lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart))
4747
pset.execute(AdvectionRK4, runtime=np.timedelta64(2, "h"), dt=np.timedelta64(15, "m"))
4848

49-
if mesh_type == "spherical":
49+
if mesh == "spherical":
5050
assert (np.diff(pset.lon) > 1.0e-4).all()
5151
else:
5252
assert (np.diff(pset.lon) < 1.0e-4).all()
@@ -58,7 +58,7 @@ def periodicBC(particle, fieldset, time):
5858

5959

6060
def test_advection_zonal_periodic():
61-
ds = simple_UV_dataset(dims=(2, 2, 2, 2), mesh_type="flat")
61+
ds = simple_UV_dataset(dims=(2, 2, 2, 2), mesh="flat")
6262
ds["U"].data[:] = 0.1
6363
ds["lon"].data = np.array([0, 2])
6464
ds["lat"].data = np.array([0, 2])
@@ -86,7 +86,7 @@ def test_advection_zonal_periodic():
8686

8787
def test_horizontal_advection_in_3D_flow(npart=10):
8888
"""Flat 2D zonal flow that increases linearly with depth from 0 m/s to 1 m/s."""
89-
ds = simple_UV_dataset(mesh_type="flat")
89+
ds = simple_UV_dataset(mesh="flat")
9090
ds["U"].data[:] = 1.0
9191
grid = XGrid.from_dataset(ds)
9292
U = Field("U", ds["U"], grid, interp_method=XLinear)
@@ -105,7 +105,7 @@ def test_horizontal_advection_in_3D_flow(npart=10):
105105
@pytest.mark.parametrize("direction", ["up", "down"])
106106
@pytest.mark.parametrize("wErrorThroughSurface", [True, False])
107107
def test_advection_3D_outofbounds(direction, wErrorThroughSurface):
108-
ds = simple_UV_dataset(mesh_type="flat")
108+
ds = simple_UV_dataset(mesh="flat")
109109
grid = XGrid.from_dataset(ds)
110110
U = Field("U", ds["U"], grid, interp_method=XLinear)
111111
U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
@@ -208,9 +208,9 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read
208208

209209
def test_radialrotation(npart=10):
210210
ds = radial_rotation_dataset()
211-
grid = XGrid.from_dataset(ds)
212-
U = parcels.Field("U", ds["U"], grid, mesh_type="flat", interp_method=XLinear)
213-
V = parcels.Field("V", ds["V"], grid, mesh_type="flat", interp_method=XLinear)
211+
grid = XGrid.from_dataset(ds, mesh="flat")
212+
U = parcels.Field("U", ds["U"], grid, interp_method=XLinear)
213+
V = parcels.Field("V", ds["V"], grid, interp_method=XLinear)
214214
UV = parcels.VectorField("UV", U, V)
215215
fieldset = parcels.FieldSet([U, V, UV])
216216

0 commit comments

Comments
 (0)