Skip to content

Commit 4817e08

Browse files
committed
ENH: info.default_dtypes: make device-aware
1 parent 29e8720 commit 4817e08

3 files changed

Lines changed: 32 additions & 9 deletions

File tree

array_api_strict/_devices.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Final
22

3-
from ._dtypes import DType, float64, complex128
43
from ._dtypes import (
4+
DType, float32, float64, complex64, complex128, int64,
55
_all_dtypes, _boolean_dtypes, _signed_integer_dtypes,
66
_unsigned_integer_dtypes, _integer_dtypes, _real_floating_dtypes,
77
_complex_floating_dtypes, _numeric_dtypes
@@ -50,6 +50,22 @@ def check_device(device: Device | None) -> None:
5050

5151
# Helpers for device-specific dtype support
5252

53+
def get_default_dtypes(device: Device | None = None) -> dict[str, Device]:
54+
if device == _F32_DEVICE:
55+
return {
56+
"real floating": float32,
57+
"complex floating": complex64,
58+
"integral": int64,
59+
"indexing": int64,
60+
}
61+
else:
62+
return {
63+
"real floating": float64,
64+
"complex floating": complex128,
65+
"integral": int64,
66+
"indexing": int64,
67+
}
68+
5369

5470
def device_supports_dtype(device: Device | None, dtype: DType |None) -> bool:
5571
"""True if `device` supports `dtype`, False otherwise."""

array_api_strict/_info.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import numpy as np
22

3-
from . import _dtypes as dt
43
from . import _devices
5-
from ._array_object import ALL_DEVICES, CPU_DEVICE, Device
4+
from ._devices import ALL_DEVICES, CPU_DEVICE, Device
65
from ._flags import get_array_api_strict_flags, requires_api_version
76
from ._typing import Capabilities, DataTypes, DefaultDataTypes
87

@@ -41,12 +40,7 @@ def default_dtypes(
4140
*,
4241
device: Device | None = None,
4342
) -> DefaultDataTypes:
44-
return {
45-
"real floating": dt.float64,
46-
"complex floating": dt.complex128,
47-
"integral": dt.int64,
48-
"indexing": dt.int64,
49-
}
43+
return _devices.get_default_dtypes(device)
5044

5145
@requires_api_version('2023.12')
5246
def dtypes(

array_api_strict/tests/test_device_support.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,16 @@ def test_info_no_f64(self):
5353
all_dtypes = info.dtypes(device=f32_device)
5454
assert "float64" not in all_dtypes
5555
assert "complex128" not in all_dtypes
56+
57+
def test_info_default_dtypes(self):
58+
f32_device = xp.Device("F32_device")
59+
info = xp.__array_namespace_info__()
60+
defaults = info.default_dtypes(device=f32_device)
61+
assert defaults["real floating"] == xp.float32
62+
assert defaults["complex floating"] == xp.complex64
63+
64+
cpu_device = xp.Device()
65+
info = xp.__array_namespace_info__()
66+
defaults = info.default_dtypes(device=cpu_device)
67+
assert defaults["real floating"] == xp.float64
68+
assert defaults["complex floating"] == xp.complex128

0 commit comments

Comments
 (0)