11import logging
22from collections .abc import Callable
3- from typing import TypeVar , overload
3+ from typing import ParamSpec , TypeVar , overload
44
55from pyfuse .graph .graph import Graph
6+ from pyfuse .typing import TraceDecorator , TracedFunction
67
78logger = logging .getLogger (__name__ )
89
9- _F = TypeVar ("_F" , bound = Callable [..., object ])
10+ _R = TypeVar ("_R" )
11+ _P = ParamSpec ("_P" )
1012
1113
1214@overload
13- def trace (func : _F ) -> _F : ...
15+ def trace (func : Callable [ _P , _R ] ) -> TracedFunction [ _P , _R ] : ...
1416@overload
15- def trace (* , timeout : float | None = ..., retries : int = ..., retry_delay : float = ...) -> Callable [[ _F ], _F ] : ...
17+ def trace (* , timeout : float | None = ..., retries : int = ..., retry_delay : float = ...) -> TraceDecorator : ...
1618
1719
1820def trace (
19- func : _F | None = None ,
21+ func : Callable [..., object ] | None = None ,
2022 * ,
2123 timeout : float | None = None ,
2224 retries : int = 0 ,
2325 retry_delay : float = 1.0 ,
24- ) -> _F | Callable [[ _F ], _F ] :
26+ ) -> object :
2527 """Enable a function for serialization and remote execution.
2628
2729 The decorated function works normally when called directly.
@@ -38,19 +40,19 @@ def flaky(x): ...
3840 if func is not None :
3941 return _apply_trace (func , timeout = timeout , retries = retries , retry_delay = retry_delay )
4042
41- def decorator (f : _F ) -> _F :
43+ def decorator (f : Callable [ _P , _R ] ) -> object :
4244 return _apply_trace (f , timeout = timeout , retries = retries , retry_delay = retry_delay )
4345
4446 return decorator
4547
4648
4749def _apply_trace (
48- func : _F ,
50+ func : Callable [ _P , _R ] ,
4951 * ,
5052 timeout : float | None = None ,
5153 retries : int = 0 ,
5254 retry_delay : float = 1.0 ,
53- ) -> _F :
55+ ) -> TracedFunction [ _P , _R ] :
5456 logger .debug ("@trace applied to %s" , func .__qualname__ )
5557 graph = Graph .default ()
5658 graph .register (func )
0 commit comments