Skip to content

Commit 687c1e1

Browse files
Merge branch 'v4-dev' into feature/uxarray_xarray_fields
2 parents 819a077 + 1316f7a commit 687c1e1

16 files changed

Lines changed: 1851 additions & 15 deletions

parcels/_datasets/__init__.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Datasets compatible with Parcels.
3+
4+
This subpackage uses xarray to generate *idealised* structured and unstructured hydrodynamical datasets that are compatible with Parcels. The goals are three-fold:
5+
6+
1. To provide users with documentation for the types of datasets they can expect Parcels to work with.
7+
2. To supply our tutorials with hydrodynamical datasets.
8+
3. To offer developers datasets for use in test cases.
9+
10+
Note that this subpackage is part of the private API for Parcels. Users should not rely directly on the functions defined within this module. Instead, if you want to generate your own datasets, copy the functions from this module into your own code.
11+
12+
Developers, note that you should only add functions that create idealised datasets to this subpackage if they are (a) quick to generate, and (b) only use dependencies already shipped with Parcels. No data files should be added to this subpackage. Real world data files should be added to the `OceanParcels/parcels-data` repository on GitHub.
13+
14+
Parcels Dataset Philosophy
15+
-------------------------
16+
17+
When adding datasets, there may be a tension between wanting to add a specific dataset or wanting to add machinery to generate completely parameterised datasets (e.g., with different grid resolutions, with different ranges, with different datetimes etc.). There are trade-offs to both approaches:
18+
19+
Working with specific hardcoded datasets:
20+
21+
* Pros
22+
* the example is stable and self-contained
23+
* easy to see exactly what the dataset is, there is little to no dependency on other functions defined in the same module
24+
* datasets don't "break" due to changes in other functions (e.g., grid edges becoming out of sync with grid centres)
25+
* Cons
26+
* inflexible for use in tests where you want to test a large range of datasets, or you want to test a specific resolution
27+
28+
Working with generated datasets is the opposite of all the above.
29+
30+
Most of the time we only want a single dataset. For example, for use in a tutorial, or for testing a specific feature of Parcels - such as (in the case of structured grids) checking that the grid from a certain (ocean) circulation model is correctly parsed, or checking that indexing is correctly picked up. As such, one should often opt for hardcoded datasets. These are more stable and easier to see exactly what the dataset is. We may have specific examples that become the default "go to" dataset for testing when we don't care about the details of the dataset.
31+
32+
Sometimes we may want to test Parcels against a whole range of datasets varying in a certain way - to ensure Parcels works as expected. For these, we should add machinery to create generated datasets.
33+
34+
Structure
35+
--------
36+
37+
This subpackage is broken down into structured and unstructured parts. Each of these have common submodules:
38+
39+
* ``circulation_model`` -> hardcoded datasets with the intention of mimicking datasets from a certain (ocean) circulation model
40+
* ``generic`` -> hardcoded datasets that are generic, and not tied to a certain (ocean) circulation model. Instead these focus on the fundamental properties of the dataset
41+
* ``generated`` -> functions to generate datasets with varying properties
42+
* ``utils`` -> any utility functions necessary related to either generating or validating datasets
43+
44+
There may be extra submodules than the ones listed above.
45+
46+
"""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Structured datasets."""
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""Datasets focussing on grid geometry"""
2+
3+
import numpy as np
4+
import xarray as xr
5+
6+
N = 30
7+
T = 10
8+
9+
10+
def rotated_curvilinear_grid():
11+
XG = np.arange(N)
12+
YG = np.arange(2 * N)
13+
LON, LAT = np.meshgrid(XG, YG)
14+
15+
angle = -np.pi / 24
16+
rotation = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
17+
18+
# rotate the LON and LAT grids
19+
LON, LAT = np.einsum("ji, mni -> jmn", rotation, np.dstack([LON, LAT]))
20+
21+
return xr.Dataset(
22+
{
23+
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
24+
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
25+
},
26+
coords={
27+
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
28+
"YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}),
29+
"XC": (["XC"], XG + 0.5, {"axis": "X"}),
30+
"YC": (["YC"], YG + 0.5, {"axis": "Y"}),
31+
"ZG": (
32+
["ZG"],
33+
np.arange(3 * N),
34+
{"axis": "Z", "c_grid_axis_shift": -0.5},
35+
),
36+
"ZC": (
37+
["ZC"],
38+
np.arange(3 * N) + 0.5,
39+
{"axis": "Z"},
40+
),
41+
"depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}),
42+
"time": (["time"], np.arange(T), {"axis": "T"}),
43+
"lon": (
44+
["YG", "XG"],
45+
LON,
46+
{"axis": "X", "c_grid_axis_shift": -0.5}, # ? Needed?
47+
),
48+
"lat": (
49+
["YG", "XG"],
50+
LAT,
51+
{"axis": "Y", "c_grid_axis_shift": -0.5}, # ? Needed?
52+
),
53+
},
54+
)
55+
56+
57+
def _cartesion_to_polar(x, y):
58+
r = np.sqrt(x**2 + y**2)
59+
theta = np.arctan2(y, x)
60+
return r, theta
61+
62+
63+
def _polar_to_cartesian(r, theta):
64+
x = r * np.cos(theta)
65+
y = r * np.sin(theta)
66+
return x, y
67+
68+
69+
def unrolled_cone_curvilinear_grid():
70+
# Not a great unrolled cone, but this is good enough for testing
71+
# you can use matplotlib pcolormesh to plot
72+
XG = np.arange(N)
73+
YG = np.arange(2 * N) * 0.25
74+
75+
pivot = -10, 0
76+
LON, LAT = np.meshgrid(XG, YG)
77+
78+
new_lon_lat = []
79+
80+
min_lon = np.min(XG)
81+
for lon, lat in zip(LON.flatten(), LAT.flatten(), strict=True):
82+
r, _ = _cartesion_to_polar(lon - pivot[0], lat - pivot[1])
83+
_, theta = _cartesion_to_polar(min_lon - pivot[0], lat - pivot[1])
84+
theta *= 1.2
85+
r *= 1.2
86+
lon, lat = _polar_to_cartesian(r, theta)
87+
new_lon_lat.append((lon + pivot[0], lat + pivot[1]))
88+
89+
new_lon, new_lat = zip(*new_lon_lat, strict=True)
90+
LON, LAT = np.array(new_lon).reshape(LON.shape), np.array(new_lat).reshape(LAT.shape)
91+
92+
return xr.Dataset(
93+
{
94+
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
95+
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
96+
},
97+
coords={
98+
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
99+
"YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}),
100+
"XC": (["XC"], XG + 0.5, {"axis": "X"}),
101+
"YC": (["YC"], YG + 0.5, {"axis": "Y"}),
102+
"ZG": (
103+
["ZG"],
104+
np.arange(3 * N),
105+
{"axis": "Z", "c_grid_axis_shift": -0.5},
106+
),
107+
"ZC": (
108+
["ZC"],
109+
np.arange(3 * N) + 0.5,
110+
{"axis": "Z"},
111+
),
112+
"depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}),
113+
"time": (["time"], np.arange(T), {"axis": "T"}),
114+
"lon": (
115+
["YG", "XG"],
116+
LON,
117+
{"axis": "X", "c_grid_axis_shift": -0.5}, # ? Needed?
118+
),
119+
"lat": (
120+
["YG", "XG"],
121+
LAT,
122+
{"axis": "Y", "c_grid_axis_shift": -0.5}, # ? Needed?
123+
),
124+
},
125+
)
126+
127+
128+
datasets = {
129+
"2d_left_rotated": rotated_curvilinear_grid(),
130+
"ds_2d_left": xr.Dataset(
131+
{
132+
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
133+
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
134+
},
135+
coords={
136+
"XG": (
137+
["XG"],
138+
2 * np.pi / N * np.arange(0, N),
139+
{"axis": "X", "c_grid_axis_shift": -0.5},
140+
),
141+
"XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}),
142+
"YG": (
143+
["YG"],
144+
2 * np.pi / (2 * N) * np.arange(0, 2 * N),
145+
{"axis": "Y", "c_grid_axis_shift": -0.5},
146+
),
147+
"YC": (
148+
["YC"],
149+
2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5),
150+
{"axis": "Y"},
151+
),
152+
"ZG": (
153+
["ZG"],
154+
np.arange(3 * N),
155+
{"axis": "Z", "c_grid_axis_shift": -0.5},
156+
),
157+
"ZC": (
158+
["ZC"],
159+
np.arange(3 * N) + 0.5,
160+
{"axis": "Z"},
161+
),
162+
"lon": (["XG"], 2 * np.pi / N * np.arange(0, N)),
163+
"lat": (["YG"], 2 * np.pi / (2 * N) * np.arange(0, 2 * N)),
164+
"depth": (["ZG"], np.arange(3 * N)),
165+
"time": (["time"], np.arange(T), {"axis": "T"}),
166+
},
167+
),
168+
"2d_left_unrolled_cone": unrolled_cone_curvilinear_grid(),
169+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unstructured datasets."""

parcels/_datasets/utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
import xarray as xr
5+
6+
from parcels._compat import add_note
7+
8+
_SUPPORTED_ATTR_TYPES = int | float | str | np.ndarray
9+
10+
11+
def _print_mismatched_keys(d1: dict[Any, Any], d2: dict[Any, Any]) -> None:
12+
k1 = set(d1.keys())
13+
k2 = set(d2.keys())
14+
if len(k1 ^ k2) == 0:
15+
return
16+
print("Mismatched keys:")
17+
print(f"L: {k1 - k2!r}")
18+
print(f"R: {k2 - k1!r}")
19+
20+
21+
def assert_common_attrs_equal(
22+
xr_attrs_1: dict[str, _SUPPORTED_ATTR_TYPES], xr_attrs_2: dict[str, _SUPPORTED_ATTR_TYPES], *, verbose: bool = True
23+
) -> None:
24+
d1, d2 = xr_attrs_1, xr_attrs_2
25+
26+
common_keys = set(d1.keys()) & set(d2.keys())
27+
if verbose:
28+
_print_mismatched_keys(d1, d2)
29+
30+
for key in common_keys:
31+
try:
32+
if isinstance(d1[key], np.ndarray):
33+
np.testing.assert_array_equal(d1[key], d2[key])
34+
else:
35+
assert d1[key] == d2[key], f"{d1[key]} != {d2[key]}"
36+
except AssertionError as e:
37+
add_note(e, f"error on key {key!r}")
38+
raise
39+
40+
41+
def assert_common_variables_common_attrs_equal(ds1: xr.Dataset, ds2: xr.Dataset, *, verbose: bool = True) -> None:
42+
if verbose:
43+
print("Checking dataset attrs...")
44+
45+
assert_common_attrs_equal(ds1.attrs, ds2.attrs, verbose=verbose)
46+
47+
ds1_vars = set(ds1.variables)
48+
ds2_vars = set(ds2.variables)
49+
50+
common_variables = ds1_vars & ds2_vars
51+
if len(ds1_vars ^ ds2_vars) > 0 and verbose:
52+
print("Mismatched variables:")
53+
print(f"L: {ds1_vars - ds2_vars}")
54+
print(f"R: {ds2_vars - ds1_vars}")
55+
56+
for var in common_variables:
57+
if verbose:
58+
print(f"Checking {var!r} attrs")
59+
assert_common_attrs_equal(ds1[var].attrs, ds2[var].attrs, verbose=verbose)
60+
61+
62+
def dataset_repr_diff(ds1: xr.Dataset, ds2: xr.Dataset) -> str:
63+
"""Return a text diff of two datasets."""
64+
repr1 = repr(ds1)
65+
repr2 = repr(ds2)
66+
import difflib
67+
68+
diff = difflib.ndiff(repr1.splitlines(keepends=True), repr2.splitlines(keepends=True))
69+
return "".join(diff)

parcels/grid.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
"CurvilinearSGrid",
1111
"CurvilinearZGrid",
1212
"Grid",
13-
"GridCode",
1413
"GridType",
1514
"RectilinearSGrid",
1615
"RectilinearZGrid",
@@ -24,11 +23,6 @@ class GridType(IntEnum):
2423
CurvilinearSGrid = 3
2524

2625

27-
# GridCode has been renamed to GridType for consistency.
28-
# TODO: Remove alias in Parcels v4
29-
GridCode = GridType
30-
31-
3226
class Grid:
3327
"""Grid class that defines a (spatial and temporal) grid on which Fields are defined."""
3428

@@ -40,7 +34,6 @@ def __init__(
4034
time_origin: TimeConverter | None,
4135
mesh: Mesh,
4236
):
43-
self._ti = -1
4437
lon = np.array(lon)
4538
lat = np.array(lat)
4639
time = np.zeros(1, dtype=np.float64) if time is None else time
@@ -112,7 +105,6 @@ def create_grid(
112105
time,
113106
time_origin,
114107
mesh: Mesh,
115-
**kwargs,
116108
):
117109
lon = np.array(lon)
118110
lat = np.array(lat)
@@ -122,14 +114,14 @@ def create_grid(
122114

123115
if len(lon.shape) <= 1:
124116
if depth is None or len(depth.shape) <= 1:
125-
return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
117+
return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
126118
else:
127-
return RectilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
119+
return RectilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
128120
else:
129121
if depth is None or len(depth.shape) <= 1:
130-
return CurvilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
122+
return CurvilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
131123
else:
132-
return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
124+
return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
133125

134126
def _check_zonal_periodic(self):
135127
if self.zonal_periodic or self.mesh == "flat" or self.lon.size == 1:

parcels/particleset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from parcels._compat import MPI
1414
from parcels.application_kernels.advection import AdvectionRK4
1515
from parcels.field import Field
16-
from parcels.grid import CurvilinearGrid, GridType
16+
from parcels.grid import GridType
1717
from parcels.interaction.interactionkernel import InteractionKernel
1818
from parcels.interaction.neighborsearch import (
1919
BruteFlatNeighborSearch,
@@ -430,7 +430,7 @@ def populate_indices(self):
430430
may be quite expensive.
431431
"""
432432
for i, grid in enumerate(self.fieldset.gridset.grids):
433-
if not isinstance(grid, CurvilinearGrid):
433+
if grid._gtype not in [GridType.CurvilinearZGrid, GridType.CurvilinearSGrid]:
434434
continue
435435

436436
tree_data = np.stack((grid.lon.flat, grid.lat.flat), axis=-1)

parcels/v4/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)