Skip to content

Commit 3760109

Browse files
Update fieldset ingestion to use convert modules (#40)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 17c74e1 commit 3760109

9 files changed

Lines changed: 1032 additions & 698 deletions

File tree

Parcels

Submodule Parcels updated 49 files

benchmarks/fesom2.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
import uxarray as ux
3-
import xarray as xr
43
from parcels import (
54
FieldSet,
65
Particle,
@@ -9,7 +8,7 @@
98
)
109
from parcels.kernels import AdvectionRK2_3D
1110

12-
from . import PARCELS_BENCHMARKS_DATA_FOLDER
11+
from .catalogs import Catalogs
1312

1413
runtime = np.timedelta64(1, "D")
1514
dt = np.timedelta64(2400, "s")
@@ -18,12 +17,9 @@
1817
def _load_ds():
1918
"""Helper function to load uxarray dataset from datapath"""
2019

21-
grid_file = xr.open_mfdataset(
22-
f"{PARCELS_BENCHMARKS_DATA_FOLDER}/surf-data/parcels-benchmarks/data/Parcelsv4_Benchmarking_data/Parcels_Benchmarks_FESOM-baroclinic-gyre/data/mesh/fesom.mesh.diag.nc"
23-
)
24-
data_files = xr.open_mfdataset(
25-
f"{PARCELS_BENCHMARKS_DATA_FOLDER}/surf-data/parcels-benchmarks/data/Parcelsv4_Benchmarking_data/Parcels_Benchmarks_FESOM-baroclinic-gyre/data/*.nc"
26-
)
20+
cat = Catalogs.CAT_BENCHMARKS
21+
grid_file = cat.fesom_baroclinic_gyre_mesh().to_dask()
22+
data_files = cat.fesom_baroclinic_gyre_data().to_dask()
2723

2824
grid = ux.open_grid(grid_file)
2925
return ux.UxDataset(data_files, uxgrid=grid)
@@ -48,7 +44,7 @@ def pset_execute(self, npart, integrator):
4844
lat = np.linspace(32.0, 19.0, npart)
4945

5046
pset = ParticleSet(fieldset=fieldset, pclass=Particle, lon=lon, lat=lat)
51-
pset.execute(runtime=runtime, dt=dt, pyfunc=integrator)
47+
pset.execute(kernels=integrator, runtime=runtime, dt=dt)
5248

5349
def time_pset_execute(self, npart, integrator):
5450
self.pset_execute(npart, integrator)

benchmarks/moi_curvilinear.py

Lines changed: 19 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,28 @@
1-
from glob import glob
2-
31
import numpy as np
42
import parcels
5-
import xarray as xr
6-
import xgcm
73
from parcels.interpolators import XLinear
84

5+
from .catalogs import Catalogs
6+
97
runtime = np.timedelta64(2, "D")
108
dt = np.timedelta64(15, "m")
119

1210

13-
PARCELS_DATADIR = ... # TODO: Replace with intake
14-
15-
16-
def download_dataset(*args, **kwargs): ... # TODO: Replace with intake
17-
18-
19-
def _load_ds(datapath, chunk):
20-
"""Helper function to load xarray dataset from datapath with or without chunking"""
21-
22-
fileU = f"{datapath}/psy4v3r1-daily_U_2025-01-0[1-3].nc"
23-
filenames = {
24-
"U": glob(fileU),
25-
"V": glob(fileU.replace("_U_", "_V_")),
26-
"W": glob(fileU.replace("_U_", "_W_")),
27-
}
28-
mesh_mask = f"{datapath}/PSY4V3R1_mesh_hgr.nc"
29-
fileargs = {
30-
"concat_dim": "time_counter",
31-
"combine": "nested",
32-
"data_vars": "minimal",
33-
"coords": "minimal",
34-
"compat": "override",
35-
}
36-
if chunk:
37-
fileargs["chunks"] = {"time_counter": 1, "depth": 2, "y": chunk, "x": chunk}
11+
def _load_ds(chunk):
12+
"""Helper function to load xarray dataset from catalog with or without chunking"""
13+
cat = Catalogs.CAT_BENCHMARKS
14+
chunks = {"time_counter": 1, "depth": 2, "y": chunk, "x": chunk} if chunk else None
3815

39-
ds_u = xr.open_mfdataset(filenames["U"], **fileargs)[["vozocrtx"]].drop_vars(
40-
["nav_lon", "nav_lat"]
16+
ds_u = (
17+
cat.moi_u(chunks=chunks).to_dask()[["vozocrtx"]].rename_vars({"vozocrtx": "U"})
4118
)
42-
ds_v = xr.open_mfdataset(filenames["V"], **fileargs)[["vomecrty"]].drop_vars(
43-
["nav_lon", "nav_lat"]
19+
ds_v = (
20+
cat.moi_v(chunks=chunks).to_dask()[["vomecrty"]].rename_vars({"vomecrty": "V"})
4421
)
45-
ds_depth = xr.open_mfdataset(filenames["W"], **fileargs)[["depthw"]]
46-
ds_mesh = xr.open_dataset(mesh_mask)[["glamf", "gphif"]].isel(t=0)
47-
48-
ds = xr.merge([ds_u, ds_v, ds_depth, ds_mesh], compat="identical")
49-
ds = ds.rename(
50-
{
51-
"vozocrtx": "U",
52-
"vomecrty": "V",
53-
"glamf": "lon",
54-
"gphif": "lat",
55-
"time_counter": "time",
56-
"depthw": "depth",
57-
}
58-
)
59-
ds.deptht.attrs["c_grid_axis_shift"] = -0.5
22+
da_depth = cat.moi_w(chunks=chunks).to_dask()["depthw"]
23+
ds_mesh = cat.moi_mesh(chunks=None).read()[["glamf", "gphif"]].isel(t=0)
24+
ds_mesh["depthw"] = da_depth
25+
ds = parcels.convert.nemo_to_sgrid(fields=dict(U=ds_u, V=ds_v), coords=ds_mesh)
6026

6127
return ds
6228

@@ -75,47 +41,26 @@ class MOICurvilinear:
7541
"npart",
7642
]
7743

78-
def setup(self, interpolator, chunk, npart):
79-
self.datapath = download_dataset("MOi-curvilinear", data_home=PARCELS_DATADIR)
80-
8144
def time_load_data_3d(self, interpolator, chunk, npart):
8245
"""Benchmark that times loading the 'U' and 'V' data arrays only for 3-D"""
8346

8447
# To have a reasonable runtime, we only consider the time it takes to load two time levels
8548
# and two depth levels (at most)
86-
ds = _load_ds(self.datapath, chunk)
49+
ds = _load_ds(chunk)
8750
for j in range(min(ds.coords["deptht"].size, 2)):
8851
for i in range(min(ds.coords["time"].size, 2)):
8952
_u = ds["U"].isel(deptht=j, time=i).compute()
9053
_v = ds["V"].isel(deptht=j, time=i).compute()
9154

9255
def pset_execute_3d(self, interpolator, chunk, npart):
93-
ds = _load_ds(self.datapath, chunk)
94-
coords = {
95-
"X": {"left": "x"},
96-
"Y": {"left": "y"},
97-
"Z": {"center": "deptht", "left": "depth"},
98-
"T": {"center": "time"},
99-
}
100-
101-
grid = parcels._core.xgrid.XGrid(
102-
xgcm.Grid(ds, coords=coords, autoparse_metadata=False, periodic=False),
103-
mesh="spherical",
104-
)
105-
56+
ds = _load_ds(chunk)
57+
fieldset = parcels.FieldSet.from_sgrid_conventions(ds)
10658
if interpolator == "XLinear":
107-
interp_method = XLinear
59+
fieldset.U.interp_method = XLinear
60+
fieldset.V.interp_method = XLinear
10861
else:
10962
raise ValueError(f"Unknown interpolator: {interpolator}")
11063

111-
U = parcels.Field("U", ds["U"], grid, interp_method=interp_method)
112-
V = parcels.Field("V", ds["V"], grid, interp_method=interp_method)
113-
U.units = parcels.GeographicPolar()
114-
V.units = parcels.Geographic()
115-
UV = parcels.VectorField("UV", U, V)
116-
117-
fieldset = parcels.FieldSet([U, V, UV])
118-
11964
pclass = parcels.Particle
12065

12166
lon = np.linspace(-10, 10, npart)

0 commit comments

Comments
 (0)