1- from glob import glob
2-
31import numpy as np
42import parcels
5- import xarray as xr
6- import xgcm
73from parcels .interpolators import XLinear
84
5+ from .catalogs import Catalogs
6+
97runtime = np .timedelta64 (2 , "D" )
108dt = np .timedelta64 (15 , "m" )
119
1210
13- PARCELS_DATADIR = ... # TODO: Replace with intake
14-
15-
16- def download_dataset (* args , ** kwargs ): ... # TODO: Replace with intake
17-
18-
19- def _load_ds (datapath , chunk ):
20- """Helper function to load xarray dataset from datapath with or without chunking"""
21-
22- fileU = f"{ datapath } /psy4v3r1-daily_U_2025-01-0[1-3].nc"
23- filenames = {
24- "U" : glob (fileU ),
25- "V" : glob (fileU .replace ("_U_" , "_V_" )),
26- "W" : glob (fileU .replace ("_U_" , "_W_" )),
27- }
28- mesh_mask = f"{ datapath } /PSY4V3R1_mesh_hgr.nc"
29- fileargs = {
30- "concat_dim" : "time_counter" ,
31- "combine" : "nested" ,
32- "data_vars" : "minimal" ,
33- "coords" : "minimal" ,
34- "compat" : "override" ,
35- }
36- if chunk :
37- fileargs ["chunks" ] = {"time_counter" : 1 , "depth" : 2 , "y" : chunk , "x" : chunk }
11+ def _load_ds (chunk ):
12+ """Helper function to load xarray dataset from catalog with or without chunking"""
13+ cat = Catalogs .CAT_BENCHMARKS
14+ chunks = {"time_counter" : 1 , "depth" : 2 , "y" : chunk , "x" : chunk } if chunk else None
3815
39- ds_u = xr . open_mfdataset ( filenames [ "U" ], ** fileargs )[[ "vozocrtx" ]]. drop_vars (
40- [ "nav_lon" , "nav_lat" ]
16+ ds_u = (
17+ cat . moi_u ( chunks = chunks ). to_dask ()[[ "vozocrtx" ]]. rename_vars ({ "vozocrtx" : "U" })
4118 )
42- ds_v = xr . open_mfdataset ( filenames [ "V" ], ** fileargs )[[ "vomecrty" ]]. drop_vars (
43- [ "nav_lon" , "nav_lat" ]
19+ ds_v = (
20+ cat . moi_v ( chunks = chunks ). to_dask ()[[ "vomecrty" ]]. rename_vars ({ "vomecrty" : "V" })
4421 )
45- ds_depth = xr .open_mfdataset (filenames ["W" ], ** fileargs )[["depthw" ]]
46- ds_mesh = xr .open_dataset (mesh_mask )[["glamf" , "gphif" ]].isel (t = 0 )
47-
48- ds = xr .merge ([ds_u , ds_v , ds_depth , ds_mesh ], compat = "identical" )
49- ds = ds .rename (
50- {
51- "vozocrtx" : "U" ,
52- "vomecrty" : "V" ,
53- "glamf" : "lon" ,
54- "gphif" : "lat" ,
55- "time_counter" : "time" ,
56- "depthw" : "depth" ,
57- }
58- )
59- ds .deptht .attrs ["c_grid_axis_shift" ] = - 0.5
22+ da_depth = cat .moi_w (chunks = chunks ).to_dask ()["depthw" ]
23+ ds_mesh = cat .moi_mesh (chunks = None ).read ()[["glamf" , "gphif" ]].isel (t = 0 )
24+ ds_mesh ["depthw" ] = da_depth
25+ ds = parcels .convert .nemo_to_sgrid (fields = dict (U = ds_u , V = ds_v ), coords = ds_mesh )
6026
6127 return ds
6228
@@ -75,47 +41,26 @@ class MOICurvilinear:
7541 "npart" ,
7642 ]
7743
78- def setup (self , interpolator , chunk , npart ):
79- self .datapath = download_dataset ("MOi-curvilinear" , data_home = PARCELS_DATADIR )
80-
8144 def time_load_data_3d (self , interpolator , chunk , npart ):
8245 """Benchmark that times loading the 'U' and 'V' data arrays only for 3-D"""
8346
8447 # To have a reasonable runtime, we only consider the time it takes to load two time levels
8548 # and two depth levels (at most)
86- ds = _load_ds (self . datapath , chunk )
49+ ds = _load_ds (chunk )
8750 for j in range (min (ds .coords ["deptht" ].size , 2 )):
8851 for i in range (min (ds .coords ["time" ].size , 2 )):
8952 _u = ds ["U" ].isel (deptht = j , time = i ).compute ()
9053 _v = ds ["V" ].isel (deptht = j , time = i ).compute ()
9154
9255 def pset_execute_3d (self , interpolator , chunk , npart ):
93- ds = _load_ds (self .datapath , chunk )
94- coords = {
95- "X" : {"left" : "x" },
96- "Y" : {"left" : "y" },
97- "Z" : {"center" : "deptht" , "left" : "depth" },
98- "T" : {"center" : "time" },
99- }
100-
101- grid = parcels ._core .xgrid .XGrid (
102- xgcm .Grid (ds , coords = coords , autoparse_metadata = False , periodic = False ),
103- mesh = "spherical" ,
104- )
105-
56+ ds = _load_ds (chunk )
57+ fieldset = parcels .FieldSet .from_sgrid_conventions (ds )
10658 if interpolator == "XLinear" :
107- interp_method = XLinear
59+ fieldset .U .interp_method = XLinear
60+ fieldset .V .interp_method = XLinear
10861 else :
10962 raise ValueError (f"Unknown interpolator: { interpolator } " )
11063
111- U = parcels .Field ("U" , ds ["U" ], grid , interp_method = interp_method )
112- V = parcels .Field ("V" , ds ["V" ], grid , interp_method = interp_method )
113- U .units = parcels .GeographicPolar ()
114- V .units = parcels .Geographic ()
115- UV = parcels .VectorField ("UV" , U , V )
116-
117- fieldset = parcels .FieldSet ([U , V , UV ])
118-
11964 pclass = parcels .Particle
12065
12166 lon = np .linspace (- 10 , 10 , npart )
0 commit comments