Skip to content

Commit 156a2b1

Browse files
committed
squash
1 parent daeea1d commit 156a2b1

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

array_api_strict/_array_object.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class Device:
4949
__slots__ = ("_device", "__weakref__")
5050

5151
def __init__(self, device: str = "CPU_DEVICE"):
52-
if device not in ("CPU_DEVICE", "device1", "device2"):
52+
if device not in ("CPU_DEVICE", "device1", "device2", "F32_device"):
5353
raise ValueError(f"The device '{device}' is not a valid choice.")
5454
self._device = device
5555

@@ -66,7 +66,7 @@ def __hash__(self) -> int:
6666

6767

6868
CPU_DEVICE = Device()
69-
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"))
69+
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"), Device("F32_device"))
7070

7171

7272
class Array:

array_api_strict/tests/test_device_support.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,20 @@ def test_fft_device_support_real(func_name):
3636
y = func(x)
3737

3838
assert x.device == y.device
39+
40+
41+
class TestF32Device:
42+
@pytest.mark.parametrize("dtype_str", ["float64", "complex128"])
43+
def test_f64_raises(self, dtype_str):
44+
f32_device = array_api_strict.Device("F32_device")
45+
dtype = getattr(array_api_strict, dtype_str)
46+
with pytest.raises(ValueError):
47+
array_api_strict.arange(3, device=f32_device, dtype=dtype)
48+
49+
def test_info_no_f64(self):
50+
f32_device = array_api_strict.Device("F32_device")
51+
52+
info = array_api_strict.__array_namespace_info__()
53+
all_dtypes = info.dtypes(device=f32_device)
54+
assert "float64" not in all_dtypes
55+
assert "complex128" not in all_dtypes

0 commit comments

Comments
 (0)