@@ -30,6 +30,10 @@ def inner(*args: Any) -> Any:
3030
3131P = ParamSpec ("P" )
3232T = TypeVar ("T" )
33+ Pin = ParamSpec ("Pin" )
34+ Tin = TypeVar ("Tin" )
35+ Pinner = ParamSpec ("Pinner" )
36+ Tinner = TypeVar ("Tinner" )
3337STR_TO_IS_NAMESPACE = {
3438 "numpy" : is_numpy_namespace ,
3539 "jax" : is_jax_namespace ,
@@ -63,8 +67,11 @@ def _default_decorator(
6367 return getattr (module , "jit" , lambda x : x )
6468
6569
70+ Decorator = Callable [[Callable [Pin , Tin ]], Callable [Pin , Tin ]]
71+
72+
6673def jit (
67- decorator : Mapping [str , Callable [[ Callable [ P , T ]], Callable [ P , T ] ]] | None = None ,
74+ decorator : Mapping [str , Decorator [..., Any ]] | None = None ,
6875 / ,
6976 * ,
7077 fail_on_error : bool = False ,
@@ -115,13 +122,13 @@ def jit(
115122
116123 """
117124
118- def new_decorator (f : Callable [P , T ]) -> Callable [P , T ]:
125+ def new_decorator (f : Callable [Pinner , Tinner ]) -> Callable [Pinner , Tinner ]:
119126 decorator_args_ = frozendict (decorator_args or {})
120127 decorator_kwargs_ = frozendict (decorator_kwargs or {})
121128 decorator_ = decorator or {}
122129
123130 @cache
124- def jit_cached (xp : ModuleType ) -> Callable [P , T ]:
131+ def jit_cached (xp : ModuleType ) -> Callable [Pinner , Tinner ]:
125132 for name_ , is_namespace in STR_TO_IS_NAMESPACE .items ():
126133 if is_namespace (xp ):
127134 name = name_
@@ -146,7 +153,7 @@ def jit_cached(xp: ModuleType) -> Callable[P, T]:
146153 return f
147154
148155 @wraps (f )
149- def inner (* args_inner : P .args , ** kwargs_inner : P .kwargs ) -> T :
156+ def inner (* args_inner : Pinner .args , ** kwargs_inner : Pinner .kwargs ) -> Tinner :
150157 try :
151158 xp = array_namespace (* args_inner )
152159 except TypeError as e :
0 commit comments