Skip to content

Commit 1bf63d3

Browse files
committed
fix: fix type stub
1 parent 0ffdb73 commit 1bf63d3

6 files changed

Lines changed: 232 additions & 203 deletions

File tree

src/array_api/_2022_12.py

Lines changed: 28 additions & 33 deletions
Large diffs are not rendered by default.

src/array_api/_2023_12.py

Lines changed: 58 additions & 42 deletions
Large diffs are not rendered by default.

src/array_api/_2024_12.py

Lines changed: 64 additions & 48 deletions
Large diffs are not rendered by default.

src/array_api/_draft.py

Lines changed: 64 additions & 48 deletions
Large diffs are not rendered by default.

src/array_api/cli/_main.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _class_to_protocol(stmt: ast.ClassDef, typevars: Sequence[TypeVarInfo]) -> P
135135
)
136136

137137

138-
def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes], bases: list[ast.expr] | None = None, typevars: Sequence[TypeVarInfo] | None = None) -> ProtocolData:
138+
def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes], /, *, typevars: Sequence[TypeVarInfo], bases: list[ast.expr] | None = None, typevars_force: Sequence[TypeVarInfo] | None = None) -> ProtocolData:
139139
"""
140140
Convert a list of module attributes to a Protocol class.
141141
@@ -147,7 +147,9 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes], b
147147
The attributes to include in the Protocol class.
148148
bases : list[ast.expr] | None, optional
149149
The base classes for the Protocol class, by default None, which defaults to [Protocol].
150-
typevars : Sequence[TypeVarInfo] | None, optional
150+
typevars : Sequence[TypeVarInfo]
151+
The type variables used in the class.
152+
typevars_force : Sequence[TypeVarInfo] | None, optional
151153
The type variables used in the Protocol class, by default None, which defaults to the type variables used in the attributes.
152154
153155
Returns
@@ -167,8 +169,8 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes], b
167169
)
168170
if a.docstring is not None:
169171
body.append(ast.Expr(value=ast.Constant(a.docstring)))
170-
if typevars is None:
171-
typevars = sorted({x for attribute in attributes for x in attribute.typevars_used}, key=lambda x: x.name.lower())
172+
if typevars_force is None:
173+
typevars_force = [t for t in typevars if any(t in attr.typevars_used for attr in attributes)]
172174
return ProtocolData(
173175
stmt=ast.ClassDef(
174176
name=name,
@@ -179,9 +181,9 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes], b
179181
ast.Name(id="Protocol"),
180182
],
181183
body=body,
182-
type_params=[ast.TypeVar(name=t.name, bound=ast.Name(id=t.bound) if t.bound else None) for t in typevars],
184+
type_params=[ast.TypeVar(name=t.name, bound=ast.Name(id=t.bound) if t.bound else None) for t in typevars_force],
183185
),
184-
typevars_used=typevars,
186+
typevars_used=typevars_force,
185187
)
186188

187189

@@ -197,31 +199,11 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
197199
The output path where the generated Protocol classes will be saved.
198200
199201
"""
200-
body_typevars = body_module["_types"]
202+
body_module["_types"]
201203
del body_module["__init__"]
202204

203205
# Get all TypeVars
204-
typevars: list[TypeVarInfo] = []
205-
for b in body_typevars:
206-
if isinstance(b, ast.Assign):
207-
value = b.value
208-
if isinstance(value, ast.Call):
209-
if isinstance(value.func, ast.Name):
210-
if value.func.id == "TypeVar":
211-
if isinstance(value.args[0], ast.Constant):
212-
name = value.args[0].s
213-
typevars.append(
214-
TypeVarInfo(
215-
name=name,
216-
bound={
217-
"array": "_array",
218-
}.get(name, None),
219-
)
220-
)
221-
typevars += [TypeVarInfo(name=x) for x in ["Capabilities", "DefaultDataTypes", "DataTypes"]]
222-
typevars = [t for t in typevars if t.name not in ["ellipsis", "PyCapsule", "SupportsBufferProtocol"]]
223-
print(typevars)
224-
typevars = sorted(typevars, key=lambda x: x.name.lower())
206+
typevars = [TypeVarInfo("array", "_array"), TypeVarInfo("dtype"), TypeVarInfo("device"), TypeVarInfo("_T_co")]
225207

226208
# Dict of module attributes per submodule
227209
module_attributes: defaultdict[str, list[ModuleAttributes]] = defaultdict(list)
@@ -254,6 +236,9 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
254236
elif isinstance(b, ast.Assign):
255237
# _types.py contains Assigns which are not part of the Namespace
256238
if submodule == "_types":
239+
if isinstance(b.targets[0], ast.Name) and b.targets[0].id in ["Capabilities", "DefaultDataTypes", "DataTypes"]:
240+
b = ast.parse(ast.unparse(b).replace("dtype", "Any")) # type: ignore
241+
out.body = [b, *out.body]
257242
continue
258243
if not isinstance(b.targets[0], ast.Name):
259244
continue
@@ -291,21 +276,21 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
291276
# Create Protocols for the main namespace
292277
OPTIONAL_SUBMODULES = ["fft", "linalg"]
293278
main_attributes = [attribute for submodule, attributes in module_attributes.items() for attribute in attributes if submodule not in OPTIONAL_SUBMODULES]
294-
main_protocol = _attributes_to_protocol("ArrayNamespace", main_attributes).stmt
279+
main_protocol = _attributes_to_protocol("ArrayNamespace", main_attributes, typevars=typevars).stmt
295280
out.body.append(main_protocol)
296281

297282
# Create Protocols for fft and linalg
298283
submodules: list[ModuleAttributes] = []
299284
for submodule, attributes in module_attributes.items():
300285
if submodule not in OPTIONAL_SUBMODULES:
301286
continue
302-
data = _attributes_to_protocol(submodule[0].upper() + submodule[1:] + "Namespace", attributes)
287+
data = _attributes_to_protocol(submodule[0].upper() + submodule[1:] + "Namespace", attributes, typevars=typevars)
303288
out.body.append(data.stmt)
304289
if submodule in OPTIONAL_SUBMODULES:
305290
submodules.append(ModuleAttributes(submodule, data.name, None, [t for t in typevars if any(t in attr.typevars_used for attr in attributes)]))
306291

307292
# Create Full Protocol for the main namespace
308-
out.body.append(_attributes_to_protocol("ArrayNamespaceFull", submodules, [ast.Subscript(ast.Name("ArrayNamespace"), ast.Tuple([ast.Name(t.name) for t in main_protocol.type_params]))], typevars=[t for t in typevars if t.name in [s.name for s in main_protocol.type_params]]).stmt) # type: ignore
293+
out.body.append(_attributes_to_protocol("ArrayNamespaceFull", submodules, bases=[ast.Subscript(ast.Name("ArrayNamespace"), ast.Tuple([ast.Name(t.name) for t in main_protocol.type_params]))], typevars=typevars, typevars_force=[t for t in typevars if t.name in [s.name for s in main_protocol.type_params]]).stmt) # type: ignore
309294

310295
# Replace TypeVars because of the name conflicts like "array: array"
311296
for node in ast.walk(out):
@@ -344,6 +329,7 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
344329
Tuple,
345330
List,
346331
runtime_checkable,
332+
TypedDict
347333
)
348334
from types import EllipsisType as ellipsis
349335
from typing_extensions import CapsuleType as PyCapsule

src/array_api_compat/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ def array_namespace[TArray: Array](
88
*xs: TArray | complex | None,
99
api_version: Literal["2024.12"] | None = None,
1010
use_compat: bool | None = None,
11-
) -> ArrayNamespaceFull[TArray, Any, Any, Any, Any, Any]: ...
11+
) -> ArrayNamespaceFull[TArray, Any, Any]: ...

0 commit comments

Comments
 (0)