1818from numpy import any as array_any # noqa
1919from numpy import (
2020 around , # noqa
21+ einsum , # noqa
22+ gradient , # noqa
2123 isclose ,
24+ isin , # noqa
2225 isnat ,
2326 take , # noqa
24- zeros_like , # noqa
27+ tensordot , # noqa
28+ transpose , # noqa
29+ unravel_index , # noqa
30+ zeros_like ,
2531)
2632from numpy import concatenate as _concatenate
2733from numpy .core .multiarray import normalize_axis_index # type: ignore[attr-defined]
2834from numpy .lib .stride_tricks import sliding_window_view # noqa
2935
3036from xarray .core import dask_array_ops , dtypes , nputils
3137from xarray .core .utils import module_available
38+ from xarray .namedarray ._array_api import _get_data_namespace
3239from xarray .namedarray ._typing import _arrayfunction_or_api
3340from xarray .namedarray .parallelcompat import get_chunked_array_type , is_chunked_array
3441from xarray .namedarray .pycompat import array_type
3744dask_available = module_available ("dask" )
3845
3946
40- def get_array_namespace (x ):
41- if hasattr (x , "__array_namespace__" ):
42- return x .__array_namespace__ ()
43- else :
44- return np
45-
46-
4747def _dask_or_eager_func (
4848 name ,
4949 eager_module = np ,
@@ -121,7 +121,7 @@ def isnull(data):
121121 return isnat (data )
122122 elif issubclass (scalar_type , np .inexact ):
123123 # float types use NaN for null
124- xp = get_array_namespace (data )
124+ xp = _get_data_namespace (data )
125125 return xp .isnan (data )
126126 elif issubclass (scalar_type , (np .bool_ , np .integer , np .character , np .void )):
127127 # these types cannot represent missing values
@@ -179,7 +179,7 @@ def cumulative_trapezoid(y, x, axis):
179179
180180def astype (data , dtype , ** kwargs ):
181181 if hasattr (data , "__array_namespace__" ):
182- xp = get_array_namespace (data )
182+ xp = _get_data_namespace (data )
183183 if xp == np :
184184 # numpy currently doesn't have a astype:
185185 return data .astype (dtype , ** kwargs )
@@ -211,7 +211,7 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
211211
212212
213213def broadcast_to (array , shape ):
214- xp = get_array_namespace (array )
214+ xp = _get_data_namespace (array )
215215 return xp .broadcast_to (array , shape )
216216
217217
@@ -289,7 +289,7 @@ def count(data, axis=None):
289289
290290
291291def sum_where (data , axis = None , dtype = None , where = None ):
292- xp = get_array_namespace (data )
292+ xp = _get_data_namespace (data )
293293 if where is not None :
294294 a = where_method (xp .zeros_like (data ), where , data )
295295 else :
@@ -300,7 +300,7 @@ def sum_where(data, axis=None, dtype=None, where=None):
300300
301301def where (condition , x , y ):
302302 """Three argument where() with better dtype promotion rules."""
303- xp = get_array_namespace (condition )
303+ xp = _get_data_namespace (condition )
304304 return xp .where (condition , * as_shared_dtype ([x , y ], xp = xp ))
305305
306306
@@ -320,19 +320,19 @@ def fillna(data, other):
320320def concatenate (arrays , axis = 0 ):
321321 """concatenate() with better dtype promotion rules."""
322322 if hasattr (arrays [0 ], "__array_namespace__" ):
323- xp = get_array_namespace (arrays [0 ])
323+ xp = _get_data_namespace (arrays [0 ])
324324 return xp .concat (as_shared_dtype (arrays , xp = xp ), axis = axis )
325325 return _concatenate (as_shared_dtype (arrays ), axis = axis )
326326
327327
328328def stack (arrays , axis = 0 ):
329329 """stack() with better dtype promotion rules."""
330- xp = get_array_namespace (arrays [0 ])
330+ xp = _get_data_namespace (arrays [0 ])
331331 return xp .stack (as_shared_dtype (arrays , xp = xp ), axis = axis )
332332
333333
334334def reshape (array , shape ):
335- xp = get_array_namespace (array )
335+ xp = _get_data_namespace (array )
336336 return xp .reshape (array , shape )
337337
338338
@@ -376,7 +376,7 @@ def f(values, axis=None, skipna=None, **kwargs):
376376 if name in ["sum" , "prod" ]:
377377 kwargs .pop ("min_count" , None )
378378
379- xp = get_array_namespace (values )
379+ xp = _get_data_namespace (values )
380380 func = getattr (xp , name )
381381
382382 try :
0 commit comments