Skip to content

Commit e81be25

Browse files
committed
chore: wip
1 parent a5558af commit e81be25

1 file changed

Lines changed: 21 additions & 17 deletions

File tree

src/array_api/cli/_main.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import ast
4-
import runpy
54
import warnings
65
from collections import defaultdict
76
from collections.abc import Iterable, Sequence
@@ -138,7 +137,7 @@ def _class_to_protocol(stmt: ast.ClassDef, typevars: Sequence[TypeVarInfo]) -> P
138137
)
139138

140139

141-
def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes]) -> ProtocolData:
140+
def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes], bases: list[ast.expr] | None = None, typevars: Sequence[TypeVarInfo] | None = None) -> ProtocolData:
142141
"""
143142
Convert a list of module attributes to a Protocol class.
144143
@@ -148,6 +147,10 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes]) -
148147
The name of the Protocol class.
149148
attributes : Sequence[ModuleAttributes]
150149
The attributes to include in the Protocol class.
150+
bases : list[ast.expr] | None, optional
151+
The base classes for the Protocol class, by default None, which defaults to [Protocol].
152+
typevars : Sequence[TypeVarInfo] | None, optional
153+
The type variables used in the Protocol class, by default None, which defaults to the type variables used in the attributes.
151154
152155
Returns
153156
-------
@@ -166,14 +169,16 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes]) -
166169
)
167170
if a.docstring is not None:
168171
body.append(ast.Expr(value=ast.Constant(a.docstring)))
169-
170-
typevars = sorted({x for attribute in attributes for x in attribute.typevars_used}, key=lambda x: x.name)
172+
if typevars is None:
173+
typevars = sorted({x for attribute in attributes for x in attribute.typevars_used}, key=lambda x: x.name)
171174
return ProtocolData(
172175
stmt=ast.ClassDef(
173176
name=name,
174177
decorator_list=[ast.Name(id="runtime_checkable")],
175178
keywords=[],
176-
bases=[
179+
bases=bases
180+
if bases
181+
else [
177182
ast.Name(id="Protocol"),
178183
],
179184
body=body,
@@ -272,26 +277,31 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
272277
elif isinstance(b, ast.ClassDef):
273278
data = _class_to_protocol(b, typevars)
274279
# add to output, do not add to module attributes
275-
out.body.append(data.stmt)
280+
# add to first position
281+
out.body.insert(0, data.stmt)
276282
elif isinstance(b, ast.Expr):
277283
pass
278284
else:
279285
warnings.warn(f"Skipping {submodule} {b}", stacklevel=2)
280286

287+
# Create Protocols for the main namespace
288+
OPTIONAL_SUBMODULES = ["fft", "linalg"]
289+
main_attributes = [attribute for submodule, attributes in module_attributes.items() for attribute in attributes if submodule not in OPTIONAL_SUBMODULES]
290+
main_protocol = _attributes_to_protocol("ArrayNamespace", main_attributes).stmt
291+
out.body.append(main_protocol)
292+
281293
# Create Protocols for fft and linalg
282294
submodules: list[ModuleAttributes] = []
283-
OPTIONAL_SUBMODULES = ["fft", "linalg"]
284295
for submodule, attributes in module_attributes.items():
285296
if submodule not in OPTIONAL_SUBMODULES:
286297
continue
287298
data = _attributes_to_protocol(submodule[0].upper() + submodule[1:] + "Namespace", attributes)
288299
out.body.append(data.stmt)
289300
if submodule in OPTIONAL_SUBMODULES:
290-
submodules.append(ModuleAttributes(submodule, data.name, None, []))
301+
submodules.append(ModuleAttributes(submodule, data.name, None, [t for t in typevars if any(t in attr.typevars_used for attr in attributes)]))
291302

292-
# Create Protocols for the main namespace
293-
attributes = [attribute for submodule, attributes in module_attributes.items() for attribute in attributes if submodule not in OPTIONAL_SUBMODULES] + submodules
294-
out.body.append(_attributes_to_protocol("ArrayNamespace", attributes).stmt)
303+
# Create Full Protocol for the main namespace
304+
out.body.append(_attributes_to_protocol("ArrayNamespaceFull", submodules, [ast.Subscript(ast.Name("ArrayNamespace"), ast.Tuple([ast.Name(t.name) for t in main_protocol.type_params]))], typevars=[t for t in typevars if t.name in [s.name for s in main_protocol.type_params]]).stmt) # type: ignore
295305

296306
# Replace TypeVars because of the name conflicts like "array: array"
297307
for node in ast.walk(out):
@@ -374,9 +384,3 @@ def generate_all(
374384
# get module bodies
375385
body_module = {path.stem: ast.parse(path.read_text("utf-8").replace("Dtype", "dtype").replace("Device", "device")).body for path in dir_path.rglob("*.py")}
376386
generate(body_module, (Path(out_path) / dir_path.name).with_suffix(".py"))
377-
378-
import sys
379-
380-
# run ssort, otherwise it is broken
381-
sys.argv = ["ssort", "src/array_api"]
382-
runpy.run_module("ssort")

0 commit comments

Comments
 (0)