66import numpy as np
77
88from ._dtypes import DType , _all_dtypes , _np_dtype
9+ from ._devices import CPU_DEVICE , Device , device_supports_dtype , check_device as _check_device
910from ._flags import get_array_api_strict_flags
1011from ._typing import NestedSequence , SupportsBufferProtocol , SupportsDLPack
1112
1415 from typing_extensions import TypeIs
1516
1617 # Circular import
17- from ._array_object import Array , Device
18+ from ._array_object import Array
1819
1920
2021class Undef (Enum ):
@@ -24,10 +25,22 @@ class Undef(Enum):
2425_undef = Undef .UNDEF
2526
2627
27- def _check_valid_dtype (dtype : DType | None ) -> None :
28+ def _check_valid_dtype (dtype : DType | None , device : Device | None = None ) -> None :
2829 # Note: Only spelling dtypes as the dtype objects is supported.
29- if dtype not in (None ,) + _all_dtypes :
30- raise ValueError (f"dtype must be one of the supported dtypes, got { dtype !r} " )
30+
31+ if dtype is not None :
32+ if dtype not in _all_dtypes :
33+ raise ValueError (f"dtype must be one of the supported dtypes, got { dtype !r} " )
34+
35+ if device is not None :
36+ if not device_supports_dtype (device , dtype ):
37+ raise ValueError (f"Device { device !r} does not support dtype={ dtype !r} ." )
38+ return dtype
39+ else :
40+ # if dtype=None, return the default for the device
41+ _device = CPU_DEVICE if device is None else device
42+
43+ return
3144
3245
3346def _supports_buffer_protocol (obj : object ) -> TypeIs [SupportsBufferProtocol ]:
@@ -38,18 +51,6 @@ def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]:
3851 return True
3952
4053
41- def _check_device (device : Device | None ) -> None :
42- # _array_object imports in this file are inside the functions to avoid
43- # circular imports
44- from ._array_object import ALL_DEVICES , Device
45-
46- if device is not None and not isinstance (device , Device ):
47- raise ValueError (f"Unsupported device { device !r} " )
48-
49- if device is not None and device not in ALL_DEVICES :
50- raise ValueError (f"Unsupported device { device !r} " )
51-
52-
5354def asarray (
5455 obj : Array | complex | NestedSequence [complex ] | SupportsBufferProtocol ,
5556 / ,
@@ -65,11 +66,12 @@ def asarray(
6566 """
6667 from ._array_object import Array
6768
68- _check_valid_dtype (dtype )
69+ _check_device (device )
70+ _check_valid_dtype (dtype , device )
6971 _np_dtype = None
7072 if dtype is not None :
7173 _np_dtype = dtype ._np_dtype
72- _check_device ( device )
74+
7375 if isinstance (obj , Array ) and device is None :
7476 device = obj .device
7577
@@ -127,8 +129,8 @@ def arange(
127129 """
128130 from ._array_object import Array
129131
130- _check_valid_dtype (dtype )
131132 _check_device (device )
133+ _check_valid_dtype (dtype , device )
132134
133135 return Array ._new (
134136 np .arange (start , stop , step , dtype = _np_dtype (dtype )),
@@ -149,8 +151,8 @@ def empty(
149151 """
150152 from ._array_object import Array
151153
152- _check_valid_dtype (dtype )
153154 _check_device (device )
155+ _check_valid_dtype (dtype , device )
154156
155157 return Array ._new (np .empty (shape , dtype = _np_dtype (dtype )), device = device )
156158
@@ -165,10 +167,10 @@ def empty_like(
165167 """
166168 from ._array_object import Array
167169
168- _check_valid_dtype (dtype )
169170 _check_device (device )
170171 if device is None :
171172 device = x .device
173+ _check_valid_dtype (dtype , device )
172174
173175 return Array ._new (np .empty_like (x ._array , dtype = _np_dtype (dtype )), device = device )
174176
@@ -189,8 +191,8 @@ def eye(
189191 """
190192 from ._array_object import Array
191193
192- _check_valid_dtype (dtype )
193194 _check_device (device )
195+ _check_valid_dtype (dtype , device )
194196
195197 return Array ._new (
196198 np .eye (n_rows , M = n_cols , k = k , dtype = _np_dtype (dtype )), device = device
@@ -237,8 +239,8 @@ def full(
237239 """
238240 from ._array_object import Array
239241
240- _check_valid_dtype (dtype )
241242 _check_device (device )
243+ _check_valid_dtype (dtype , device )
242244
243245 if not isinstance (fill_value , bool | int | float | complex ):
244246 msg = f"Expected Python scalar fill_value, got type { type (fill_value )} "
@@ -266,10 +268,10 @@ def full_like(
266268 """
267269 from ._array_object import Array
268270
269- _check_valid_dtype (dtype )
270271 _check_device (device )
271272 if device is None :
272273 device = x .device
274+ _check_valid_dtype (dtype , device )
273275
274276 if not isinstance (fill_value , bool | int | float | complex ):
275277 msg = f"Expected Python scalar fill_value, got type { type (fill_value )} "
@@ -300,8 +302,8 @@ def linspace(
300302 """
301303 from ._array_object import Array
302304
303- _check_valid_dtype (dtype )
304305 _check_device (device )
306+ _check_valid_dtype (dtype , device )
305307
306308 return Array ._new (
307309 np .linspace (start , stop , num , dtype = _np_dtype (dtype ), endpoint = endpoint ),
@@ -353,8 +355,8 @@ def ones(
353355 """
354356 from ._array_object import Array
355357
356- _check_valid_dtype (dtype )
357358 _check_device (device )
359+ _check_valid_dtype (dtype , device )
358360
359361 return Array ._new (np .ones (shape , dtype = _np_dtype (dtype )), device = device )
360362
@@ -369,10 +371,10 @@ def ones_like(
369371 """
370372 from ._array_object import Array
371373
372- _check_valid_dtype (dtype )
373374 _check_device (device )
374375 if device is None :
375376 device = x .device
377+ _check_valid_dtype (dtype , device )
376378
377379 return Array ._new (np .ones_like (x ._array , dtype = _np_dtype (dtype )), device = device )
378380
@@ -418,8 +420,8 @@ def zeros(
418420 """
419421 from ._array_object import Array
420422
421- _check_valid_dtype (dtype )
422423 _check_device (device )
424+ _check_valid_dtype (dtype , device )
423425
424426 return Array ._new (np .zeros (shape , dtype = _np_dtype (dtype )), device = device )
425427
@@ -434,9 +436,9 @@ def zeros_like(
434436 """
435437 from ._array_object import Array
436438
437- _check_valid_dtype (dtype )
438439 _check_device (device )
439440 if device is None :
440441 device = x .device
442+ _check_valid_dtype (dtype , device )
441443
442444 return Array ._new (np .zeros_like (x ._array , dtype = _np_dtype (dtype )), device = device )
0 commit comments