Skip to content

Commit a6a3be6

Browse files
committed
chore: wip
1 parent a8b61fa commit a6a3be6

1 file changed

Lines changed: 16 additions & 3 deletions

File tree

src/array_api/cli/_main.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,14 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
140140
body_module.pop("__init__")
141141

142142
# Get all TypeVars
143-
typevars = []
143+
typevars: list[str] = []
144144
for b in body_typevars:
145145
if isinstance(b, ast.Assign):
146146
value = b.value
147147
if isinstance(value, ast.Call):
148148
if value.func.id == "TypeVar":
149-
typevars.append(value.args[0].s)
149+
name = value.args[0].s
150+
typevars.append(name)
150151
print(typevars)
151152

152153
# Dict of module attributes per submodule
@@ -175,6 +176,18 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
175176
level=0,
176177
),
177178
)
179+
180+
for typevar in typevars:
181+
out.body.append(
182+
ast.Assign(
183+
targets=[ast.Name(id=typevar, ctx=ast.Store())],
184+
value=ast.Call(
185+
func=ast.Name(id="TypeVar", ctx=ast.Load()),
186+
args=[ast.Constant(value=typevar)],
187+
keywords=[],
188+
),
189+
)
190+
)
178191

179192
# Create Protocols with __call__, representing functions
180193
for submodule, body in body_module.items():
@@ -239,4 +252,4 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
239252
_attributes_to_protocol("ArrayNamespace", attributes, typevars).stmt
240253
)
241254
out_path.parent.mkdir(parents=True, exist_ok=True)
242-
out_path.write_text(ast.unparse(out), "utf-8")
255+
out_path.write_text(ast.unparse(ast.fix_missing_locations(out)), "utf-8")

0 commit comments

Comments
 (0)