Skip to content

Commit 2b8facb

Browse files
committed
feat: wip
1 parent 8a6794f commit 2b8facb

5 files changed

Lines changed: 1090 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ classifiers = [
2626
]
2727

2828
dependencies = [
29+
"array-api-compat>=1.11.2",
30+
"frozendict>=2.4.6",
2931
]
3032
urls."Bug Tracker" = "https://github.com/34j/array-api-jit/issues"
3133
urls.Changelog = "https://github.com/34j/array-api-jit/blob/main/CHANGELOG.md"
@@ -34,8 +36,11 @@ urls.repository = "https://github.com/34j/array-api-jit"
3436

3537
[dependency-groups]
3638
dev = [
39+
"jax>=0.4.30",
40+
"numba>=0.60.0",
3741
"pytest>=8,<9",
3842
"pytest-cov>=6,<7",
43+
"torch>=2.7.1",
3944
]
4045
docs = [
4146
"furo>=2023.5.20; python_version>='3.11'",
@@ -45,7 +50,7 @@ docs = [
4550
]
4651

4752
[tool.ruff]
48-
line-length = 88
53+
line-length = 100
4954
lint.select = [
5055
"B", # flake8-bugbear
5156
"D", # flake8-docstrings

src/array_api_jit/main.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,85 @@
1-
def add(n1: int, n2: int) -> int:
2-
"""Add the arguments."""
3-
return n1 + n2
1+
# https://github.com/search?q=gumerov+translation+language%3APython&type=code&l=Python
2+
from collections.abc import Mapping, Sequence
3+
from functools import cache, wraps
4+
from types import ModuleType
5+
from typing import Any, Callable, ParamSpec, TypeVar
6+
7+
from array_api_compat import array_namespace
8+
from frozendict import frozendict
9+
10+
P = ParamSpec("P")
11+
T = TypeVar("T")
12+
13+
14+
def _default_decorator(
15+
module: ModuleType, /, *args: Any, **kwargs: Any
16+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
17+
if "numpy" in module.__name__:
18+
import numba
19+
20+
return numba.jit(*args, **kwargs)
21+
elif "torch" in module.__name__:
22+
import torch
23+
24+
return torch.compile
25+
else:
26+
return getattr(module, "jit", lambda x: x)(*args, **kwargs)
27+
28+
29+
def jit(
30+
decorator: Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None = None,
31+
/,
32+
*,
33+
decorator_args: Mapping[str, Sequence[Any]] | None = None,
34+
decorator_kwargs: Mapping[str, Mapping[str, Any]] | None = None,
35+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
36+
"""
37+
Just-in-time compilation decorator with multiple backends.
38+
39+
Parameters
40+
----------
41+
decorator : Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None, optional
42+
The JIT decorator to use for each array namespace, by default None
43+
decorator_args : Mapping[str, Sequence[Any]] | None, optional
44+
Additional positional arguments for the decorator for each array namespace, by default None
45+
decorator_kwargs : Mapping[str, Mapping[str, Any]] | None, optional
46+
Additional keyword arguments for the decorator for each array namespace, by default None
47+
48+
Returns
49+
-------
50+
Callable[[Callable[P, T]], Callable[P, T]]
51+
The JIT decorator that can be applied to a function.
52+
53+
"""
54+
55+
def new_decorator(f: Callable[P, T]) -> Callable[P, T]:
56+
decorator_args_ = frozendict(decorator_args or {})
57+
decorator_kwargs_ = frozendict(decorator_kwargs or {})
58+
decorator_ = decorator or {}
59+
60+
@cache
61+
def jit_cached(xp: ModuleType) -> Callable[P, T]:
62+
decorator_args__ = decorator_args_.get(xp, ())
63+
decorator_kwargs__ = decorator_kwargs_.get(xp, {})
64+
name = xp.__name__
65+
if name in decorator_:
66+
decorator_current = decorator_[name]
67+
else:
68+
decorator_current = _default_decorator(xp, *decorator_args_, **decorator_kwargs_)
69+
70+
return decorator_current(f, *decorator_args__, **decorator_kwargs__)
71+
72+
@wraps(f)
73+
def inner(*args_inner: P.args, **kwargs_inner: P.kwargs) -> T:
74+
try:
75+
xp = array_namespace(*args_inner)
76+
except TypeError as e:
77+
if e.args[0] == "Unrecognized array input":
78+
return f(*args_inner, **kwargs_inner)
79+
raise
80+
f_jit = jit_cached(xp)
81+
return f_jit(*args_inner, **kwargs_inner)
82+
83+
return inner
84+
85+
return new_decorator

tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Any
2+
3+
import pytest
4+
5+
6+
@pytest.fixture(scope="session", params=["numpy", "torch"])
7+
def xp(request: pytest.FixtureRequest) -> Any:
8+
"""Get the array namespace for the given backend."""
9+
backend = request.param
10+
if backend == "numpy":
11+
import numpy as xp
12+
elif backend == "torch":
13+
import torch as xp
14+
else:
15+
raise ValueError(f"Unknown backend: {backend}")
16+
return xp

tests/test_main.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,26 @@
1-
from array_api_jit.main import add
1+
from typing import Any
22

3+
import array_api_compat.torch as torch
4+
from array_api_compat import array_namespace
35

4-
def test_add():
5-
"""Adding two number works as expected."""
6-
assert add(1, 1) == 2
6+
from array_api_jit.main import jit
7+
8+
9+
def test_jit(xp: Any) -> None:
10+
@jit(decorator_kwargs={"numpy": {"forceobj": True, "nopython": False}}) # type: ignore
11+
def spherical_coordinates(
12+
r: torch.Tensor, theta: torch.Tensor, phi: torch.Tensor
13+
) -> torch.Tensor:
14+
xp = array_namespace(r, theta, phi)
15+
rsin = r * xp.sin(theta)
16+
x = rsin * xp.cos(phi)
17+
y = rsin * xp.sin(phi)
18+
z = r * xp.cos(theta)
19+
return xp.stack((x, y, z), axis=-1)
20+
21+
r = xp.arange(100)
22+
theta = xp.linspace(0, xp.pi, 100)
23+
phi = xp.linspace(0, 2 * xp.pi, 100)
24+
x = spherical_coordinates(r, theta, phi) # type: ignore
25+
print(x.shape)
26+
assert x.shape == (100, 3)

0 commit comments

Comments
 (0)