Skip to content

Commit 02535a7

Browse files
committed
chore: wip
1 parent 80c7410 commit 02535a7

4 files changed

Lines changed: 166 additions & 31 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ scripts.array-api = "array_api.cli:app"
3838

3939
[dependency-groups]
4040
dev = [
41+
"array-api-strict>=2.3.1",
4142
"pytest>=8,<9",
4243
"pytest-cov>=6,<7",
4344
]

src/array_api/cli/_main.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -150,31 +150,6 @@ def _attributes_to_protocol(
150150
)
151151

152152

153-
def generate_all(
154-
cache_dir: Path | str = ".cache",
155-
out_path: Path | str = "src/array-api",
156-
out_name: str = "_namespace.py",
157-
) -> None:
158-
import subprocess as sp
159-
160-
Path(cache_dir).mkdir(exist_ok=True)
161-
sp.run(["git", "clone", "https://github.com/data-apis/array-api", ".cache"])
162-
163-
for dir_path in (Path(cache_dir) / Path("src") / "array_api_stubs").iterdir():
164-
if not dir_path.is_dir():
165-
continue
166-
# get module bodies
167-
body_module = {
168-
path.stem: ast.parse(
169-
path.read_text("utf-8")
170-
.replace("Dtype", "dtype")
171-
.replace("Device", "device")
172-
).body
173-
for path in dir_path.rglob("*.py")
174-
}
175-
generate(body_module, Path(out_path) / dir_path.name / out_name)
176-
177-
178153
def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
179154
body_typevars = body_module["_types"]
180155
del body_module["__init__"]
@@ -309,3 +284,27 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
309284
)
310285
out_path.parent.mkdir(parents=True, exist_ok=True)
311286
out_path.write_text(text, "utf-8")
287+
288+
289+
def generate_all(
290+
cache_dir: Path | str = ".cache",
291+
out_path: Path | str = "src/array_api",
292+
) -> None:
293+
import subprocess as sp
294+
295+
Path(cache_dir).mkdir(exist_ok=True)
296+
sp.run(["git", "clone", "https://github.com/data-apis/array-api", ".cache"])
297+
298+
for dir_path in (Path(cache_dir) / Path("src") / "array_api_stubs").iterdir():
299+
if not dir_path.is_dir():
300+
continue
301+
# get module bodies
302+
body_module = {
303+
path.stem: ast.parse(
304+
path.read_text("utf-8")
305+
.replace("Dtype", "dtype")
306+
.replace("Device", "device")
307+
).body
308+
for path in dir_path.rglob("*.py")
309+
}
310+
generate(body_module, (Path(out_path) / dir_path.name).with_suffix(".py"))

tests/test_main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from array_api.main import add
2-
3-
4-
def test_add():
5-
"""Adding two number works as expected."""
6-
assert add(1, 1) == 2
1+
from array_api._2024_12 import ArrayNamespace, add
2+
import array_api_strict
3+
def test_main():
4+
assert isinstance(array_api_strict.add, add)
5+
assert isinstance(array_api_strict, ArrayNamespace)

0 commit comments

Comments
 (0)