Skip to content

Commit ce4024b

Browse files
Support np.timedelta64 as time dimension in Fields (#2039)
* Support np.timedelta64 as time dimension * Remove requirement that time is timedelta64[s] precision * fix test
1 parent b2b0146 commit ce4024b

4 files changed

Lines changed: 64 additions & 23 deletions

File tree

parcels/_core/utils/time.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313

1414

1515
class TimeInterval:
16-
"""A class representing a time interval between two datetime objects.
16+
"""A class representing a time interval between two datetime or np.timedelta64 objects.
1717
1818
Parameters
1919
----------
20-
left : datetime or cftime.datetime
20+
left : np.datetime64 or cftime.datetime or np.timedelta64
2121
The left endpoint of the interval.
22-
right : datetime or cftime.datetime
22+
right : np.datetime64 or cftime.datetime or np.timedelta64
2323
The right endpoint of the interval.
2424
2525
Notes
@@ -28,12 +28,17 @@ class TimeInterval:
2828
"""
2929

3030
def __init__(self, left: T, right: T) -> None:
31-
if not isinstance(left, (datetime, cftime.datetime, np.datetime64)):
32-
raise ValueError(f"Expected right to be a datetime, cftime.datetime, or np.datetime64. Got {type(left)}.")
33-
if not isinstance(right, (datetime, cftime.datetime, np.datetime64)):
34-
raise ValueError(f"Expected right to be a datetime, cftime.datetime, or np.datetime64. Got {type(right)}.")
31+
if not isinstance(left, (np.timedelta64, datetime, cftime.datetime, np.datetime64)):
32+
raise ValueError(
33+
f"Expected right to be a np.timedelta64, datetime, cftime.datetime, or np.datetime64. Got {type(left)}."
34+
)
35+
if not isinstance(right, (np.timedelta64, datetime, cftime.datetime, np.datetime64)):
36+
raise ValueError(
37+
f"Expected right to be a np.timedelta64, datetime, cftime.datetime, or np.datetime64. Got {type(right)}."
38+
)
3539
if left >= right:
3640
raise ValueError(f"Expected left to be strictly less than right, got left={left} and right={right}.")
41+
3742
if not is_compatible(left, right):
3843
raise ValueError(f"Expected left and right to be compatible, got left={left} and right={right}.")
3944

@@ -58,15 +63,26 @@ def intersection(self, other: TimeInterval) -> TimeInterval | None:
5863
"""Return the intersection of two time intervals. Returns None if there is no overlap."""
5964
if not is_compatible(self.left, other.left):
6065
raise ValueError("TimeIntervals are not compatible.")
66+
if not is_compatible(self.right, other.right):
67+
raise ValueError("TimeIntervals are not compatible.")
6168

6269
start = max(self.left, other.left)
6370
end = min(self.right, other.right)
6471

6572
return TimeInterval(start, end) if start <= end else None
6673

6774

68-
def is_compatible(t1: datetime | cftime.datetime, t2: datetime | cftime.datetime) -> bool:
69-
"""Checks whether two (cftime.)datetime objects are compatible."""
75+
def is_compatible(
76+
t1: datetime | cftime.datetime | np.timedelta64, t2: datetime | cftime.datetime | np.timedelta64
77+
) -> bool:
78+
"""
79+
Defines whether two datetime or np.timedelta64 objects are compatible in the context
80+
of being left and right sides of an interval.
81+
"""
82+
# Ensure if either is a timedelta64, both must be
83+
if isinstance(t1, np.timedelta64) ^ isinstance(t2, np.timedelta64):
84+
return False
85+
7086
try:
7187
t1 - t2
7288
except Exception:

tests/v4/test_field.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ def test_field_init_structured_grid(data, grid):
7272
assert field.grid == grid
7373

7474

75-
@pytest.mark.parametrize("numpy_dtype", ["timedelta64[s]", "float64"])
76-
def test_field_init_fail_on_bad_time_type(numpy_dtype):
77-
"""Tests that field initialisation fails when the time isn't given as datetime object (i.e., is float or timedelta)."""
75+
def test_field_init_fail_on_float_time_dim():
76+
"""Test field initialisation fails when given float array as time dimension.
77+
78+
(users are expected to use timedelta64 or datetime).
79+
"""
7880
ds = datasets_structured["ds_2d_left"].copy()
79-
ds["time"] = np.arange(0, T_structured, dtype=numpy_dtype)
81+
ds["time"] = np.arange(0, T_structured, dtype="float64")
8082

8183
data = ds["data_g"]
8284
grid = XGrid(xgcm.Grid(ds))

tests/v4/test_fieldset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ def test_fieldset_add_field_incompatible_calendars(fieldset):
128128
with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"):
129129
fieldset.add_field(field, "test_field")
130130

131+
ds_test = ds.copy()
132+
ds_test["time"] = np.linspace(0, 100, T_structured, dtype="timedelta64[s]")
133+
grid = XGrid(xgcm.Grid(ds_test))
134+
field = Field("test_field", ds_test["data_g"], grid, mesh_type="flat")
135+
136+
with pytest.raises(CalendarError, match="Expected field '.*' to have calendar compatible with datetime object"):
137+
fieldset.add_field(field, "test_field")
138+
131139

132140
@pytest.mark.parametrize(
133141
"input_, expected",

tests/v4/utils/test_time.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from datetime import datetime, timedelta
3+
from datetime import datetime
44

55
import numpy as np
66
import pytest
@@ -11,17 +11,36 @@
1111
from parcels._core.utils.time import TimeInterval
1212

1313
calendar_strategy = st.sampled_from(
14-
["gregorian", "proleptic_gregorian", "365_day", "360_day", "julian", "366_day", np.datetime64, datetime]
14+
[
15+
"gregorian",
16+
"proleptic_gregorian",
17+
"365_day",
18+
"360_day",
19+
"julian",
20+
"366_day",
21+
np.datetime64,
22+
datetime,
23+
np.timedelta64,
24+
]
1525
)
1626

1727

28+
@st.composite
29+
def np_timedelta64_strategy(draw):
30+
"""Strategy for generating np.timedelta64 objects."""
31+
return np.timedelta64(draw(st.integers(1, 60 * 60 * 24 * 100 * 365)), "s")
32+
33+
1834
@st.composite
1935
def datetime_strategy(draw, calendar=None):
36+
if calendar is None:
37+
calendar = draw(calendar_strategy)
38+
if calendar is np.timedelta64:
39+
return draw(np_timedelta64_strategy())
40+
2041
year = draw(st.integers(1900, 2100))
2142
month = draw(st.integers(1, 12))
2243
day = draw(st.integers(1, 28))
23-
if calendar is None:
24-
calendar = draw(calendar_strategy)
2544
if calendar is datetime:
2645
return datetime(year, month, day)
2746
if calendar is np.datetime64:
@@ -34,12 +53,8 @@ def datetime_strategy(draw, calendar=None):
3453
def time_interval_strategy(draw, left=None, calendar=None):
3554
if left is None:
3655
left = draw(datetime_strategy(calendar=calendar))
37-
right = left + draw(
38-
st.timedeltas(
39-
min_value=timedelta(seconds=1),
40-
max_value=timedelta(days=100 * 365),
41-
)
42-
)
56+
right = left + draw(np_timedelta64_strategy())
57+
4358
return TimeInterval(left, right)
4459

4560

0 commit comments

Comments
 (0)