Skip to content

Commit 92548e1

Browse files
Merge pull request #1772 from OceanParcels/CROCO_fix_sigma_calculation
Implementing correct depth-to-sigma calculation
2 parents ef9eaa9 + 4c92221 commit 92548e1

8 files changed

Lines changed: 199 additions & 68 deletions

File tree

docs/examples/tutorial_croco_3D.ipynb

Lines changed: 27 additions & 21 deletions
Large diffs are not rendered by default.

parcels/compilation/codegenerator.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -425,10 +425,13 @@ def __init__(self, fieldset=None, ptype=JITParticle):
425425
self.fieldset = fieldset
426426
self.ptype = ptype
427427
self.field_args = collections.OrderedDict()
428-
if isinstance(fieldset.U, Field) and fieldset.U.gridindexingtype == "croco" and hasattr(fieldset, "H"):
429-
self.field_args["H"] = fieldset.H # CROCO requires H field
430428
self.vector_field_args = collections.OrderedDict()
431429
self.const_args = collections.OrderedDict()
430+
if isinstance(fieldset.U, Field) and fieldset.U.gridindexingtype == "croco" and hasattr(fieldset, "H"):
431+
self.field_args["H"] = fieldset.H # CROCO requires H field
432+
self.field_args["Zeta"] = fieldset.Zeta # CROCO requires Zeta field
433+
self.field_args["Cs_w"] = fieldset.Cs_w # CROCO requires CS_w field
434+
self.const_args["hc"] = fieldset.hc # CROCO requires hc constant
432435

433436
def generate(self, py_ast, funcvars: list[str]):
434437
# Replace occurrences of intrinsic objects in Python AST
@@ -825,16 +828,18 @@ def visit_FieldEvalNode(self, node):
825828
self.visit(node.field)
826829
self.visit(node.args)
827830
args = self._check_FieldSamplingArguments(node.args.ccode)
828-
statements_croco = []
829-
if "croco" in node.field.obj.gridindexingtype and node.field.obj.name != "H":
830-
statements_croco.append(
831-
c.Assign(
832-
"parcels_interp_state",
833-
f"temporal_interpolation({args[3]}, {args[2]}, 0, time, H, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], &{node.var}, LINEAR, {node.field.obj.gridindexingtype.upper()})",
834-
)
835-
)
836-
statements_croco.append(c.Statement(f"{node.var} = {args[1]}/{node.var}"))
831+
if "croco" in node.field.obj.gridindexingtype and node.field.obj.name != "H" and node.field.obj.name != "Zeta":
832+
# Get Cs_w values directly from fieldset (since they are 1D in vertical only)
833+
Cs_w = [float(self.fieldset.Cs_w.data[0][zi][0][0]) for zi in range(self.fieldset.Cs_w.data.shape[1])]
834+
statements_croco = [
835+
c.Statement(f"float cs_w[] = {*Cs_w, }".replace("(", "{").replace(")", "}")),
836+
c.Statement(
837+
f"{node.var} = croco_from_z_to_sigma(U, H, Zeta, {args[3]}, {args[2]}, {args[1]}, time, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], hc, &cs_w)"
838+
),
839+
]
837840
args = (args[0], node.var, args[2], args[3])
841+
else:
842+
statements_croco = []
838843
ccode_eval = node.field.obj._ccode_eval(node.var, *args)
839844
stmts = [
840845
c.Assign("parcels_interp_state", ccode_eval),
@@ -852,16 +857,18 @@ def visit_VectorFieldEvalNode(self, node):
852857
self.visit(node.field)
853858
self.visit(node.args)
854859
args = self._check_FieldSamplingArguments(node.args.ccode)
855-
statements_croco = []
856860
if "3DSigma" in node.field.obj.vector_type:
857-
statements_croco.append(
858-
c.Assign(
859-
"parcels_interp_state",
860-
f"temporal_interpolation({args[3]}, {args[2]}, 0, time, H, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], &{node.var}, LINEAR, {node.field.obj.U.gridindexingtype.upper()})",
861-
)
862-
)
863-
statements_croco.append(c.Statement(f"{node.var4} = {args[1]}/{node.var}"))
861+
# Get Cs_w values directly from fieldset (since they are 1D in vertical only)
862+
Cs_w = [float(self.fieldset.Cs_w.data[0][zi][0][0]) for zi in range(self.fieldset.Cs_w.data.shape[1])]
863+
statements_croco = [
864+
c.Statement(f"float cs_w[] = {*Cs_w, }".replace("(", "{").replace(")", "}")),
865+
c.Statement(
866+
f"{node.var4} = croco_from_z_to_sigma(U, H, Zeta, {args[3]}, {args[2]}, {args[1]}, time, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], hc, &cs_w)"
867+
),
868+
]
864869
args = (args[0], node.var4, args[2], args[3])
870+
else:
871+
statements_croco = []
865872
ccode_eval = node.field.obj._ccode_eval(
866873
node.var, node.var2, node.var3, node.field.obj.U, node.field.obj.V, node.field.obj.W, *args
867874
)

parcels/field.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,26 @@ def _deal_with_errors(error, key, vector_type: VectorType):
7676
return 0
7777

7878

79+
def _croco_from_z_to_sigma_scipy(fieldset, time, z, y, x, particle):
80+
"""Calculate local sigma level of the particle, by linearly interpolating the
81+
scaling function that maps sigma to depth (using local ocean depth H,
82+
sea-surface Zeta and stretching parameters Cs_w and hc).
83+
See also https://croco-ocean.gitlabpages.inria.fr/croco_doc/model/model.grid.html#vertical-grid-parameters
84+
"""
85+
h = fieldset.H.eval(time, 0, y, x, particle=particle, applyConversion=False)
86+
zeta = fieldset.Zeta.eval(time, 0, y, x, particle=particle, applyConversion=False)
87+
sigma_levels = fieldset.U.grid.depth
88+
z0 = fieldset.hc * sigma_levels + (h - fieldset.hc) * fieldset.Cs_w.data[0, :, 0, 0]
89+
zvec = z0 + zeta * (1 + (z0 / h))
90+
zinds = zvec <= z
91+
if z >= zvec[-1]:
92+
zi = len(zvec) - 2
93+
else:
94+
zi = zinds.argmin() - 1 if z >= zvec[0] else 0
95+
96+
return sigma_levels[zi] + (z - zvec[zi]) * (sigma_levels[zi + 1] - sigma_levels[zi]) / (zvec[zi + 1] - zvec[zi])
97+
98+
7999
class Field:
80100
"""Class that encapsulates access to field data.
81101
@@ -617,18 +637,23 @@ def from_netcdf(
617637

618638
_grid_fb_class = NetcdfFileBuffer
619639

620-
with _grid_fb_class(
621-
lonlat_filename,
622-
dimensions,
623-
indices,
624-
netcdf_engine,
625-
gridindexingtype=gridindexingtype,
626-
) as filebuffer:
627-
lon, lat = filebuffer.lonlat
628-
indices = filebuffer.indices
629-
# Check if parcels_mesh has been explicitly set in file
630-
if "parcels_mesh" in filebuffer.dataset.attrs:
631-
mesh = filebuffer.dataset.attrs["parcels_mesh"]
640+
if "lon" in dimensions and "lat" in dimensions:
641+
with _grid_fb_class(
642+
lonlat_filename,
643+
dimensions,
644+
indices,
645+
netcdf_engine,
646+
gridindexingtype=gridindexingtype,
647+
) as filebuffer:
648+
lon, lat = filebuffer.lonlat
649+
indices = filebuffer.indices
650+
# Check if parcels_mesh has been explicitly set in file
651+
if "parcels_mesh" in filebuffer.dataset.attrs:
652+
mesh = filebuffer.dataset.attrs["parcels_mesh"]
653+
else:
654+
lon = 0
655+
lat = 0
656+
mesh = "flat"
632657

633658
if "depth" in dimensions:
634659
with _grid_fb_class(
@@ -1537,8 +1562,8 @@ def eval(self, time, z, y, x, particle=None, applyConversion=True):
15371562
"""
15381563
(ti, periods) = self._time_index(time)
15391564
time -= periods * (self.grid.time_full[-1] - self.grid.time_full[0])
1540-
if self.gridindexingtype == "croco" and self is not self.fieldset.H:
1541-
z = z / self.fieldset.H.eval(time, 0, y, x, particle=particle, applyConversion=False)
1565+
if self.gridindexingtype == "croco" and self not in [self.fieldset.H, self.fieldset.Zeta]:
1566+
z = _croco_from_z_to_sigma_scipy(self.fieldset, time, z, y, x, particle=particle)
15421567
if ti < self.grid.tdim - 1 and time > self.grid.time[ti]:
15431568
f0 = self._spatial_interpolation(ti, z, y, x, time, particle=particle)
15441569
f1 = self._spatial_interpolation(ti + 1, z, y, x, time, particle=particle)
@@ -2250,7 +2275,7 @@ def spatial_c_grid_interpolation3D(self, ti, z, y, x, time, particle=None, apply
22502275
(u, v, w) = self.spatial_c_grid_interpolation3D_full(ti, z, y, x, time, particle=particle)
22512276
else:
22522277
if self.gridindexingtype == "croco":
2253-
z = z / self.fieldset.H.eval(time, 0, y, x, particle=particle, applyConversion=False)
2278+
z = _croco_from_z_to_sigma_scipy(self.fieldset, time, z, y, x, particle=particle)
22542279
(u, v) = self.spatial_c_grid_interpolation2D(ti, z, y, x, time, particle=particle)
22552280
w = self.W.eval(time, z, y, x, particle=particle, applyConversion=False)
22562281
if applyConversion:

parcels/fieldfilebuffer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,10 @@ def _check_extend_depth(self, data, di):
188188
)
189189

190190
def _apply_indices(self, data, ti):
191-
if len(data.shape) == 2:
191+
if len(data.shape) == 1:
192+
if self.indices["depth"] is not None:
193+
data = data[self.indices["depth"]]
194+
elif len(data.shape) == 2:
192195
if self.nolonlatindices:
193196
pass
194197
else:

parcels/fieldset.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ def from_croco(
713713
filenames,
714714
variables,
715715
dimensions,
716+
hc: float | None = None,
716717
indices=None,
717718
mesh="spherical",
718719
allow_time_extrapolation=None,
@@ -723,11 +724,14 @@ def from_croco(
723724
):
724725
"""Initialises FieldSet object from NetCDF files of CROCO fields.
725726
All parameters and keywords are exactly the same as for FieldSet.from_nemo(), except that
726-
the vertical coordinate is scaled by the bathymetry (``h``) field from CROCO, in order to
727-
account for the sigma-grid. The horizontal interpolation uses the MITgcm grid indexing
728-
as described in FieldSet.from_mitgcm().
727+
in order to scale the vertical coordinate in CROCO, the following fields are required:
728+
the bathymetry (``h``), the sea-surface height (``zeta``), the S-coordinate stretching curves
729+
at W-points (``Cs_w``), and the stretching parameter (``hc``).
730+
The horizontal interpolation uses the MITgcm grid indexing as described in FieldSet.from_mitgcm().
729731
730-
The sigma grid scaling means that FieldSet.from_croco() requires a variable ``H: h`` to work.
732+
In 3D, when there is a ``depth`` dimension, the sigma grid scaling means that FieldSet.from_croco()
733+
requires variables ``H: h`` and ``Zeta: zeta``, ``Cs_w: Cs_w``, as well as the stretching parameter ``hc``
734+
(as an extra input) parameter to work.
731735
732736
See `the CROCO 3D tutorial <../examples/tutorial_croco_3D.ipynb>`__ for more infomation.
733737
"""
@@ -739,14 +743,23 @@ def from_croco(
739743
)
740744

741745
dimsU = dimensions["U"] if "U" in dimensions else dimensions
742-
if "depth" in dimsU:
743-
warnings.warn(
744-
"Note that it is unclear which vertical velocity ('w' or 'omega') to use in 3D CROCO fields.\nSee https://docs.oceanparcels.org/en/latest/examples/tutorial_croco_3D.html for more information",
745-
FieldSetWarning,
746-
stacklevel=2,
747-
)
746+
croco3D = True if "depth" in dimsU else False
747+
748+
if croco3D:
749+
if "W" in variables and variables["W"] == "omega":
750+
warnings.warn(
751+
"Note that Parcels expects 'w' for vertical velicites in 3D CROCO fields.\nSee https://docs.oceanparcels.org/en/latest/examples/tutorial_croco_3D.html for more information",
752+
FieldSetWarning,
753+
stacklevel=2,
754+
)
748755
if "H" not in variables:
749-
raise ValueError("FieldSet.from_croco() requires a field 'H' for the bathymetry")
756+
raise ValueError("FieldSet.from_croco() requires a bathymetry field 'H' for 3D CROCO fields")
757+
if "Zeta" not in variables:
758+
raise ValueError("FieldSet.from_croco() requires a free-surface field 'Zeta' for 3D CROCO fields")
759+
if "Cs_w" not in variables:
760+
raise ValueError(
761+
"FieldSet.from_croco() requires the S-coordinate stretching curves at W-points 'Cs_w' for 3D CROCO fields"
762+
)
750763

751764
interp_method = {}
752765
for v in variables:
@@ -776,6 +789,10 @@ def from_croco(
776789
gridindexingtype="croco",
777790
**kwargs,
778791
)
792+
if croco3D:
793+
if hc is None:
794+
raise ValueError("FieldSet.from_croco() requires the hc parameter for 3D CROCO fields")
795+
fieldset.add_constant("hc", hc)
779796
return fieldset
780797

781798
@classmethod

parcels/include/parcels.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,33 @@ static inline StatusCode temporal_interpolationUVW(type_coord x, type_coord y, t
12421242
return SUCCESS;
12431243
}
12441244

1245+
1246+
static inline double croco_from_z_to_sigma(CField *U, CField *H, CField *Zeta,
1247+
type_coord x, type_coord y, type_coord z, double time,
1248+
int *xi, int *yi, int *zi, int *ti, double hc, float *cs_w)
1249+
{
1250+
float local_h, local_zeta, z0;
1251+
int status, zii;
1252+
CStructuredGrid *grid = U->grid->grid;
1253+
float *sigma_levels = grid->depth;
1254+
int zdim = grid->zdim;
1255+
float zvec[zdim];
1256+
status = temporal_interpolation(x, y, 0, time, H, xi, yi, zi, ti, &local_h, LINEAR, CROCO); CHECKSTATUS(status);
1257+
status = temporal_interpolation(x, y, 0, time, Zeta, xi, yi, zi, ti, &local_zeta, LINEAR, CROCO); CHECKSTATUS(status);
1258+
for (zii = 0; zii < zdim; zii++) {
1259+
z0 = hc*sigma_levels[zii] + (local_h - hc) *cs_w[zii];
1260+
zvec[zii] = z0 + local_zeta * (1 + z0 / local_h);
1261+
}
1262+
if (z >= zvec[zdim-1])
1263+
zii = zdim - 2;
1264+
else
1265+
for (zii = 0; zii < zdim-1; zii++)
1266+
if ((z >= zvec[zii]) && (z < zvec[zii+1]))
1267+
break;
1268+
1269+
return sigma_levels[zii] + (z - zvec[zii]) * (sigma_levels[zii + 1] - sigma_levels[zii]) / (zvec[zii + 1] - zvec[zii]);
1270+
}
1271+
12451272
#ifdef __cplusplus
12461273
}
12471274
#endif

tests/test_advection.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,47 @@ def test_advection_RK45(lon, lat, mode, rk45_tol):
195195
print(fieldset.RK45_tol)
196196

197197

198+
def test_conversion_3DCROCO():
199+
"""Test of the (SciPy) version of the conversion from depth to sigma in CROCO
200+
201+
Values below are retrieved using xroms and hardcoded in the method (to avoid dependency on xroms):
202+
```py
203+
x, y = 10, 20
204+
s_xroms = ds.s_w.values
205+
z_xroms = ds.z_w.isel(time=0).isel(eta_rho=y).isel(xi_rho=x).values
206+
lat, lon = ds.y_rho.values[y, x], ds.x_rho.values[y, x]
207+
```
208+
"""
209+
fieldset = FieldSet.from_modulefile(TEST_DATA / "fieldset_CROCO3D.py")
210+
211+
lat, lon = 78000.0, 38000.0
212+
s_xroms = np.array([-1.0, -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0], dtype=np.float32)
213+
z_xroms = np.array(
214+
[
215+
-1.26000000e02,
216+
-1.10585846e02,
217+
-9.60985413e01,
218+
-8.24131317e01,
219+
-6.94126511e01,
220+
-5.69870148e01,
221+
-4.50318756e01,
222+
-3.34476166e01,
223+
-2.21383114e01,
224+
-1.10107975e01,
225+
2.62768921e-02,
226+
],
227+
dtype=np.float32,
228+
)
229+
230+
sigma = np.zeros_like(z_xroms)
231+
from parcels.field import _croco_from_z_to_sigma_scipy
232+
233+
for zi, z in enumerate(z_xroms):
234+
sigma[zi] = _croco_from_z_to_sigma_scipy(fieldset, 0, z, lat, lon, None)
235+
236+
assert np.allclose(sigma, s_xroms, atol=1e-3)
237+
238+
198239
@pytest.mark.parametrize("mode", ["scipy", "jit"])
199240
def test_advection_3DCROCO(mode):
200241
fieldset = FieldSet.from_modulefile(TEST_DATA / "fieldset_CROCO3D.py")
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import os
22

3+
import xarray as xr
4+
35
import parcels
46

57

68
def create_fieldset(indices=None):
79
example_dataset_folder = parcels.download_example_dataset("CROCOidealized_data")
810
file = os.path.join(example_dataset_folder, "CROCO_idealized.nc")
911

10-
variables = {"U": "u", "V": "v", "W": "w", "H": "h"}
12+
variables = {"U": "u", "V": "v", "W": "w", "H": "h", "Zeta": "zeta", "Cs_w": "Cs_w"}
1113
dimensions = {
1214
"U": {"lon": "x_rho", "lat": "y_rho", "depth": "s_w", "time": "time"},
1315
"V": {"lon": "x_rho", "lat": "y_rho", "depth": "s_w", "time": "time"},
1416
"W": {"lon": "x_rho", "lat": "y_rho", "depth": "s_w", "time": "time"},
1517
"H": {"lon": "x_rho", "lat": "y_rho"},
18+
"Zeta": {"lon": "x_rho", "lat": "y_rho", "time": "time"},
19+
"Cs_w": {"depth": "s_w"},
1620
}
1721
fieldset = parcels.FieldSet.from_croco(
1822
file,
@@ -21,6 +25,7 @@ def create_fieldset(indices=None):
2125
allow_time_extrapolation=True,
2226
mesh="flat",
2327
indices=indices,
28+
hc=xr.open_dataset(file).hc.values,
2429
)
2530

2631
return fieldset

0 commit comments

Comments
 (0)