-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_main.py
More file actions
176 lines (152 loc) · 5.72 KB
/
_main.py
File metadata and controls
176 lines (152 loc) · 5.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import importlib.util
import warnings
from collections.abc import Callable, Mapping, Sequence
from functools import cache, wraps
from types import ModuleType
from typing import Any, ParamSpec, TypeVar
from array_api_compat import (
array_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_namespace,
is_numpy_namespace,
is_torch_namespace,
)
if importlib.util.find_spec("numba"):
import numpy as np
from numba.extending import overload
@overload(array_namespace)
def _array_namespace_overload(*args: Any) -> Any:
def inner(*args: Any) -> Any:
return np
return inner
P = ParamSpec("P")
T = TypeVar("T")
Pin = ParamSpec("Pin")
Tin = TypeVar("Tin")
Pinner = ParamSpec("Pinner")
Tinner = TypeVar("Tinner")
STR_TO_IS_NAMESPACE = {
"numpy": is_numpy_namespace,
"jax": is_jax_namespace,
"cupy": is_cupy_namespace,
"torch": is_torch_namespace,
"dask": is_dask_namespace,
}
def _default_decorator(
module: ModuleType,
/,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
if is_jax_namespace(module):
import jax
return jax.jit
elif is_numpy_namespace(module) or is_cupy_namespace(module):
# import numba
# return numba.jit()
# The success rate of numba.jit is low
return lambda x: x
elif is_torch_namespace(module):
import torch
return torch.compile
elif is_dask_namespace(module):
return lambda x: x
else:
return getattr(module, "jit", lambda x: x)
Decorator = Callable[[Callable[Pin, Tin]], Callable[Pin, Tin]]
def jit(
decorator: Mapping[str, Decorator[..., Any]] | None = None,
/,
*,
fail_on_error: bool = False,
rerun_on_error: bool = False,
decorator_args: Mapping[str, Sequence[Any]] | None = None,
decorator_kwargs: Mapping[str, Mapping[str, Any]] | None = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""
Just-in-time compilation decorator with multiple backends.
Parameters
----------
decorator : Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None, optional
The JIT decorator to use for each array namespace, by default None
fail_on_error : bool, optional
If True, raise an error if the JIT decorator fails to apply.
If False, just warn and return the original function, by default False
rerun_on_error : bool, optional
If True, rerun the function without JIT if the function
with JIT applied fails, by default False
decorator_args : Mapping[str, Sequence[Any]] | None, optional
Additional positional arguments to be passed along with the function
to the decorator for each array namespace, by default None
decorator_kwargs : Mapping[str, Mapping[str, Any]] | None, optional
Additional keyword arguments to be passed along with the function
to the decorator for each array namespace, by default None
Returns
-------
Callable[[Callable[P, T]], Callable[P, T]]
The JIT decorator that can be applied to a function.
Example
-------
>>> from array_api_jit import jit
>>> from array_api_compat import array_namespace
>>> from typing import Any
>>> import numba
>>> @jit(
... {"numpy": numba.jit()}, # numba.jit is not used by default
... decorator_kwargs={"jax": {"static_argnames": ["n"]}}, # jax requires static_argnames
... )
... def sin_n_times(x: Any, n: int) -> Any:
... xp = array_namespace(x)
... for i in range(n):
... x = xp.sin(x)
... return x
"""
def new_decorator(f: Callable[Pinner, Tinner]) -> Callable[Pinner, Tinner]:
decorator_args_ = decorator_args or {}
decorator_kwargs_ = decorator_kwargs or {}
decorator_ = decorator or {}
@cache
def jit_cached(xp: ModuleType) -> Callable[Pinner, Tinner]:
for name_, is_namespace in STR_TO_IS_NAMESPACE.items():
if is_namespace(xp):
name = name_
else:
name = xp.__name__.split(".")[0]
decorator_args__ = decorator_args_.get(name, ())
decorator_kwargs__ = decorator_kwargs_.get(name, {})
if name in decorator_:
decorator_current = decorator_[name]
else:
decorator_current = _default_decorator(xp)
try:
return decorator_current(f, *decorator_args__, **decorator_kwargs__)
except Exception as e:
if fail_on_error:
raise RuntimeError(f"Failed to apply JIT decorator for {name}") from e
warnings.warn(
f"Failed to apply JIT decorator for {name}: {e}",
RuntimeWarning,
stacklevel=2,
)
return f
@wraps(f)
def inner(*args_inner: Pinner.args, **kwargs_inner: Pinner.kwargs) -> Tinner:
try:
xp = array_namespace(*args_inner)
except TypeError as e:
if e.args[0] == "Unrecognized array input":
return f(*args_inner, **kwargs_inner)
raise
f_jit = jit_cached(xp)
try:
return f_jit(*args_inner, **kwargs_inner)
except Exception as e:
if rerun_on_error:
warnings.warn(
f"JIT failed for {xp.__name__}: {e}. Rerunning without JIT.",
RuntimeWarning,
stacklevel=2,
)
return f(*args_inner, **kwargs_inner)
raise RuntimeError(f"Failed to run JIT function for {xp.__name__}") from e
return inner
return new_decorator