Skip to content

Commit 24dae32

Browse files
committed
chore: wip
1 parent 6508d71 commit 24dae32

5 files changed

Lines changed: 179 additions & 26 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ classifiers = [
2727

2828
dependencies = [
2929
"array-api-compat>=1.11.2",
30+
"cm-time>=0.1.2",
3031
"frozendict>=2.4.6",
3132
]
3233
urls."Bug Tracker" = "https://github.com/34j/array-api-jit/issues"

src/array_api_jit/main.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# https://github.com/search?q=gumerov+translation+language%3APython&type=code&l=Python
22
import importlib.util
3+
import warnings
34
from collections.abc import Mapping, Sequence
45
from functools import cache, wraps
56
from types import ModuleType
@@ -28,7 +29,11 @@ def _default_decorator(
2829
module: ModuleType,
2930
/,
3031
) -> Callable[[Callable[P, T]], Callable[P, T]]:
31-
if "numpy" in module.__name__:
32+
if "jax" in module.__name__:
33+
import jax
34+
35+
return jax.jit
36+
elif any(x in module.__name__ for x in ("numpy", "cupy")):
3237
import numba
3338

3439
return numba.jit()
@@ -44,6 +49,8 @@ def jit(
4449
decorator: Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None = None,
4550
/,
4651
*,
52+
fail_on_error: bool = False,
53+
rerun_on_error: bool = False,
4754
decorator_args: Mapping[str, Sequence[Any]] | None = None,
4855
decorator_kwargs: Mapping[str, Mapping[str, Any]] | None = None,
4956
) -> Callable[[Callable[P, T]], Callable[P, T]]:
@@ -54,10 +61,18 @@ def jit(
5461
----------
5562
decorator : Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None, optional
5663
The JIT decorator to use for each array namespace, by default None
64+
fail_on_error : bool, optional
65+
If True, raise an error if the JIT decorator fails to apply.
66+
If False, just warn and return the original function, by default False
67+
rerun_on_error : bool, optional
68+
If True, rerun the function without JIT if the function
69+
with JIT applied fails, by default False
5770
decorator_args : Mapping[str, Sequence[Any]] | None, optional
58-
Additional positional arguments for the decorator for each array namespace, by default None
71+
Additional positional arguments to be passed along with the function
72+
to the decorator for each array namespace, by default None
5973
decorator_kwargs : Mapping[str, Mapping[str, Any]] | None, optional
60-
Additional keyword arguments for the decorator for each array namespace, by default None
74+
Additional keyword arguments to be passed along with the function
75+
to the decorator for each array namespace, by default None
6176
6277
Returns
6378
-------
@@ -73,15 +88,26 @@ def new_decorator(f: Callable[P, T]) -> Callable[P, T]:
7388

7489
@cache
7590
def jit_cached(xp: ModuleType) -> Callable[P, T]:
76-
decorator_args__ = decorator_args_.get(xp, ())
77-
decorator_kwargs__ = decorator_kwargs_.get(xp, {})
7891
name = xp.__name__
92+
name = name.replace("array_api_compat.", "")
93+
name = name.split(".")[0]
94+
decorator_args__ = decorator_args_.get(name, ())
95+
decorator_kwargs__ = decorator_kwargs_.get(name, {})
7996
if name in decorator_:
8097
decorator_current = decorator_[name]
8198
else:
8299
decorator_current = _default_decorator(xp)
83-
84-
return decorator_current(f, *decorator_args__, **decorator_kwargs__)
100+
try:
101+
return decorator_current(f, *decorator_args__, **decorator_kwargs__)
102+
except Exception as e:
103+
if fail_on_error:
104+
raise RuntimeError(f"Failed to apply JIT decorator for {name}") from e
105+
warnings.warn(
106+
f"Failed to apply JIT decorator for {name}: {e}",
107+
RuntimeWarning,
108+
stacklevel=2,
109+
)
110+
return f
85111

86112
@wraps(f)
87113
def inner(*args_inner: P.args, **kwargs_inner: P.kwargs) -> T:
@@ -92,7 +118,17 @@ def inner(*args_inner: P.args, **kwargs_inner: P.kwargs) -> T:
92118
return f(*args_inner, **kwargs_inner)
93119
raise
94120
f_jit = jit_cached(xp)
95-
return f_jit(*args_inner, **kwargs_inner)
121+
try:
122+
return f_jit(*args_inner, **kwargs_inner)
123+
except Exception as e:
124+
if rerun_on_error:
125+
warnings.warn(
126+
f"JIT failed for {xp.__name__}: {e}. Rerunning without JIT.",
127+
RuntimeWarning,
128+
stacklevel=2,
129+
)
130+
return f(*args_inner, **kwargs_inner)
131+
raise RuntimeError(f"Failed to run JIT function for {xp.__name__}") from e
96132

97133
return inner
98134

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
import pytest
44

55

6-
@pytest.fixture(scope="session", params=["numpy", "torch"])
6+
@pytest.fixture(scope="session", params=["numpy", "torch", "jax"])
77
def xp(request: pytest.FixtureRequest) -> Any:
88
"""Get the array namespace for the given backend."""
99
backend = request.param
1010
if backend == "numpy":
1111
import numpy as xp
1212
elif backend == "torch":
1313
import torch as xp
14+
elif backend == "jax":
15+
import jax.numpy as xp
1416
else:
1517
raise ValueError(f"Unknown backend: {backend}")
1618
return xp

tests/test_main.py

Lines changed: 117 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,125 @@
11
from typing import Any
22

3-
import array_api_compat.torch as torch
3+
import numba
4+
import numpy as np
45
from array_api_compat import array_namespace
6+
from cm_time import timer
7+
from numba import prange
8+
from numba.extending import overload
59

610
from array_api_jit.main import jit
711

812

13+
@overload(np.stack)
14+
def stack(arrays, axis=0):
15+
def inner(arrays, axis=0):
16+
if axis == 0:
17+
shape = (len(arrays),) + arrays[0].shape # noqa: RUF005
18+
stacked_array = np.empty(shape, dtype=arrays[0].dtype)
19+
for j in prange(len(arrays)):
20+
stacked_array[j] = arrays[j]
21+
elif axis == -1:
22+
shape = arrays[0].shape + (len(arrays),) # noqa: RUF005
23+
stacked_array = np.empty(shape, dtype=arrays[0].dtype)
24+
for j in prange(len(arrays)):
25+
stacked_array[..., j] = arrays[j]
26+
return stacked_array
27+
28+
return inner
29+
30+
31+
def legendre(x: Any, n_end: int) -> Any:
32+
"""
33+
Legendre polynomial of order 0 to n_end-1.
34+
35+
Parameters
36+
----------
37+
x : Any
38+
The points at which to evaluate the polynomial.
39+
n_end : int
40+
The number of orders to compute, starting from 0.
41+
42+
Returns
43+
-------
44+
Any
45+
Array of shape (*x.shape, n_end),
46+
where [..., i] is the value of the i-th order polynomial at x.
47+
48+
"""
49+
xp = array_namespace(x)
50+
prevprev = xp.zeros_like(x)
51+
prev = xp.ones_like(x)
52+
result = [prevprev, prev]
53+
for n in range(2, n_end):
54+
prevprev, prev = prev, (((2 * n - 1) * x * prev - (n - 1) * prevprev) / n)
55+
result.append(prev)
56+
return xp.stack(result[:n_end], axis=-1)
57+
58+
59+
def legendre_assign(x: Any, n_end: int) -> Any:
60+
"""
61+
Legendre polynomial of order 0 to n_end-1.
62+
63+
Parameters
64+
----------
65+
x : Any
66+
The points at which to evaluate the polynomial.
67+
n_end : int
68+
The number of orders to compute, starting from 0.
69+
70+
Returns
71+
-------
72+
Any
73+
Array of shape (*x.shape, n_end),
74+
where [..., i] is the value of the i-th order polynomial at x.
75+
76+
"""
77+
xp = array_namespace(x)
78+
prevprev = xp.zeros_like(x)
79+
prev = xp.ones_like(x)
80+
result = xp.empty((*x.shape, n_end), dtype=x.dtype)
81+
if n_end > 0:
82+
result[..., 0] = prevprev
83+
if n_end > 1:
84+
result[..., 1] = prev
85+
for n in range(2, n_end):
86+
prevprev, prev = prev, (((2 * n - 1) * x * prev - (n - 1) * prevprev) / n)
87+
result[..., n] = prev
88+
return result
89+
90+
91+
legendre_jit = jit(
92+
{"numpy": numba.jit(nogil=True, parallel=True)},
93+
decorator_kwargs={"jax": {"static_argnames": ["n_end"]}},
94+
)(legendre)
95+
legendre_assign_jit = jit({"numpy": numba.jit(nogil=True, parallel=True)})(legendre_assign)
96+
97+
998
def test_jit(xp: Any) -> None:
10-
@jit() # type: ignore
11-
def spherical_coordinates(
12-
r: torch.Tensor, theta: torch.Tensor, phi: torch.Tensor
13-
) -> torch.Tensor:
14-
xp = array_namespace(r, theta, phi)
15-
rsin = r * xp.sin(theta)
16-
x = rsin * xp.cos(phi)
17-
y = rsin * xp.sin(phi)
18-
z = r * xp.cos(theta)
19-
return xp.stack((x, y, z), axis=-1)
20-
21-
r = xp.arange(100)
22-
theta = xp.linspace(0, xp.pi, 100)
23-
phi = xp.linspace(0, 2 * xp.pi, 100)
24-
x = spherical_coordinates(r, theta, phi) # type: ignore
25-
assert x.shape == (100, 3)
99+
t = {}
100+
for name, func in [("nojit", legendre), ("jit", legendre_jit)] + (
101+
[
102+
("assign-nojit", legendre_assign),
103+
("assign-jit", legendre_assign_jit),
104+
]
105+
if "jax" not in xp.__name__
106+
else []
107+
):
108+
print(xp.__name__, name)
109+
for i in range(20):
110+
with timer() as timer_:
111+
x = xp.arange(1000000, dtype=xp.float32)
112+
p = func(x, 10) # type: ignore
113+
assert p.shape == (1000000, 10)
114+
t[name] = timer_.elapsed
115+
if i == 0:
116+
print(f"First call: {timer_.elapsed:g}s")
117+
print(f"Last call: {timer_.elapsed:g}s")
118+
if "numpy" not in xp.__name__:
119+
assert t["jit"] < t["nojit"], (
120+
f"JIT time {t['jit']} should be less than non-JIT time {t['nojit']}"
121+
)
122+
assert t["assign-jit"] < t["assign-nojit"], (
123+
f"JIT assign time {t['assign-jit']} should be "
124+
f"less than non-JIT assign time {t['assign-nojit']}"
125+
)

uv.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)