Skip to content

Commit 9ae5083

Browse files
committed
MAINT: add a _canonic_name field to DTypes, refactor _info.dtypes()
1 parent 08dbe3a commit 9ae5083

2 files changed

Lines changed: 26 additions & 70 deletions

File tree

array_api_strict/_dtypes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111

1212
class DType:
1313
_np_dtype: Final[np.dtype[Any]]
14-
__slots__ = ("_np_dtype", "__weakref__")
14+
_canonic_name: Final[Any]
15+
__slots__ = ("_np_dtype", "_canonic_name", "__weakref__")
1516

1617
def __init__(self, np_dtype: npt.DTypeLike):
18+
self._canonic_name = np_dtype
1719
self._np_dtype = np.dtype(np_dtype)
1820

1921
def __repr__(self) -> str:

array_api_strict/_info.py

Lines changed: 23 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,26 @@
66
from ._typing import Capabilities, DataTypes, DefaultDataTypes
77

88

9+
def _kind_to_dtypes(kind : str | None = None):
10+
if kind is None:
11+
return {x._canonic_name:x for x in dt._all_dtypes}
12+
if kind == "bool":
13+
return {"bool": dt.bool}
14+
if kind == "signed integer":
15+
return {x._canonic_name:x for x in dt._signed_integer_dtypes}
16+
if kind == "unsigned integer":
17+
return {x._canonic_name:x for x in dt._unsigned_integer_dtypes}
18+
if kind == "integral":
19+
return {x._canonic_name:x for x in dt._integer_dtypes}
20+
if kind == "real floating":
21+
return {x._canonic_name:x for x in dt._real_floating_dtypes}
22+
if kind == "complex floating":
23+
return {x._canonic_name:x for x in dt._complex_floating_dtypes}
24+
if kind == "numeric":
25+
return {x._canonic_name:x for x in dt._numeric_dtypes}
26+
raise ValueError(f"unsupported kind: {kind!r}")
27+
28+
929
@requires_api_version('2023.12')
1030
class __array_namespace_info__:
1131
@requires_api_version('2023.12')
@@ -54,75 +74,9 @@ def dtypes(
5474
device: Device | None = None,
5575
kind: str | tuple[str, ...] | None = None,
5676
) -> DataTypes:
57-
if kind is None:
58-
return {
59-
"bool": dt.bool,
60-
"int8": dt.int8,
61-
"int16": dt.int16,
62-
"int32": dt.int32,
63-
"int64": dt.int64,
64-
"uint8": dt.uint8,
65-
"uint16": dt.uint16,
66-
"uint32": dt.uint32,
67-
"uint64": dt.uint64,
68-
"float32": dt.float32,
69-
"float64": dt.float64,
70-
"complex64": dt.complex64,
71-
"complex128": dt.complex128,
72-
}
73-
if kind == "bool":
74-
return {"bool": dt.bool}
75-
if kind == "signed integer":
76-
return {
77-
"int8": dt.int8,
78-
"int16": dt.int16,
79-
"int32": dt.int32,
80-
"int64": dt.int64,
81-
}
82-
if kind == "unsigned integer":
83-
return {
84-
"uint8": dt.uint8,
85-
"uint16": dt.uint16,
86-
"uint32": dt.uint32,
87-
"uint64": dt.uint64,
88-
}
89-
if kind == "integral":
90-
return {
91-
"int8": dt.int8,
92-
"int16": dt.int16,
93-
"int32": dt.int32,
94-
"int64": dt.int64,
95-
"uint8": dt.uint8,
96-
"uint16": dt.uint16,
97-
"uint32": dt.uint32,
98-
"uint64": dt.uint64,
99-
}
100-
if kind == "real floating":
101-
return {
102-
"float32": dt.float32,
103-
"float64": dt.float64,
104-
}
105-
if kind == "complex floating":
106-
return {
107-
"complex64": dt.complex64,
108-
"complex128": dt.complex128,
109-
}
110-
if kind == "numeric":
111-
return {
112-
"int8": dt.int8,
113-
"int16": dt.int16,
114-
"int32": dt.int32,
115-
"int64": dt.int64,
116-
"uint8": dt.uint8,
117-
"uint16": dt.uint16,
118-
"uint32": dt.uint32,
119-
"uint64": dt.uint64,
120-
"float32": dt.float32,
121-
"float64": dt.float64,
122-
"complex64": dt.complex64,
123-
"complex128": dt.complex128,
124-
}
125-
if isinstance(kind, tuple):
77+
if isinstance(kind, type(None) | str):
78+
return _kind_to_dtypes(kind)
79+
elif isinstance(kind, tuple):
12680
res: DataTypes = {}
12781
for k in kind:
12882
res.update(self.dtypes(kind=k))

0 commit comments

Comments
 (0)