Skip to content

Commit 6d72df1

Browse files
committed
BUG: make _like functions raise if array arg dtype is incompatible with the device
1 parent 6ba0f5b commit 6d72df1

2 files changed

Lines changed: 18 additions & 1 deletion

File tree

array_api_strict/_creation_functions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def empty_like(
168168
_check_device(device)
169169
if device is None:
170170
device = x.device
171+
if dtype is None:
172+
dtype = x.dtype
171173
_check_valid_dtype(dtype, device)
172174

173175
return Array._new(np.empty_like(x._array, dtype=_np_dtype(dtype)), device=device)
@@ -279,6 +281,8 @@ def full_like(
279281
_check_device(device)
280282
if device is None:
281283
device = x.device
284+
if dtype is None:
285+
dtype = x.dtype
282286
_check_valid_dtype(dtype, device)
283287

284288
if not isinstance(fill_value, bool | int | float | complex):
@@ -384,6 +388,8 @@ def ones_like(
384388
_check_device(device)
385389
if device is None:
386390
device = x.device
391+
if dtype is None:
392+
dtype = x.dtype
387393
_check_valid_dtype(dtype, device)
388394

389395
return Array._new(np.ones_like(x._array, dtype=_np_dtype(dtype)), device=device)
@@ -451,6 +457,8 @@ def zeros_like(
451457
_check_device(device)
452458
if device is None:
453459
device = x.device
460+
if dtype is None:
461+
dtype = x.dtype
454462
_check_valid_dtype(dtype, device)
455463

456464
return Array._new(np.zeros_like(x._array, dtype=_np_dtype(dtype)), device=device)

array_api_strict/tests/test_creation_functions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,17 @@ def test_ones_like_etc_incorrect(self, func):
263263
# >>> torch.ones_like(a, device='mps')
264264
# TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework
265265
# doesn't support float64.
266-
with pytest.raises(TypeError):
266+
267+
# incompatible dtype inferred from `a.dtype`
268+
with pytest.raises((TypeError, ValueError)):
267269
func(a, device=Device('F32_device'))
270+
271+
# `a.dtype` is compatible but the explicit dtype= argument is incompatible
272+
a = ones(2, dtype=float32)
273+
with pytest.raises((TypeError, ValueError)):
274+
func(a, device=Device('F32_device'), dtype=float64)
275+
276+
268277
# TODO:
269278
# def asarray(
270279
# def arange(

0 commit comments

Comments
 (0)