3434from numpy .lib .stride_tricks import sliding_window_view # noqa
3535
3636from xarray .core import dask_array_ops , dtypes , nputils
37- from xarray .core .utils import module_available
38- from xarray .namedarray ._array_api import _get_data_namespace
3937from xarray .namedarray ._typing import _arrayfunction_or_api
4038from xarray .namedarray .parallelcompat import get_chunked_array_type , is_chunked_array
4139from xarray .namedarray .pycompat import array_type
42- from xarray .namedarray .utils import is_duck_dask_array
40+ from xarray .namedarray .utils import is_duck_dask_array , module_available
4341
4442dask_available = module_available ("dask" )
4543
4644
45+ def get_array_namespace (x ):
46+ if hasattr (x , "__array_namespace__" ):
47+ return x .__array_namespace__ ()
48+ else :
49+ return np
50+
51+
4752def _dask_or_eager_func (
4853 name ,
4954 eager_module = np ,
@@ -121,7 +126,7 @@ def isnull(data):
121126 return isnat (data )
122127 elif issubclass (scalar_type , np .inexact ):
123128 # float types use NaN for null
124- xp = _get_data_namespace (data )
129+ xp = get_array_namespace (data )
125130 return xp .isnan (data )
126131 elif issubclass (scalar_type , (np .bool_ , np .integer , np .character , np .void )):
127132 # these types cannot represent missing values
@@ -179,7 +184,7 @@ def cumulative_trapezoid(y, x, axis):
179184
180185def astype (data , dtype , ** kwargs ):
181186 if hasattr (data , "__array_namespace__" ):
182- xp = _get_data_namespace (data )
187+ xp = get_array_namespace (data )
183188 if xp == np :
184189 # numpy currently doesn't have a astype:
185190 return data .astype (dtype , ** kwargs )
@@ -211,7 +216,7 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
211216
212217
213218def broadcast_to (array , shape ):
214- xp = _get_data_namespace (array )
219+ xp = get_array_namespace (array )
215220 return xp .broadcast_to (array , shape )
216221
217222
@@ -289,7 +294,7 @@ def count(data, axis=None):
289294
290295
291296def sum_where (data , axis = None , dtype = None , where = None ):
292- xp = _get_data_namespace (data )
297+ xp = get_array_namespace (data )
293298 if where is not None :
294299 a = where_method (xp .zeros_like (data ), where , data )
295300 else :
@@ -300,7 +305,7 @@ def sum_where(data, axis=None, dtype=None, where=None):
300305
301306def where (condition , x , y ):
302307 """Three argument where() with better dtype promotion rules."""
303- xp = _get_data_namespace (condition )
308+ xp = get_array_namespace (condition )
304309 return xp .where (condition , * as_shared_dtype ([x , y ], xp = xp ))
305310
306311
@@ -320,19 +325,19 @@ def fillna(data, other):
320325def concatenate (arrays , axis = 0 ):
321326 """concatenate() with better dtype promotion rules."""
322327 if hasattr (arrays [0 ], "__array_namespace__" ):
323- xp = _get_data_namespace (arrays [0 ])
328+ xp = get_array_namespace (arrays [0 ])
324329 return xp .concat (as_shared_dtype (arrays , xp = xp ), axis = axis )
325330 return _concatenate (as_shared_dtype (arrays ), axis = axis )
326331
327332
328333def stack (arrays , axis = 0 ):
329334 """stack() with better dtype promotion rules."""
330- xp = _get_data_namespace (arrays [0 ])
335+ xp = get_array_namespace (arrays [0 ])
331336 return xp .stack (as_shared_dtype (arrays , xp = xp ), axis = axis )
332337
333338
334339def reshape (array , shape ):
335- xp = _get_data_namespace (array )
340+ xp = get_array_namespace (array )
336341 return xp .reshape (array , shape )
337342
338343
@@ -376,7 +381,7 @@ def f(values, axis=None, skipna=None, **kwargs):
376381 if name in ["sum" , "prod" ]:
377382 kwargs .pop ("min_count" , None )
378383
379- xp = _get_data_namespace (values )
384+ xp = get_array_namespace (values )
380385 func = getattr (xp , name )
381386
382387 try :
0 commit comments