diff --git a/xarray/compat/array_api_compat.py b/xarray/compat/array_api_compat.py index e1e5d5c5bdc..63a7cd8ac44 100644 --- a/xarray/compat/array_api_compat.py +++ b/xarray/compat/array_api_compat.py @@ -1,7 +1,20 @@ +import datetime as dt + import numpy as np from xarray.namedarray.pycompat import array_type +builtin_types = ( + bool, + int, + float, + complex, + str, + bytes, + dt.datetime, + dt.timedelta, +) + def is_weak_scalar_type(t): return isinstance(t, bool | int | float | complex | str | bytes) @@ -38,12 +51,15 @@ def _future_array_api_result_type(*arrays_and_dtypes, xp): def result_type(*arrays_and_dtypes, xp) -> np.dtype: - if xp is np or any( - isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes - ): - return xp.result_type(*arrays_and_dtypes) - else: - return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) + try: + if xp is np or any( + isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes + ): + return xp.result_type(*arrays_and_dtypes) + else: + return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) + except TypeError: + return np.dtype(object) def get_array_namespace(*values): diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index bb2fe26d727..1e9a69ea977 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -278,17 +278,11 @@ def should_promote_to_object( """ np_result_types = set() for arr_or_dtype in arrays_and_dtypes: - try: - result_type = array_api_compat.result_type( - maybe_promote_to_variable_width(arr_or_dtype), xp=xp - ) - if isinstance(result_type, np.dtype): - np_result_types.add(result_type) - except TypeError: - # passing individual objects to xp.result_type (i.e., what `array_api_compat.result_type` calls) means NEP-18 implementations won't have - # a chance to intercept special values (such as NA) that numpy core cannot handle. - # Thus they are considered as types that don't need promotion i.e., the `arr_or_dtype` that rose the `TypeError` will not contribute to `np_result_types`. - pass + result_type = array_api_compat.result_type( + maybe_promote_to_variable_width(arr_or_dtype), xp=xp + ) + if isinstance(result_type, np.dtype): + np_result_types.add(result_type) if np_result_types: for left, right in PROMOTE_TO_OBJECT: @@ -328,6 +322,7 @@ def result_type( if should_promote_to_object(arrays_and_dtypes, xp): return np.dtype(object) + maybe_promote = functools.partial( maybe_promote_to_variable_width, # let extension arrays handle their own str/bytes diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 100c256fa9d..f94487759c2 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -213,7 +213,7 @@ def maybe_coerce_to_str(index, original_coords): try: result_type = dtypes.result_type(*original_coords) - except (TypeError, ValueError): + except ValueError: pass else: if result_type.kind in "SU": diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 4ed66509725..5a772b1691f 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -32,6 +32,10 @@ class DummyArrayAPINamespace: ([np.dtype(" None: