Skip to content

Commit d4ea6a2

Browse files
committed
Merge branch 'master' into better-prints
2 parents 02d36fc + bbe8448 commit d4ea6a2

14 files changed

Lines changed: 174 additions & 92 deletions

docs/examples/example_globcurrent.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,23 +219,21 @@ def test__particles_init_time():
219219
assert pset[0].time - pset4[0].time == 0
220220

221221

222-
@pytest.mark.xfail(reason="Time extrapolation error expected to be thrown", strict=True)
223222
@pytest.mark.parametrize("mode", ["scipy", "jit"])
224223
@pytest.mark.parametrize("use_xarray", [True, False])
225224
def test_globcurrent_time_extrapolation_error(mode, use_xarray):
226225
fieldset = set_globcurrent_fieldset(use_xarray=use_xarray)
227-
228226
pset = parcels.ParticleSet(
229227
fieldset,
230228
pclass=ptype[mode],
231229
lon=[25],
232230
lat=[-35],
233-
time=fieldset.U.time[0] - timedelta(days=1).total_seconds(),
234-
)
235-
236-
pset.execute(
237-
parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5)
231+
time=fieldset.U.grid.time[0] - timedelta(days=1).total_seconds(),
238232
)
233+
with pytest.raises(parcels.TimeExtrapolationError):
234+
pset.execute(
235+
parcels.AdvectionRK4, runtime=timedelta(days=1), dt=timedelta(minutes=5)
236+
)
239237

240238

241239
@pytest.mark.parametrize("mode", ["scipy", "jit"])

parcels/field.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import collections
2-
import datetime
32
import math
43
import warnings
54
from collections.abc import Iterable
65
from ctypes import POINTER, Structure, c_float, c_int, pointer
76
from pathlib import Path
8-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Literal
98

109
import dask.array as da
1110
import numpy as np
@@ -21,7 +20,7 @@
2120
assert_valid_gridindexingtype,
2221
assert_valid_interp_method,
2322
)
24-
from parcels.tools._helpers import default_repr, deprecated_made_private, field_repr
23+
from parcels.tools._helpers import default_repr, deprecated_made_private, field_repr, timedelta_to_float
2524
from parcels.tools.converters import (
2625
Geographic,
2726
GeographicPolar,
@@ -152,6 +151,7 @@ class Field:
152151

153152
allow_time_extrapolation: bool
154153
time_periodic: TimePeriodic
154+
_cast_data_dtype: type[np.float32] | type[np.float64]
155155

156156
def __init__(
157157
self,
@@ -165,16 +165,16 @@ def __init__(
165165
mesh: Mesh = "flat",
166166
timestamps=None,
167167
fieldtype=None,
168-
transpose=False,
169-
vmin=None,
170-
vmax=None,
171-
cast_data_dtype="float32",
172-
time_origin=None,
168+
transpose: bool = False,
169+
vmin: float | None = None,
170+
vmax: float | None = None,
171+
cast_data_dtype: type[np.float32] | type[np.float64] | Literal["float32", "float64"] = "float32",
172+
time_origin: TimeConverter | None = None,
173173
interp_method: InterpMethod = "linear",
174174
allow_time_extrapolation: bool | None = None,
175175
time_periodic: TimePeriodic = False,
176176
gridindexingtype: GridIndexingType = "nemo",
177-
to_write=False,
177+
to_write: bool = False,
178178
**kwargs,
179179
):
180180
if kwargs.get("netcdf_decodewarning") is not None:
@@ -250,8 +250,8 @@ def __init__(
250250
"Unsupported time_periodic=True. time_periodic must now be either False or the length of the period (either float in seconds or datetime.timedelta object."
251251
)
252252
if self.time_periodic is not False:
253-
if isinstance(self.time_periodic, datetime.timedelta):
254-
self.time_periodic = self.time_periodic.total_seconds()
253+
self.time_periodic = timedelta_to_float(self.time_periodic)
254+
255255
if not np.isclose(self.grid.time[-1] - self.grid.time[0], self.time_periodic):
256256
if self.grid.time[-1] - self.grid.time[0] > self.time_periodic:
257257
raise ValueError("Time series provided is longer than the time_periodic parameter")
@@ -261,11 +261,19 @@ def __init__(
261261

262262
self.vmin = vmin
263263
self.vmax = vmax
264-
self._cast_data_dtype = cast_data_dtype
265-
if self.cast_data_dtype == "float32":
266-
self._cast_data_dtype = np.float32
267-
elif self.cast_data_dtype == "float64":
268-
self._cast_data_dtype = np.float64
264+
265+
match cast_data_dtype:
266+
case "float32":
267+
self._cast_data_dtype = np.float32
268+
case "float64":
269+
self._cast_data_dtype = np.float64
270+
case _:
271+
self._cast_data_dtype = cast_data_dtype
272+
273+
if self.cast_data_dtype not in [np.float32, np.float64]:
274+
raise ValueError(
275+
f"Unsupported cast_data_dtype {self.cast_data_dtype!r}. Choose either: 'float32' or 'float64'"
276+
)
269277

270278
if not self.grid.defer_load:
271279
self.data = self._reshape(self.data, transpose)
@@ -803,7 +811,7 @@ def from_xarray(
803811
lat = da[dimensions["lat"]].values
804812

805813
time_origin = TimeConverter(time[0])
806-
time = time_origin.reltime(time)
814+
time = time_origin.reltime(time) # type: ignore[assignment]
807815

808816
grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
809817
kwargs["time_periodic"] = time_periodic

parcels/fieldfilebuffer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def close(self):
388388
self.chunk_mapping = None
389389

390390
@classmethod
391-
def add_to_dimension_name_map_global(self, name_map):
391+
def add_to_dimension_name_map_global(cls, name_map):
392392
"""
393393
[externally callable]
394394
This function adds entries to the name map from parcels_dim -> netcdf_dim. This is required if you want to
@@ -406,9 +406,9 @@ def add_to_dimension_name_map_global(self, name_map):
406406
for pcls_dim_name in name_map.keys():
407407
if isinstance(name_map[pcls_dim_name], list):
408408
for nc_dim_name in name_map[pcls_dim_name]:
409-
self._static_name_maps[pcls_dim_name].append(nc_dim_name)
409+
cls._static_name_maps[pcls_dim_name].append(nc_dim_name)
410410
elif isinstance(name_map[pcls_dim_name], str):
411-
self._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name])
411+
cls._static_name_maps[pcls_dim_name].append(name_map[pcls_dim_name])
412412

413413
def add_to_dimension_name_map(self, name_map):
414414
"""

parcels/fieldset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ def check_velocityfields(U, V, W):
347347

348348
@classmethod
349349
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
350-
def parse_wildcards(self, *args, **kwargs):
351-
return self._parse_wildcards(*args, **kwargs)
350+
def parse_wildcards(cls, *args, **kwargs):
351+
return cls._parse_wildcards(*args, **kwargs)
352352

353353
@classmethod
354354
def _parse_wildcards(cls, paths, filenames, var):

parcels/kernel.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def __init__(
7777
self.funccode = funccode
7878
self.py_ast = py_ast
7979
self.dyn_srcs = []
80-
self.static_srcs = []
8180
self.src_file = None
8281
self.lib_file = None
8382
self.log_file = None
@@ -562,9 +561,11 @@ def from_list(cls, fieldset, ptype, pyfunc_list, *args, **kwargs):
562561
def cleanup_remove_files(lib_file, all_files_array, delete_cfiles):
563562
if lib_file is not None:
564563
if os.path.isfile(lib_file): # and delete_cfiles
565-
[os.remove(s) for s in [lib_file] if os.path is not None and os.path.exists(s)]
566-
if delete_cfiles and len(all_files_array) > 0:
567-
[os.remove(s) for s in all_files_array if os.path is not None and os.path.exists(s)]
564+
os.remove(lib_file)
565+
if delete_cfiles:
566+
for s in all_files_array:
567+
if os.path.exists(s):
568+
os.remove(s)
568569

569570
@staticmethod
570571
def cleanup_unload_lib(lib):

parcels/particle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,13 @@ def __del__(self):
201201

202202
def __repr__(self):
203203
time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}"
204-
str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
204+
p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
205205
for var in vars(type(self)):
206206
if var in ["lon_nextloop", "lat_nextloop", "depth_nextloop", "time_nextloop"]:
207207
continue
208208
if type(getattr(type(self), var)) is Variable and getattr(type(self), var).to_write is True:
209-
str += f"{var}={getattr(self, var):f}, "
210-
return str + f"time={time_string})"
209+
p_string += f"{var}={getattr(self, var):f}, "
210+
return p_string + f"time={time_string})"
211211

212212
@classmethod
213213
def add_variable(cls, var, *args, **kwargs):

parcels/particledata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def getPType(self):
460460

461461
def __repr__(self):
462462
time_string = "not_yet_set" if self.time is None or np.isnan(self.time) else f"{self.time:f}"
463-
str = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
463+
p_string = "P[%d](lon=%f, lat=%f, depth=%f, " % (self.id, self.lon, self.lat, self.depth)
464464
for var in self._pcoll.ptype.variables:
465465
if var.name in [
466466
"lon_nextloop",
@@ -470,8 +470,8 @@ def __repr__(self):
470470
]: # TODO check if time_nextloop is needed (or can work with time-dt?)
471471
continue
472472
if var.to_write is not False and var.name not in ["id", "lon", "lat", "depth", "time"]:
473-
str += f"{var.name}={getattr(self, var.name):f}, "
474-
return str + f"time={time_string})"
473+
p_string += f"{var.name}={getattr(self, var.name):f}, "
474+
return p_string + f"time={time_string})"
475475

476476
def delete(self):
477477
"""Signal the particle for deletion."""

parcels/particlefile.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import parcels
1212
from parcels._compat import MPI
13-
from parcels.tools._helpers import default_repr, deprecated, deprecated_made_private
13+
from parcels.tools._helpers import default_repr, deprecated, deprecated_made_private, timedelta_to_float
1414
from parcels.tools.warnings import FileWarning
1515

1616
__all__ = ["ParticleFile"]
@@ -48,7 +48,7 @@ class ParticleFile:
4848
"""
4949

5050
def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_zarrfile=True):
51-
self._outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt
51+
self._outputdt = timedelta_to_float(outputdt)
5252
self._chunks = chunks
5353
self._particleset = particleset
5454
self._parcels_mesh = "spherical"
@@ -263,7 +263,7 @@ def _extend_zarr_dims(self, Z, store, dtype, axis):
263263
Z.append(a, axis=axis)
264264
zarr.consolidate_metadata(store)
265265

266-
def write(self, pset, time, indices=None):
266+
def write(self, pset, time: float | timedelta | np.timedelta64 | None, indices=None):
267267
"""Write all data from one time step to the zarr file,
268268
before the particle locations are updated.
269269
@@ -274,7 +274,7 @@ def write(self, pset, time, indices=None):
274274
time :
275275
Time at which to write ParticleSet
276276
"""
277-
time = time.total_seconds() if isinstance(time, timedelta) else time
277+
time = timedelta_to_float(time) if time is not None else None
278278

279279
if pset.particledata._ncount == 0:
280280
warnings.warn(
@@ -305,18 +305,18 @@ def write(self, pset, time, indices=None):
305305
if self.create_new_zarrfile:
306306
if self.chunks is None:
307307
self._chunks = (len(ids), 1)
308-
if pset._repeatpclass is not None and self.chunks[0] < 1e4:
308+
if pset._repeatpclass is not None and self.chunks[0] < 1e4: # type: ignore[index]
309309
warnings.warn(
310310
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
311311
f"a significant slowdown in Parcels when many calls to repeatdt. "
312312
f"Consider setting a larger chunk size for your ParticleFile (e.g. chunks=(int(1e4), 1)).",
313313
FileWarning,
314314
stacklevel=2,
315315
)
316-
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
317-
arrsize = (self._maxids, self.chunks[1])
316+
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
317+
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
318318
else:
319-
arrsize = (len(ids), self.chunks[1])
319+
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
320320
ds = xr.Dataset(
321321
attrs=self.metadata,
322322
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
@@ -341,7 +341,7 @@ def write(self, pset, time, indices=None):
341341
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
342342
dims = ["trajectory", "obs"]
343343
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
344-
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks
344+
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks # type: ignore[index]
345345
ds.to_zarr(self.fname, mode="w")
346346
self._create_new_zarrfile = False
347347
else:

parcels/particleset.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from parcels.particle import JITParticle, Variable
2828
from parcels.particledata import ParticleData, ParticleDataIterator
2929
from parcels.particlefile import ParticleFile
30-
from parcels.tools._helpers import deprecated, deprecated_made_private, particleset_repr
30+
from parcels.tools._helpers import deprecated, deprecated_made_private, particleset_repr, timedelta_to_float
3131
from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array
3232
from parcels.tools.global_statics import get_package_dir
3333
from parcels.tools.loggers import logger
@@ -189,12 +189,13 @@ def ArrayClass_init(self, *args, **kwargs):
189189
lon.size == kwargs[kwvar].size
190190
), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths."
191191

192-
self.repeatdt = repeatdt.total_seconds() if isinstance(repeatdt, timedelta) else repeatdt
192+
self.repeatdt = timedelta_to_float(repeatdt) if repeatdt is not None else None
193+
193194
if self.repeatdt:
194195
if self.repeatdt <= 0:
195-
raise "Repeatdt should be > 0"
196+
raise ValueError("Repeatdt should be > 0")
196197
if time[0] and not np.allclose(time, time[0]):
197-
raise "All Particle.time should be the same when repeatdt is not None"
198+
raise ValueError("All Particle.time should be the same when repeatdt is not None")
198199
self._repeatpclass = pclass
199200
self._repeatkwargs = kwargs
200201
self._repeatkwargs.pop("partition_function", None)
@@ -986,13 +987,13 @@ def execute(
986987
pyfunc=AdvectionRK4,
987988
pyfunc_inter=None,
988989
endtime=None,
989-
runtime=None,
990-
dt=1.0,
990+
runtime: float | timedelta | np.timedelta64 | None = None,
991+
dt: float | timedelta | np.timedelta64 = 1.0,
991992
output_file=None,
992993
verbose_progress=True,
993994
postIterationCallbacks=None,
994-
callbackdt=None,
995-
delete_cfiles=True,
995+
callbackdt: float | timedelta | np.timedelta64 | None = None,
996+
delete_cfiles: bool = True,
996997
):
997998
"""Execute a given kernel function over the particle set for multiple timesteps.
998999
@@ -1072,22 +1073,23 @@ def execute(
10721073
if self.time_origin.calendar is None:
10731074
raise NotImplementedError("If fieldset.time_origin is not a date, execution endtime must be a double")
10741075
endtime = self.time_origin.reltime(endtime)
1075-
if isinstance(runtime, timedelta):
1076-
runtime = runtime.total_seconds()
1077-
if isinstance(dt, timedelta):
1078-
dt = dt.total_seconds()
1076+
1077+
if runtime is not None:
1078+
runtime = timedelta_to_float(runtime)
1079+
1080+
dt = timedelta_to_float(dt)
1081+
10791082
if abs(dt) <= 1e-6:
10801083
raise ValueError("Time step dt is too small")
10811084
if (dt * 1e6) % 1 != 0:
10821085
raise ValueError("Output interval should not have finer precision than 1e-6 s")
1083-
outputdt = output_file.outputdt if output_file else np.inf
1084-
if isinstance(outputdt, timedelta):
1085-
outputdt = outputdt.total_seconds()
1086-
if outputdt is not None:
1086+
outputdt = timedelta_to_float(output_file.outputdt) if output_file else np.inf
1087+
1088+
if np.isfinite(outputdt):
10871089
_warn_outputdt_release_desync(outputdt, self.particledata.data["time_nextloop"])
10881090

1089-
if isinstance(callbackdt, timedelta):
1090-
callbackdt = callbackdt.total_seconds()
1091+
if callbackdt is not None:
1092+
callbackdt = timedelta_to_float(callbackdt)
10911093

10921094
assert runtime is None or runtime >= 0, "runtime must be positive"
10931095
assert outputdt is None or outputdt >= 0, "outputdt must be positive"
@@ -1240,7 +1242,7 @@ def execute(
12401242

12411243
def _warn_outputdt_release_desync(outputdt: float, release_times: Iterable[float]):
12421244
"""Gives the user a warning if the release time isn't a multiple of outputdt."""
1243-
if any(t % outputdt != 0 for t in release_times):
1245+
if any((np.isfinite(t) and t % outputdt != 0) for t in release_times):
12441246
warnings.warn(
12451247
"Some of the particles have a start time that is not a multiple of outputdt. "
12461248
"This could cause the first output to be at a different time than expected.",

parcels/tools/_helpers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
import textwrap
77
import warnings
88
from collections.abc import Callable
9+
from datetime import timedelta
910
from typing import TYPE_CHECKING, Any
1011

12+
import numpy as np
13+
1114
if TYPE_CHECKING:
1215
from parcels import Field, FieldSet, ParticleSet
1316

@@ -134,3 +137,12 @@ def fieldset_repr(fieldset: FieldSet) -> str:
134137

135138
def default_repr(obj: Any):
136139
return object.__repr__(obj)
140+
141+
142+
def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float:
143+
"""Convert a timedelta to a float in seconds."""
144+
if isinstance(dt, timedelta):
145+
return dt.total_seconds()
146+
if isinstance(dt, np.timedelta64):
147+
return float(dt / np.timedelta64(1, "s"))
148+
return float(dt)

0 commit comments

Comments
 (0)