Skip to content

Commit 8d4eb67

Browse files
committed
ENH: asarray handles f32-only device
1 parent 59a143c commit 8d4eb67

2 files changed

Lines changed: 68 additions & 3 deletions

File tree

array_api_strict/_creation_functions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ def asarray(
106106
raise OverflowError("Integer out of bounds for array dtypes")
107107

108108
res = np.array(obj, dtype=_np_dtype, copy=copy)
109+
110+
# numpy default dtype may differ; if so, adjust the dtype
111+
if dtype is None and device is not None:
112+
res_dtype = DType(res.dtype)
113+
if not device_supports_dtype(device, res_dtype):
114+
# find out the default dtype for the device
115+
from ._data_type_functions import isdtype
116+
if isdtype(res_dtype, "bool"):
117+
targ_dtype = DType("bool")
118+
elif isdtype(res_dtype, "integral"):
119+
targ_dtype = get_default_dtypes(device)["integral"]
120+
elif isdtype(res_dtype, "real floating"):
121+
targ_dtype = get_default_dtypes(device)["real floating"]
122+
elif isdtype(res_dtype, "complex floating"):
123+
targ_dtype = get_default_dtypes(device)["complex floating"]
124+
else:
125+
raise ValueError(f"{res_dtype = } not understood.")
126+
del isdtype
127+
128+
res = res.astype(targ_dtype._np_dtype)
129+
109130
return Array._new(res, device=device)
110131

111132

array_api_strict/tests/test_creation_functions.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
zeros,
2323
zeros_like,
2424
)
25-
from .._dtypes import float32, float64
25+
from .._dtypes import float32, float64, bool as xp_bool
2626
from .._array_object import Array
2727
from .._devices import CPU_DEVICE, ALL_DEVICES, Device
2828
from .._info import __array_namespace_info__
@@ -250,6 +250,7 @@ def test_ones_like_etc_correct(self, func):
250250
device = Device('F32_device')
251251
b = func(a, device=device)
252252
assert b.dtype == self.info.default_dtypes(device=device)["real floating"]
253+
assert b.device == device
253254

254255
@pytest.mark.parametrize("func", [empty_like, zeros_like, ones_like, _full_like])
255256
def test_ones_like_etc_incorrect(self, func):
@@ -277,6 +278,7 @@ def test_eye(self):
277278
device = Device('F32_device')
278279
a = eye(3, device=device)
279280
assert a.dtype == self.info.default_dtypes(device=device)["real floating"]
281+
assert a.device == device
280282

281283
with pytest.raises((TypeError, ValueError)):
282284
eye(3, device=device, dtype=float64)
@@ -286,6 +288,7 @@ def test_linspace(self):
286288

287289
a = linspace(1, 10, 11, device=device)
288290
assert a.dtype == self.info.default_dtypes(device=device)["real floating"]
291+
assert a.device == device
289292

290293
a = linspace(1+0j, 10, 11, device=device)
291294
assert a.dtype == self.info.default_dtypes(device=device)["complex floating"]
@@ -298,18 +301,59 @@ def test_arange(self):
298301

299302
a = arange(0, 10, 1, device=device)
300303
assert a.dtype == self.info.default_dtypes(device=device)["integral"]
304+
assert a.device == device
301305

302306
a = arange(0.0, 10, 1, device=device)
303307
assert a.dtype == self.info.default_dtypes(device=device)["real floating"]
308+
assert a.device == device
304309

305310
with pytest.raises((TypeError, ValueError)):
306311
arange(0, 10, 1, device=device, dtype=float64)
307312

308313
with pytest.raises((TypeError, ValueError)):
309314
arange(0.0, 10, 1, device=device, dtype=float64)
310315

311-
# TODO:
312-
# def asarray(
316+
def test_asarray(self):
317+
device = Device('F32_device')
318+
319+
### asarray(python_object)
320+
for x in (True, [False,]):
321+
arr = asarray(x, device=device)
322+
assert arr.dtype == xp_bool
323+
assert arr.device == device
324+
325+
for x in [1, [1,]]:
326+
arr = asarray(x, device=device)
327+
assert arr.dtype == self.info.default_dtypes(device=device)['integral']
328+
assert arr.device == device
329+
330+
for x in [1.0, [1.0,]]:
331+
arr = asarray(x, device=device)
332+
assert arr.dtype == self.info.default_dtypes(device=device)['real floating']
333+
assert arr.device == device
334+
335+
for x in [1j, [1j,]]:
336+
arr = asarray(x, device=device)
337+
assert arr.dtype == self.info.default_dtypes(device=device)['complex floating']
338+
assert arr.device == device
339+
340+
# asarray(python_object, dtype=unsupported_by_device)
341+
with pytest.raises(ValueError, match="Device"):
342+
asarray(1, dtype=float64, device=device)
343+
344+
### asarray(array)
345+
346+
# compatible dtypes, device transfer
347+
src = asarray(1, dtype=float32, device=Device('device1'))
348+
dst = asarray(src, device=device)
349+
assert dst.device == device
350+
assert dst.dtype == float32
351+
352+
# incompatible dtypes, device transfer
353+
src = asarray(1, dtype=float64, device=Device('device1'))
354+
355+
with pytest.raises(ValueError, match="Device"):
356+
asarray(src, device=device)
313357

314358

315359
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])

0 commit comments

Comments
 (0)