Skip to content

Commit 99eab4a

Browse files
committed
chore: wip
1 parent b40a6d1 commit 99eab4a

6 files changed

Lines changed: 204 additions & 13 deletions

File tree

.pre-commit-config.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ repos:
1313
hooks:
1414
- id: commitizen
1515
stages: [commit-msg]
16+
# run .pre-commit/update_array_namespace.py
17+
- repo: local
18+
hooks:
19+
- id: update-array-namespace
20+
name: Update array namespace
21+
entry: .pre-commit/update_array_namespace.py
22+
language: python
1623
- repo: https://github.com/pre-commit/pre-commit-hooks
1724
rev: v5.0.0
1825
hooks:
@@ -44,7 +51,7 @@ repos:
4451
rev: v0.11.5
4552
hooks:
4653
- id: ruff
47-
args: [--fix, --exit-non-zero-on-fix]
54+
args: [--fix, --unsafe-fixes, --exit-non-zero-on-fix]
4855
- id: ruff-format
4956
- repo: https://github.com/codespell-project/codespell
5057
rev: v2.4.1

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ lint.ignore = [
6868
"D104", # Missing docstring in public package
6969
"D107", # Missing docstring in `__init__`
7070
"D401", # First line of docstring should be in imperative mood
71+
"S603",
72+
"S607",
7173
]
7274
lint.per-file-ignores."conftest.py" = [ "D100" ]
7375
lint.per-file-ignores."docs/conf.py" = [ "D100" ]

src/array_api/cli.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/array_api/cli/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .cli import app
2+
3+
__all__ = ["app"]

src/array_api/cli/_main.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import sys
5+
from collections import defaultdict
6+
from copy import deepcopy
7+
from pathlib import Path
8+
9+
10+
def _function_to_protocol(
11+
stmt: ast.FunctionDef, typevars: list[str]
12+
) -> tuple[ast.ClassDef, list[str], str]:
13+
stmt = deepcopy(stmt)
14+
name = stmt.name
15+
docstring = ast.get_docstring(stmt)
16+
stmt.name = "__call__"
17+
stmt.body = [ast.Raise(exc=ast.Name(id="NotImplementedError"), cause=None)]
18+
stmt.args.posonlyargs.insert(0, ast.arg(arg="self"))
19+
stmt.decorator_list.append(ast.Name(id="abstractmethod"))
20+
args = ast.unparse(stmt.args)
21+
typevars = [typevar for typevar in typevars if typevar in args]
22+
23+
# Construct the protocol
24+
cls_def = ast.ClassDef(
25+
name=name,
26+
decorator_list=[ast.Name(id="runtime_checkable")],
27+
keywords=[],
28+
bases=[
29+
ast.Subscript(
30+
value=ast.Name(id="Protocol"),
31+
slice=ast.Tuple(elts=[ast.Name(typevar) for typevar in typevars]),
32+
)
33+
],
34+
body=[stmt],
35+
type_params=[],
36+
)
37+
if docstring is not None:
38+
cls_def.body.insert(0, ast.Expr(value=ast.Constant(docstring, kind=None)))
39+
if sys.version_info >= (3, 12):
40+
cls_def.type_params = []
41+
return cls_def, typevars, name + (f"[{', '.join(typevars)}]" if typevars else "")
42+
43+
44+
def _attributes_to_protocol(
45+
name, attributes: list[tuple[str, str, str | None, list]], typevars: list[str]
46+
) -> tuple[ast.ClassDef, set[str], str]:
47+
body = []
48+
for attribute, type, docstring, _ in attributes:
49+
body.append(
50+
ast.AnnAssign(
51+
target=ast.Name(id=attribute),
52+
annotation=ast.Name(id=type) if type is not None else None,
53+
simple=1,
54+
)
55+
)
56+
if docstring is not None:
57+
body.append(ast.Expr(value=ast.Constant(docstring)))
58+
59+
typevars = {x for attribute in attributes for x in attribute[3]}
60+
return (
61+
ast.ClassDef(
62+
name=name,
63+
decorator_list=[ast.Name(id="runtime_checkable")],
64+
keywords=[],
65+
bases=[
66+
ast.Subscript(
67+
value=ast.Name(id="Protocol"),
68+
slice=ast.Tuple(elts=[ast.Name(typevar) for typevar in typevars]),
69+
)
70+
],
71+
body=body,
72+
type_params=[],
73+
),
74+
typevars,
75+
name + (f"[{', '.join(typevars)}]" if typevars else ""),
76+
)
77+
78+
79+
def generate(cache_dir: Path | str = ".cache", out_name: str = "_namespace.py") -> None:
80+
import subprocess as sp
81+
82+
Path(cache_dir).mkdir(exist_ok=True)
83+
sp.run(["git", "clone", "https://github.com/data-apis/array-api", ".cache"])
84+
# main working directory
85+
draft_path = Path(cache_dir) / Path("src") / "array_api_stubs" / "_draft"
86+
87+
# get module bodies
88+
body_module = {
89+
path.stem: ast.parse(path.read_text("utf-8")).body
90+
for path in draft_path.rglob("*.py")
91+
if path.name != out_name
92+
}
93+
body_typevars = body_module.pop("_types")
94+
body_module.pop("__init__")
95+
96+
# Get all TypeVars
97+
typevars = []
98+
for b in body_typevars:
99+
if isinstance(b, ast.Assign):
100+
value = b.value
101+
if isinstance(value, ast.Call):
102+
if value.func.id == "TypeVar":
103+
typevars.append(value.args[0].s)
104+
print(typevars)
105+
106+
# Dict of module attributes per submodule
107+
module_attributes = defaultdict(list)
108+
109+
# Import `abc.abstractmethod`, `typing.Protocol` and `typing.runtime_checkable`
110+
out = ast.Module(body=[], type_ignores=[])
111+
out.body.append(
112+
ast.Expr(value=ast.Constant("Auto generated Protocol classes (Do not edit)"))
113+
)
114+
out.body.append(
115+
ast.ImportFrom(
116+
module="typing",
117+
names=[
118+
ast.alias(name="Protocol", alias=None),
119+
ast.alias(name="runtime_checkable", alias=None),
120+
],
121+
level=0,
122+
),
123+
)
124+
out.body.append(
125+
ast.ImportFrom(
126+
module="abc",
127+
names=[ast.alias(name="abstractmethod", alias=None)],
128+
level=0,
129+
),
130+
)
131+
132+
# Create Protocols with __call__, representing functions
133+
for submodule, body in body_module.items():
134+
for b in body:
135+
if isinstance(b, (ast.Import, ast.ImportFrom)):
136+
out.body.insert(0, b)
137+
elif isinstance(b, ast.FunctionDef):
138+
cls_def, typevars_, type = _function_to_protocol(b, typevars)
139+
module_attributes[submodule].append((b.name, type, None, typevars_))
140+
out.body.append(cls_def)
141+
elif isinstance(b, ast.Assign):
142+
id = b.targets[0].id
143+
if id == "__all__":
144+
pass
145+
else:
146+
docstring = None
147+
docstring_expr = body[body.index(b) + 1]
148+
if isinstance(docstring_expr, ast.Expr):
149+
if isinstance(docstring_expr.value, ast.Constant):
150+
docstring = docstring_expr.value.value
151+
module_attributes[submodule].append((id, "float", docstring, []))
152+
elif isinstance(b, ast.Expr):
153+
pass
154+
else:
155+
print(f"Skipping {submodule} {b} {ast.dump(b)} \n\n")
156+
157+
# Create Protocols for fft and linalg
158+
submodules = []
159+
OPTIONAL_SUBMODULES = ["fft", "linalg"]
160+
for submodule, attributes in module_attributes.items():
161+
if submodule not in OPTIONAL_SUBMODULES:
162+
continue
163+
cls_def, typevars_, type = _attributes_to_protocol(
164+
submodule[0].upper() + submodule[1:] + "Namespace", attributes, typevars
165+
)
166+
out.body.append(cls_def)
167+
if submodule in OPTIONAL_SUBMODULES:
168+
submodules.append((submodule, type, None, []))
169+
170+
# Create Protocols for the main namespace
171+
attributes = [
172+
attribute
173+
for submodule, attributes in module_attributes.items()
174+
for attribute in attributes
175+
if submodule not in OPTIONAL_SUBMODULES
176+
] + submodules
177+
out.body.append(_attributes_to_protocol("ArrayNamespace", attributes, typevars)[0])
178+
179+
out_path = draft_path / out_name
180+
out_path.write_text(ast.unparse(out), "utf-8")

src/array_api/cli/cli.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import typer
2+
3+
from ._main import generate
4+
5+
app = typer.Typer()
6+
7+
8+
@app.command()
9+
def main() -> None:
10+
"""Add the arguments and print the result."""
11+
generate()

0 commit comments

Comments
 (0)