Skip to content

Commit 1ff1758

Browse files
committed
Fix min_count support for multi-dimensional reductions
- Remove restriction that prevented min_count with tuple/list axes - Use np.take(mask.shape, axis).prod() for computing total elements across multiple axes - Add _is_nat_dtype helper to fix NaT dtype comparison bug - Update docstrings for clarity
1 parent e05fdde commit 1ff1758

1 file changed

Lines changed: 21 additions & 10 deletions

File tree

xarray/core/nanops.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,34 @@ def _replace_nan(a, val):
2222
return where_method(val, mask, a), mask
2323

2424

25+
def _is_nat_dtype(dtype):
26+
"""Check if dtype is a datetime64 or timedelta64 type (NaT types).
27+
28+
This is needed because numpy's __eq__ behavior makes dtype in (nat1, nat2)
29+
return True even for non-NAT dtypes due to numpy scalar comparison quirks.
30+
"""
31+
if not isinstance(dtype, np.dtype):
32+
return False
33+
return dtype.kind in "mM" # 'm' for timedelta64, 'M' for datetime64
34+
35+
2536
def _maybe_null_out(result, axis, mask, min_count=1):
2637
"""
2738
xarray version of pandas.core.nanops._maybe_null_out
2839
"""
29-
if hasattr(axis, "__len__"): # if tuple or list
30-
raise ValueError(
31-
"min_count is not available for reduction with more than one dimensions."
32-
)
33-
3440
if axis is not None and getattr(result, "ndim", False):
35-
null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0
41+
if hasattr(axis, "__len__"): # if tuple or list
42+
# For multiple axes, compute total elements along those axes
43+
total_elements = np.take(mask.shape, axis).prod()
44+
null_mask = (total_elements - mask.sum(axis) - min_count) < 0
45+
else:
46+
null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0
3647
if null_mask.any():
3748
dtype, fill_value = dtypes.maybe_promote(result.dtype)
3849
result = result.astype(dtype)
3950
result[null_mask] = fill_value
4051

41-
elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
52+
elif getattr(result, "dtype", None) is not None and not _is_nat_dtype(result.dtype):
4253
null_mask = mask.size - mask.sum()
4354
if null_mask < min_count:
4455
result = np.nan
@@ -47,7 +58,7 @@ def _maybe_null_out(result, axis, mask, min_count=1):
4758

4859

4960
def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
50-
""" In house nanargmin, nanargmax for object arrays. Always return integer
61+
"""In house nanargmin, nanargmax for object arrays. Always return integer
5162
type
5263
"""
5364
valid_count = count(value, axis=axis)
@@ -62,7 +73,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
6273

6374

6475
def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs):
65-
""" In house nanmin and nanmax for object array """
76+
"""In house nanmin and nanmax for object array"""
6677
valid_count = count(value, axis=axis)
6778
filled_value = fillna(value, fill_value)
6879
data = getattr(np, func)(filled_value, axis=axis, **kwargs)
@@ -118,7 +129,7 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None):
118129

119130

120131
def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs):
121-
""" In house nanmean. ddof argument will be used in _nanvar method """
132+
"""In house nanmean. ddof argument will be used in _nanvar method"""
122133
from .duck_array_ops import _dask_or_eager_func, count, fillna, where_method
123134

124135
valid_count = count(value, axis=axis)

0 commit comments

Comments
 (0)