Skip to content

Commit e80d3cb

Browse files
spencerkclarkjsignell
authored andcommitted
Adapt to deprecation of generic np.timedelta64 dtype (pydata#11281)
This PR modifies the few places we relied on a generic `np.timedelta64` dtype to explicitly specify the time resolution: - It removes `NAT_TYPES` and relies instead on checking the `dtype.kind` in `computation.nanops._maybe_null_out`. - It infers the time `unit` using `np.datetime_data` from the input `dtype` to determine the `unit` on the returned `fill_value` in `core.dtypes.maybe_promote`. - It explicitly constructs a zero-valued `np.timedelta64` or `np.datetime64` object for use downstream in `plot.utils._determine_cmap_params`.
1 parent cdf4443 commit e80d3cb

File tree

4 files changed

+29
-18
lines changed

4 files changed

+29
-18
lines changed

xarray/computation/nanops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def _maybe_null_out(result, axis, mask, min_count=1):
2929
dtype, fill_value = dtypes.maybe_promote(result.dtype)
3030
result = where(null_mask, fill_value, astype(result, dtype))
3131

32-
elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
32+
elif (dtype := getattr(result, "dtype", None)) and getattr(
33+
dtype, "kind", None
34+
) not in {"m", "M"}:
3335
null_mask = mask.size - duck_array_ops.sum(mask)
3436
result = where(null_mask < min_count, np.nan, result)
3537

xarray/core/dtypes.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from xarray.compat import array_api_compat, npcompat
1111
from xarray.compat.npcompat import HAS_STRING_DTYPE
1212
from xarray.core import utils
13+
from xarray.core.types import PDDatetimeUnitOptions
1314

1415
if TYPE_CHECKING:
1516
from typing import Any
@@ -88,7 +89,11 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
8889
# See https://github.com/numpy/numpy/issues/10685
8990
# np.timedelta64 is a subclass of np.integer
9091
# Check np.timedelta64 before np.integer
91-
fill_value = np.timedelta64("NaT")
92+
unit, _ = np.datetime_data(dtype)
93+
# np.datetime_data returns a generic str for the unit so we need to
94+
# cast it to a valid time unit for mypy purposes.
95+
unit = cast(PDDatetimeUnitOptions, unit)
96+
fill_value = np.timedelta64("NaT", unit)
9297
dtype_ = dtype
9398
elif isdtype(dtype, "integral"):
9499
dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
@@ -97,8 +102,12 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
97102
dtype_ = dtype
98103
fill_value = np.nan + np.nan * 1j
99104
elif np.issubdtype(dtype, np.datetime64):
105+
unit, _ = np.datetime_data(dtype)
106+
# np.datetime_data returns a generic str for the unit so we need to
107+
# cast it to a valid time unit for mypy purposes.
108+
unit = cast(PDDatetimeUnitOptions, unit)
100109
dtype_ = dtype
101-
fill_value = np.datetime64("NaT")
110+
fill_value = np.datetime64("NaT", unit)
102111
else:
103112
dtype_ = object
104113
fill_value = np.nan
@@ -108,9 +117,6 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
108117
return dtype_out, fill_value
109118

110119

111-
NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}
112-
113-
114120
def get_fill_value(dtype):
115121
"""Return an appropriate fill value for this dtype.
116122

xarray/plot/utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,15 @@ def _determine_cmap_params(
189189
else:
190190
mpl = attempt_import("matplotlib")
191191

192+
if plot_data.dtype.kind == "m":
193+
unit, _ = np.datetime_data(plot_data.dtype)
194+
zero = np.timedelta64(0, unit)
195+
elif plot_data.dtype.kind == "M":
196+
unit, _ = np.datetime_data(plot_data.dtype)
197+
zero = np.datetime64(0, unit)
198+
else:
199+
zero = 0.0
200+
192201
if isinstance(levels, Iterable):
193202
levels = sorted(levels)
194203

@@ -197,15 +206,15 @@ def _determine_cmap_params(
197206
# Handle all-NaN input data gracefully
198207
if calc_data.size == 0:
199208
# Arbitrary default for when all values are NaN
200-
calc_data = np.array(0.0)
209+
calc_data = np.array(zero)
201210

202211
# Setting center=False prevents a divergent cmap
203212
possibly_divergent = center is not False
204213

205214
# Set center to 0 so math below makes sense but remember its state
206215
center_is_none = False
207216
if center is None:
208-
center = 0
217+
center = zero
209218
center_is_none = True
210219

211220
# Setting both vmin and vmax prevents a divergent cmap
@@ -240,10 +249,10 @@ def _determine_cmap_params(
240249

241250
if possibly_divergent:
242251
levels_are_divergent = (
243-
isinstance(levels, Iterable) and levels[0] * levels[-1] < 0
252+
isinstance(levels, Iterable) and levels[0] * levels[-1] < zero
244253
)
245254
# kwargs not specific about divergent or not: infer defaults from data
246-
divergent = (vmin < 0 < vmax) or not center_is_none or levels_are_divergent
255+
divergent = (vmin < zero < vmax) or not center_is_none or levels_are_divergent
247256
else:
248257
divergent = False
249258

xarray/tests/test_dtypes.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def test_inf(obj) -> None:
102102
("I", (np.float64, "nan")), # dtype('uint32')
103103
("l", (np.float64, "nan")), # dtype('int64')
104104
("L", (np.float64, "nan")), # dtype('uint64')
105-
("m", (np.timedelta64, "NaT")), # dtype('<m8')
106-
("M", (np.datetime64, "NaT")), # dtype('<M8')
105+
("<m8[ns]", (np.dtype("<m8[ns]"), "NaT")), # dtype('<m8[ns]')
106+
("<M8[ns]", (np.dtype("<M8[ns]"), "NaT")), # dtype('<M8[ns]')
107107
("O", (np.dtype("O"), "nan")), # dtype('O')
108108
("p", (np.float64, "nan")), # dtype('int64')
109109
("P", (np.float64, "nan")), # dtype('uint64')
@@ -123,12 +123,6 @@ def test_maybe_promote(kind, expected) -> None:
123123
assert str(actual[1]) == expected[1]
124124

125125

126-
def test_nat_types_membership() -> None:
127-
assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES
128-
assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES
129-
assert np.float64 not in dtypes.NAT_TYPES
130-
131-
132126
@pytest.mark.parametrize(
133127
["dtype", "kinds", "xp", "expected"],
134128
(

0 commit comments

Comments
 (0)