Skip to content

Commit 4f19ef1

Browse files
committed
chore: wip
1 parent d81f68e commit 4f19ef1

1 file changed

Lines changed: 39 additions & 39 deletions

File tree

src/array_api/cli/_main.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,25 @@
22

33
import ast
44
from collections import defaultdict
5-
from collections.abc import Mapping, Sequence
5+
from collections.abc import Iterable, Mapping, Sequence
66
from copy import deepcopy
77
from pathlib import Path
88

99
import attrs
1010

11+
@attrs.frozen()
12+
class TypeVarInfo:
13+
name: str
14+
bound: str | None = None
1115

1216
@attrs.frozen()
1317
class ProtocolData:
1418
stmt: ast.ClassDef
15-
typevars_used: Sequence[str]
19+
typevars_used: Iterable[TypeVarInfo]
1620
name: str
1721

1822

19-
def _function_to_protocol(stmt: ast.FunctionDef, typevars: list[str]) -> ProtocolData:
23+
def _function_to_protocol(stmt: ast.FunctionDef, typevars: list[TypeVarInfo]) -> ProtocolData:
2024
"""
2125
Convert a function definition to a Protocol class.
2226
@@ -41,51 +45,48 @@ def _function_to_protocol(stmt: ast.FunctionDef, typevars: list[str]) -> Protoco
4145
stmt.args.posonlyargs.insert(0, ast.arg(arg="self"))
4246
stmt.decorator_list.append(ast.Name(id="abstractmethod"))
4347
args = ast.unparse(stmt.args)
44-
typevars = [typevar for typevar in typevars if typevar in args]
48+
typevars = [typevar for typevar in typevars if typevar.name in args]
4549

4650
# Construct the protocol
4751
stmt_new = ast.ClassDef(
4852
name=name,
4953
decorator_list=[ast.Name(id="runtime_checkable")],
5054
keywords=[],
5155
bases=[
52-
ast.Subscript(
53-
value=ast.Name(id="Protocol"),
54-
slice=ast.Tuple(elts=[ast.Name(typevar) for typevar in typevars]),
55-
)
56+
ast.Name(id="Protocol"),
5657
],
5758
body=(
5859
[ast.Expr(value=ast.Constant(docstring, kind=None))]
5960
if docstring is not None
6061
else []
6162
)
6263
+ [stmt],
63-
type_params=[],
64+
type_params=[ast.TypeVar(name=t.name, bound=ast.Name(id=t.bound) if t.bound else None) for t in typevars],
6465
) # type: ignore[call-arg]
6566
return ProtocolData(
6667
stmt=stmt_new,
6768
typevars_used=typevars,
68-
name=name + (f"[{', '.join(typevars)}]" if typevars else ""),
69+
name=name + (f"[{', '.join([t.name for t in typevars])}]" if typevars else ""),
6970
)
7071

7172

72-
def _class_to_protocol(stmt: ast.ClassDef, typevars: list[str]) -> ProtocolData:
73-
typevars = [typevar for typevar in typevars if typevar in ast.unparse(stmt)]
73+
def _class_to_protocol(stmt: ast.ClassDef, typevars: list[TypeVarInfo]) -> ProtocolData:
74+
typevars = [typevar for typevar in typevars if typevar.name in ast.unparse(stmt)]
7475
stmt.bases = [
75-
ast.Subscript(
76-
value=ast.Name(id="Protocol"),
77-
slice=ast.Tuple(elts=[ast.Name(typevar) for typevar in typevars]),
78-
)
76+
ast.Name(id="Protocol"),
77+
]
78+
stmt.type_params = [
79+
ast.TypeVar(name=t.name, bound=ast.Name(id=t.bound) if t.bound else None) for t in typevars
7980
]
8081
return ProtocolData(
8182
stmt=stmt,
8283
typevars_used=typevars,
83-
name=stmt.name + (f"[{', '.join(typevars)}]" if typevars else ""),
84+
name=stmt.name + (f"[{', '.join([t.name for t in typevars])}]" if typevars else ""),
8485
)
8586

8687

8788
def _attributes_to_protocol(
88-
name, attributes: list[tuple[str, str, str | None, list]], typevars: list[str]
89+
name, attributes: list[tuple[str, str, str | None, list[TypeVarInfo]]]
8990
) -> ProtocolData:
9091
body = []
9192
for attribute, type, docstring, _ in attributes:
@@ -106,16 +107,15 @@ def _attributes_to_protocol(
106107
decorator_list=[ast.Name(id="runtime_checkable")],
107108
keywords=[],
108109
bases=[
109-
ast.Subscript(
110-
value=ast.Name(id="Protocol"),
111-
slice=ast.Tuple(elts=[ast.Name(typevar) for typevar in typevars]),
112-
)
110+
ast.Name(id="Protocol"),
113111
],
114112
body=body,
115-
type_params=[],
113+
type_params=[
114+
ast.TypeVar(name=t.name, bound=ast.Name(id=t.bound) if t.bound else None) for t in typevars
115+
],
116116
),
117117
typevars_used=typevars,
118-
name=name + (f"[{', '.join(typevars)}]" if typevars else ""),
118+
name=name + (f"[{', '.join([t.name for t in typevars])}]" if typevars else ""),
119119
)
120120

121121

@@ -147,14 +147,14 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
147147
body_module.pop("__init__")
148148

149149
# Get all TypeVars
150-
typevars: list[str] = []
150+
typevars: list[TypeVarInfo] = []
151151
for b in body_typevars:
152152
if isinstance(b, ast.Assign):
153153
value = b.value
154154
if isinstance(value, ast.Call):
155155
if value.func.id == "TypeVar":
156156
name = value.args[0].s
157-
typevars.append(name)
157+
typevars.append(TypeVarInfo(name=name, bound=None))
158158
print(typevars)
159159

160160
# Dict of module attributes per submodule
@@ -184,17 +184,17 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
184184
),
185185
)
186186

187-
for typevar in typevars:
188-
out.body.append(
189-
ast.Assign(
190-
targets=[ast.Name(id=typevar, ctx=ast.Store())],
191-
value=ast.Call(
192-
func=ast.Name(id="TypeVar", ctx=ast.Load()),
193-
args=[ast.Constant(value=typevar)],
194-
keywords=[],
195-
),
196-
)
197-
)
187+
# for typevar in typevars:
188+
# out.body.append(
189+
# ast.Assign(
190+
# targets=[ast.Name(id=typevar, ctx=ast.Store())],
191+
# value=ast.Call(
192+
# func=ast.Name(id="TypeVar", ctx=ast.Load()),
193+
# args=[ast.Constant(value=typevar)],
194+
# keywords=[],
195+
# ),
196+
# )
197+
# )
198198

199199
# Create Protocols with __call__, representing functions
200200
for submodule, body in body_module.items():
@@ -244,7 +244,7 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
244244
if submodule not in OPTIONAL_SUBMODULES:
245245
continue
246246
data = _attributes_to_protocol(
247-
submodule[0].upper() + submodule[1:] + "Namespace", attributes, typevars
247+
submodule[0].upper() + submodule[1:] + "Namespace", attributes
248248
)
249249
out.body.append(data.stmt)
250250
if submodule in OPTIONAL_SUBMODULES:
@@ -258,7 +258,7 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
258258
if submodule not in OPTIONAL_SUBMODULES
259259
] + submodules
260260
out.body.append(
261-
_attributes_to_protocol("ArrayNamespace", attributes, typevars).stmt
261+
_attributes_to_protocol("ArrayNamespace", attributes).stmt
262262
)
263263
out_path.parent.mkdir(parents=True, exist_ok=True)
264264
out_path.write_text(ast.unparse(ast.fix_missing_locations(out)), "utf-8")

0 commit comments

Comments
 (0)