1+ from __future__ import annotations
2+
3+ import warnings
14from types import ModuleType
25from typing import Any
36
1316)
1417from xarray .namedarray .core import NamedArray
1518
19+ with warnings .catch_warnings ():
20+ warnings .filterwarnings (
21+ "ignore" ,
22+ r"The numpy.array_api submodule is still experimental" ,
23+ category = UserWarning ,
24+ )
25+ import numpy .array_api as nxp # noqa: F401
26+
1627
1728def _get_data_namespace (x : NamedArray [Any , Any ]) -> ModuleType :
1829 if isinstance (x ._data , _arrayapi ):
1930 return x ._data .__array_namespace__ ()
20- else :
21- return np
31+
32+ return np
33+
34+
35+ # %% Creation Functions
2236
2337
2438def astype (
@@ -49,18 +63,25 @@ def astype(
4963
5064 Examples
5165 --------
52- >>> narr = NamedArray(("x",), np.array([1.5, 2.5]))
53- >>> astype(narr, np.dtype(int)).data
54- array([1, 2])
66+ >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5]))
67+ >>> narr
68+ <xarray.NamedArray (x: 2)>
69+ Array([1.5, 2.5], dtype=float64)
70+ >>> astype(narr, np.dtype(np.int32))
71+ <xarray.NamedArray (x: 2)>
72+ Array([1, 2], dtype=int32)
5573 """
5674 if isinstance (x ._data , _arrayapi ):
5775 xp = x ._data .__array_namespace__ ()
58- return x ._new (data = xp .astype (x , dtype , copy = copy ))
76+ return x ._new (data = xp .astype (x . _data , dtype , copy = copy ))
5977
6078 # np.astype doesn't exist yet:
6179 return x ._new (data = x ._data .astype (dtype , copy = copy )) # type: ignore[attr-defined]
6280
6381
82+ # %% Elementwise Functions
83+
84+
6485def imag (
6586 x : NamedArray [_ShapeType , np .dtype [_SupportsImag [_ScalarType ]]], / # type: ignore[type-var]
6687) -> NamedArray [_ShapeType , np .dtype [_ScalarType ]]:
@@ -83,8 +104,9 @@ def imag(
83104
84105 Examples
85106 --------
86- >>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j]))
87- >>> imag(narr).data
107+ >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp
108+ >>> imag(narr)
109+ <xarray.NamedArray (x: 2)>
88110 array([2., 4.])
89111 """
90112 xp = _get_data_namespace (x )
@@ -114,9 +136,11 @@ def real(
114136
115137 Examples
116138 --------
117- >>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j]))
118- >>> real(narr).data
139+ >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp
140+ >>> real(narr)
141+ <xarray.NamedArray (x: 2)>
119142 array([1., 2.])
120143 """
121144 xp = _get_data_namespace (x )
122- return x ._new (data = xp .real (x ._data ))
145+ out = x ._new (data = xp .real (x ._data ))
146+ return out
0 commit comments