Skip to content

Commit d81f68e

Browse files
committed
chore: wip
1 parent a6a3be6 commit d81f68e

1 file changed

Lines changed: 17 additions & 8 deletions

File tree

src/array_api/cli/_main.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ class ProtocolData:
1616
name: str
1717

1818

19-
20-
2119
def _function_to_protocol(stmt: ast.FunctionDef, typevars: list[str]) -> ProtocolData:
2220
"""
2321
Convert a function definition to a Protocol class.
@@ -56,7 +54,12 @@ def _function_to_protocol(stmt: ast.FunctionDef, typevars: list[str]) -> Protoco
5654
slice=ast.Tuple(elts=[ast.Name(typevar) for typevar in typevars]),
5755
)
5856
],
59-
body=([ast.Expr(value=ast.Constant(docstring, kind=None))] if docstring is not None else []) + [stmt],
57+
body=(
58+
[ast.Expr(value=ast.Constant(docstring, kind=None))]
59+
if docstring is not None
60+
else []
61+
)
62+
+ [stmt],
6063
type_params=[],
6164
) # type: ignore[call-arg]
6265
return ProtocolData(
@@ -65,9 +68,8 @@ def _function_to_protocol(stmt: ast.FunctionDef, typevars: list[str]) -> Protoco
6568
name=name + (f"[{', '.join(typevars)}]" if typevars else ""),
6669
)
6770

68-
def _class_to_protocol(
69-
stmt: ast.ClassDef, typevars: list[str]
70-
) -> ProtocolData:
71+
72+
def _class_to_protocol(stmt: ast.ClassDef, typevars: list[str]) -> ProtocolData:
7173
typevars = [typevar for typevar in typevars if typevar in ast.unparse(stmt)]
7274
stmt.bases = [
7375
ast.Subscript(
@@ -81,6 +83,7 @@ def _class_to_protocol(
8183
name=stmt.name + (f"[{', '.join(typevars)}]" if typevars else ""),
8284
)
8385

86+
8487
def _attributes_to_protocol(
8588
name, attributes: list[tuple[str, str, str | None, list]], typevars: list[str]
8689
) -> ProtocolData:
@@ -129,7 +132,11 @@ def generate_all(
129132
for dir_path in (Path(cache_dir) / Path("src") / "array_api_stubs").glob("**/"):
130133
# get module bodies
131134
body_module = {
132-
path.stem: ast.parse(path.read_text("utf-8")).body
135+
path.stem: ast.parse(
136+
path.read_text("utf-8")
137+
.replace("Dtype", "dtype")
138+
.replace("Device", "device")
139+
).body
133140
for path in dir_path.rglob("*.py")
134141
}
135142
generate(body_module, Path(out_path) / dir_path.name / out_name)
@@ -176,7 +183,7 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
176183
level=0,
177184
),
178185
)
179-
186+
180187
for typevar in typevars:
181188
out.body.append(
182189
ast.Assign(
@@ -202,6 +209,8 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
202209
module_attributes[submodule].append(
203210
(b.name, data.name, None, data.typevars_used)
204211
)
212+
if "alias" in (ast.get_docstring(b) or ""):
213+
continue
205214
out.body.append(data.stmt)
206215
elif isinstance(b, ast.Assign):
207216
if submodule == "_types":

0 commit comments

Comments
 (0)