Skip to content

Commit abf8275

Browse files
committed
WIP
1 parent 156a2b1 commit abf8275

6 files changed

Lines changed: 209 additions & 86 deletions

File tree

array_api_strict/_array_object.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections.abc import Iterator
2121
from enum import IntEnum
2222
from types import EllipsisType, ModuleType
23-
from typing import Any, Final, Literal, SupportsIndex, Callable
23+
from typing import Any, Literal, SupportsIndex, Callable
2424

2525
import numpy as np
2626
import numpy.typing as npt
@@ -40,33 +40,11 @@
4040
_real_to_complex_map,
4141
_result_type,
4242
)
43+
from ._devices import CPU_DEVICE, ALL_DEVICES, Device
4344
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags
4445
from ._typing import PyCapsule
4546

4647

47-
class Device:
48-
_device: Final[str]
49-
__slots__ = ("_device", "__weakref__")
50-
51-
def __init__(self, device: str = "CPU_DEVICE"):
52-
if device not in ("CPU_DEVICE", "device1", "device2", "F32_device"):
53-
raise ValueError(f"The device '{device}' is not a valid choice.")
54-
self._device = device
55-
56-
def __repr__(self) -> str:
57-
return f"array_api_strict.Device('{self._device}')"
58-
59-
def __eq__(self, other: object) -> bool:
60-
if not isinstance(other, Device):
61-
return False
62-
return self._device == other._device
63-
64-
def __hash__(self) -> int:
65-
return hash(("Device", self._device))
66-
67-
68-
CPU_DEVICE = Device()
69-
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"), Device("F32_device"))
7048

7149

7250
class Array:

array_api_strict/_creation_functions.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
from ._dtypes import DType, _all_dtypes, _np_dtype
9+
from ._devices import CPU_DEVICE, Device, device_supports_dtype, check_device as _check_device
910
from ._flags import get_array_api_strict_flags
1011
from ._typing import NestedSequence, SupportsBufferProtocol, SupportsDLPack
1112

@@ -14,7 +15,7 @@
1415
from typing_extensions import TypeIs
1516

1617
# Circular import
17-
from ._array_object import Array, Device
18+
from ._array_object import Array
1819

1920

2021
class Undef(Enum):
@@ -24,10 +25,22 @@ class Undef(Enum):
2425
_undef = Undef.UNDEF
2526

2627

27-
def _check_valid_dtype(dtype: DType | None) -> None:
28+
def _check_valid_dtype(dtype: DType | None, device: Device | None = None) -> None:
2829
# Note: Only spelling dtypes as the dtype objects is supported.
29-
if dtype not in (None,) + _all_dtypes:
30-
raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}")
30+
31+
if dtype is not None:
32+
if dtype not in _all_dtypes:
33+
raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}")
34+
35+
if device is not None:
36+
if not device_supports_dtype(device, dtype):
37+
raise ValueError(f"Device {device!r} does not support dtype={dtype!r}.")
38+
return dtype
39+
else:
40+
# if dtype=None, return the default for the device
41+
_device = CPU_DEVICE if device is None else device
42+
43+
return
3144

3245

3346
def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]:
@@ -38,18 +51,6 @@ def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]:
3851
return True
3952

4053

41-
def _check_device(device: Device | None) -> None:
42-
# _array_object imports in this file are inside the functions to avoid
43-
# circular imports
44-
from ._array_object import ALL_DEVICES, Device
45-
46-
if device is not None and not isinstance(device, Device):
47-
raise ValueError(f"Unsupported device {device!r}")
48-
49-
if device is not None and device not in ALL_DEVICES:
50-
raise ValueError(f"Unsupported device {device!r}")
51-
52-
5354
def asarray(
5455
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
5556
/,
@@ -65,11 +66,12 @@ def asarray(
6566
"""
6667
from ._array_object import Array
6768

68-
_check_valid_dtype(dtype)
69+
_check_device(device)
70+
_check_valid_dtype(dtype, device)
6971
_np_dtype = None
7072
if dtype is not None:
7173
_np_dtype = dtype._np_dtype
72-
_check_device(device)
74+
7375
if isinstance(obj, Array) and device is None:
7476
device = obj.device
7577

@@ -127,8 +129,8 @@ def arange(
127129
"""
128130
from ._array_object import Array
129131

130-
_check_valid_dtype(dtype)
131132
_check_device(device)
133+
_check_valid_dtype(dtype, device)
132134

133135
return Array._new(
134136
np.arange(start, stop, step, dtype=_np_dtype(dtype)),
@@ -149,8 +151,8 @@ def empty(
149151
"""
150152
from ._array_object import Array
151153

152-
_check_valid_dtype(dtype)
153154
_check_device(device)
155+
_check_valid_dtype(dtype, device)
154156

155157
return Array._new(np.empty(shape, dtype=_np_dtype(dtype)), device=device)
156158

@@ -165,10 +167,10 @@ def empty_like(
165167
"""
166168
from ._array_object import Array
167169

168-
_check_valid_dtype(dtype)
169170
_check_device(device)
170171
if device is None:
171172
device = x.device
173+
_check_valid_dtype(dtype, device)
172174

173175
return Array._new(np.empty_like(x._array, dtype=_np_dtype(dtype)), device=device)
174176

@@ -189,8 +191,8 @@ def eye(
189191
"""
190192
from ._array_object import Array
191193

192-
_check_valid_dtype(dtype)
193194
_check_device(device)
195+
_check_valid_dtype(dtype, device)
194196

195197
return Array._new(
196198
np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device
@@ -237,8 +239,8 @@ def full(
237239
"""
238240
from ._array_object import Array
239241

240-
_check_valid_dtype(dtype)
241242
_check_device(device)
243+
_check_valid_dtype(dtype, device)
242244

243245
if not isinstance(fill_value, bool | int | float | complex):
244246
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
@@ -266,10 +268,10 @@ def full_like(
266268
"""
267269
from ._array_object import Array
268270

269-
_check_valid_dtype(dtype)
270271
_check_device(device)
271272
if device is None:
272273
device = x.device
274+
_check_valid_dtype(dtype, device)
273275

274276
if not isinstance(fill_value, bool | int | float | complex):
275277
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
@@ -300,8 +302,8 @@ def linspace(
300302
"""
301303
from ._array_object import Array
302304

303-
_check_valid_dtype(dtype)
304305
_check_device(device)
306+
_check_valid_dtype(dtype, device)
305307

306308
return Array._new(
307309
np.linspace(start, stop, num, dtype=_np_dtype(dtype), endpoint=endpoint),
@@ -353,8 +355,8 @@ def ones(
353355
"""
354356
from ._array_object import Array
355357

356-
_check_valid_dtype(dtype)
357358
_check_device(device)
359+
_check_valid_dtype(dtype, device)
358360

359361
return Array._new(np.ones(shape, dtype=_np_dtype(dtype)), device=device)
360362

@@ -369,10 +371,10 @@ def ones_like(
369371
"""
370372
from ._array_object import Array
371373

372-
_check_valid_dtype(dtype)
373374
_check_device(device)
374375
if device is None:
375376
device = x.device
377+
_check_valid_dtype(dtype, device)
376378

377379
return Array._new(np.ones_like(x._array, dtype=_np_dtype(dtype)), device=device)
378380

@@ -418,8 +420,8 @@ def zeros(
418420
"""
419421
from ._array_object import Array
420422

421-
_check_valid_dtype(dtype)
422423
_check_device(device)
424+
_check_valid_dtype(dtype, device)
423425

424426
return Array._new(np.zeros(shape, dtype=_np_dtype(dtype)), device=device)
425427

@@ -434,9 +436,9 @@ def zeros_like(
434436
"""
435437
from ._array_object import Array
436438

437-
_check_valid_dtype(dtype)
438439
_check_device(device)
439440
if device is None:
440441
device = x.device
442+
_check_valid_dtype(dtype, device)
441443

442444
return Array._new(np.zeros_like(x._array, dtype=_np_dtype(dtype)), device=device)

array_api_strict/_devices.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Final
2+
3+
from ._dtypes import DType, float64, complex128
4+
from ._dtypes import (
5+
# _kind_to_dtypes,
6+
_all_dtypes, _boolean_dtypes, _signed_integer_dtypes,
7+
_unsigned_integer_dtypes, _integer_dtypes, _real_floating_dtypes,
8+
_complex_floating_dtypes, _numeric_dtypes
9+
)
10+
11+
_ALL_DEVICE_NAMES = ("CPU_DEVICE", "device1", "device2", "F32_device")
12+
13+
class Device:
14+
_device: Final[str]
15+
__slots__ = ("_device", "__weakref__")
16+
17+
def __init__(self, device: str = "CPU_DEVICE"):
18+
if device not in _ALL_DEVICE_NAMES:
19+
raise ValueError(f"The device '{device}' is not a valid choice.")
20+
self._device = device
21+
22+
def __repr__(self) -> str:
23+
return f"array_api_strict.Device('{self._device}')"
24+
25+
def __eq__(self, other: object) -> bool:
26+
if not isinstance(other, Device):
27+
return False
28+
return self._device == other._device
29+
30+
def __hash__(self) -> int:
31+
return hash(("Device", self._device))
32+
33+
def _supported_dtypes(self) -> list[DType]:
34+
return list(dt for dt in _all_dtypes if device_supports_dtype(self, dt))
35+
36+
37+
CPU_DEVICE = Device()
38+
_F32_DEVICE = Device("F32_device")
39+
40+
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"), _F32_DEVICE)
41+
42+
43+
def check_device(device: Device | None) -> None:
44+
if device is not None and not isinstance(device, Device):
45+
raise ValueError(f"Unsupported device {device!r}")
46+
47+
if device is not None and device not in ALL_DEVICES:
48+
raise ValueError(f"Unsupported device {device!r}")
49+
50+
51+
def device_supports_dtype(device: Device | None, dtype: DType |None) -> bool:
52+
"""True if `device` supports `dtype`, False otherwise."""
53+
# special-case F32_device
54+
if device == _F32_DEVICE:
55+
return dtype not in (float64, complex128)
56+
57+
# All other devices support all dtypes
58+
return True
59+
60+
61+
# Device-specific dtype maps
62+
63+
def _map_helper(dtypes: list[DType], device: Device) -> dict[str, DType]:
64+
return {
65+
dt._canonic_name: dt
66+
for dt in dtypes
67+
if device_supports_dtype(device, dt)
68+
}
69+
70+
71+
# _info.dtypes() maps "kind" -> dict of {name: dtype}
72+
# Note that "kinds" differ from "categories" above, per the spec.
73+
74+
_kind_to_dtypes = {
75+
None: _all_dtypes,
76+
"bool": _boolean_dtypes,
77+
"signed integer": _signed_integer_dtypes,
78+
"unsigned integer": _unsigned_integer_dtypes,
79+
"integral": _integer_dtypes,
80+
"real floating": _real_floating_dtypes,
81+
"complex floating": _complex_floating_dtypes,
82+
"numeric": _numeric_dtypes
83+
}
84+
85+
_device_dtype_maps = {}
86+
87+
for device in ALL_DEVICES:
88+
_map = {}
89+
for kind, dtypes in _kind_to_dtypes.items():
90+
_map[kind] = {
91+
dt._canonic_name: dt
92+
for dt in dtypes
93+
if device_supports_dtype(device, dt)
94+
}
95+
_device_dtype_maps[device] = _map
96+
97+
###breakpoint()
98+
99+
100+
'''
101+
_kind_to_dtypes = {
102+
None: {x._canonic_name:x for x in _all_dtypes},
103+
"bool": {"bool": bool},
104+
"signed integer": {x._canonic_name: x for x in _signed_integer_dtypes},
105+
"unsigned integer": {x._canonic_name: x for x in _unsigned_integer_dtypes},
106+
"integral": {x._canonic_name: x for x in _integer_dtypes},
107+
"real floating": {x._canonic_name: x for x in _real_floating_dtypes},
108+
"complex floating": {x._canonic_name: x for x in _complex_floating_dtypes},
109+
"numeric": {x._canonic_name: x for x in _numeric_dtypes}
110+
}
111+
'''

array_api_strict/_dtypes.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,48 @@ def _np_dtype(dtype: DType | None) -> np.dtype[Any] | None:
141141

142142
_real_to_complex_map = {float32: complex64, float64: complex128}
143143

144+
145+
146+
147+
'''
148+
_kind_to_dtypes = {
149+
None: {x._canonic_name:x for x in _all_dtypes},
150+
"bool": {"bool": bool},
151+
"signed integer": {x._canonic_name: x for x in _signed_integer_dtypes},
152+
"unsigned integer": {x._canonic_name: x for x in _unsigned_integer_dtypes},
153+
"integral": {x._canonic_name: x for x in _integer_dtypes},
154+
"real floating": {x._canonic_name: x for x in _real_floating_dtypes},
155+
"complex floating": {x._canonic_name: x for x in _complex_floating_dtypes},
156+
"numeric": {x._canonic_name: x for x in _numeric_dtypes}
157+
}
158+
'''
159+
160+
'''
161+
def _kind_to_dtypes(kind : str | None = None) -> dict[str, ...]:
162+
"""A helper for Device / _info.dtypes().
163+
164+
The `kind`s here are those _info.dtypes() accepts. Which differs from
165+
"categories" above, yes.
166+
"""
167+
if kind is None:
168+
return {x._canonic_name:x for x in dt._all_dtypes}
169+
if kind == "bool":
170+
return {"bool": dt.bool}
171+
if kind == "signed integer":
172+
return {x._canonic_name: x for x in dt._signed_integer_dtypes}
173+
if kind == "unsigned integer":
174+
return {x._canonic_name: x for x in dt._unsigned_integer_dtypes}
175+
if kind == "integral":
176+
return {x._canonic_name: x for x in dt._integer_dtypes}
177+
if kind == "real floating":
178+
return {x._canonic_name: x for x in dt._real_floating_dtypes}
179+
if kind == "complex floating":
180+
return {x._canonic_name: x for x in dt._complex_floating_dtypes}
181+
if kind == "numeric":
182+
return {x._canonic_name: x for x in dt._numeric_dtypes}
183+
raise ValueError(f"unsupported kind: {kind!r}")
184+
'''
185+
144186
# Note: the spec defines a restricted type promotion table compared to NumPy.
145187
# In particular, cross-kind promotions like integer + float or boolean +
146188
# integer are not allowed, even for functions that accept both kinds.

0 commit comments

Comments
 (0)