Skip to content

Commit 9535024

Browse files
committed
fix: wip
1 parent aa87a67 commit 9535024

2 files changed

Lines changed: 12 additions & 5 deletions

File tree

src/array_api_jit/_main.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def inner(*args: Any) -> Any:
3030

3131
P = ParamSpec("P")
3232
T = TypeVar("T")
33+
Pin = ParamSpec("Pin")
34+
Tin = TypeVar("Tin")
35+
Pinner = ParamSpec("Pinner")
36+
Tinner = TypeVar("Tinner")
3337
STR_TO_IS_NAMESPACE = {
3438
"numpy": is_numpy_namespace,
3539
"jax": is_jax_namespace,
@@ -63,8 +67,11 @@ def _default_decorator(
6367
return getattr(module, "jit", lambda x: x)
6468

6569

70+
Decorator = Callable[[Callable[Pin, Tin]], Callable[Pin, Tin]]
71+
72+
6673
def jit(
67-
decorator: Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None = None,
74+
decorator: Mapping[str, Decorator[..., Any]] | None = None,
6875
/,
6976
*,
7077
fail_on_error: bool = False,
@@ -115,13 +122,13 @@ def jit(
115122
116123
"""
117124

118-
def new_decorator(f: Callable[P, T]) -> Callable[P, T]:
125+
def new_decorator(f: Callable[Pinner, Tinner]) -> Callable[Pinner, Tinner]:
119126
decorator_args_ = frozendict(decorator_args or {})
120127
decorator_kwargs_ = frozendict(decorator_kwargs or {})
121128
decorator_ = decorator or {}
122129

123130
@cache
124-
def jit_cached(xp: ModuleType) -> Callable[P, T]:
131+
def jit_cached(xp: ModuleType) -> Callable[Pinner, Tinner]:
125132
for name_, is_namespace in STR_TO_IS_NAMESPACE.items():
126133
if is_namespace(xp):
127134
name = name_
@@ -146,7 +153,7 @@ def jit_cached(xp: ModuleType) -> Callable[P, T]:
146153
return f
147154

148155
@wraps(f)
149-
def inner(*args_inner: P.args, **kwargs_inner: P.kwargs) -> T:
156+
def inner(*args_inner: Pinner.args, **kwargs_inner: Pinner.kwargs) -> Tinner:
150157
try:
151158
xp = array_namespace(*args_inner)
152159
except TypeError as e:

tests/test_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_jit(xp: Any) -> None:
109109
for i in range(20):
110110
with timer() as timer_:
111111
x = xp.arange(1000000, dtype=xp.float32)
112-
p = func(x, 10) # type: ignore
112+
p = func(x, 10)
113113
assert p.shape == (1000000, 10)
114114
t[name] = timer_.elapsed
115115
if i == 0:

0 commit comments

Comments
 (0)