Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,34 @@ def _replace_nan(a, val):
return where_method(val, mask, a), mask


def _is_nat_dtype(dtype):
"""Check if dtype is a datetime64 or timedelta64 type (NaT types).

This is needed because numpy's __eq__ behavior makes dtype in (nat1, nat2)
return True even for non-NAT dtypes due to numpy scalar comparison quirks.
"""
if not isinstance(dtype, np.dtype):
return False
return dtype.kind in "mM" # 'm' for timedelta64, 'M' for datetime64


def _maybe_null_out(result, axis, mask, min_count=1):
"""
xarray version of pandas.core.nanops._maybe_null_out
"""
if hasattr(axis, "__len__"): # if tuple or list
raise ValueError(
"min_count is not available for reduction with more than one dimensions."
)

if axis is not None and getattr(result, "ndim", False):
null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0
if hasattr(axis, "__len__"): # if tuple or list
# For multiple axes, compute total elements along those axes
total_elements = np.take(mask.shape, axis).prod()
null_mask = (total_elements - mask.sum(axis) - min_count) < 0
else:
null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0
if null_mask.any():
dtype, fill_value = dtypes.maybe_promote(result.dtype)
result = result.astype(dtype)
result[null_mask] = fill_value

elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
elif getattr(result, "dtype", None) is not None and not _is_nat_dtype(result.dtype):
null_mask = mask.size - mask.sum()
if null_mask < min_count:
result = np.nan
Expand All @@ -47,7 +58,7 @@ def _maybe_null_out(result, axis, mask, min_count=1):


def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
""" In house nanargmin, nanargmax for object arrays. Always return integer
"""In house nanargmin, nanargmax for object arrays. Always return integer
type
"""
valid_count = count(value, axis=axis)
Expand All @@ -62,7 +73,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):


def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs):
""" In house nanmin and nanmax for object array """
"""In house nanmin and nanmax for object array"""
valid_count = count(value, axis=axis)
filled_value = fillna(value, fill_value)
data = getattr(np, func)(filled_value, axis=axis, **kwargs)
Expand Down Expand Up @@ -118,7 +129,7 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None):


def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs):
""" In house nanmean. ddof argument will be used in _nanvar method """
"""In house nanmean. ddof argument will be used in _nanvar method"""
from .duck_array_ops import _dask_or_eager_func, count, fillna, where_method

valid_count = count(value, axis=axis)
Expand Down