Skip to content

Commit 1fb1dce

Browse files
committed
Typing
1 parent 45cdfb0 commit 1fb1dce

8 files changed

Lines changed: 53 additions & 21 deletions

File tree

examples/async_execution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
pyfuse.connect("local://localhost:9748")
1616

17-
async def add(a: int, b: int) -> int:
17+
async def add(a: float, b: float) -> float:
1818
return a + b
1919

2020
@trace
@@ -39,12 +39,12 @@ async def main() -> None:
3939
print(f"map results = {results}") # [5.0, 13.0, 17.0]
4040

4141
# 4. asyncio.gather -- submit multiple tasks, await concurrently
42-
results = await asyncio.gather(
42+
gathered = await asyncio.gather(
4343
hypotenuse.run(3.0, 4.0),
4444
hypotenuse.run(5.0, 12.0),
4545
hypotenuse.run(8.0, 15.0),
4646
)
47-
print(f"gather: {results}") # [5.0, 13.0, 17.0]
47+
print(f"gather: {gathered}") # [5.0, 13.0, 17.0]
4848

4949

5050
asyncio.run(main())

examples/remote_execution.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
import asyncio
22
import math
3+
from typing import overload
34

45
import pyfuse
56
from pyfuse import trace
67

78
pyfuse.connect("local://localhost:9748")
89

9-
def add(a: int, b: int) -> int:
10+
@overload
11+
def add(a: int, b: int) -> int: ...
12+
13+
@overload
14+
def add(a: float, b: float) -> float: ...
15+
16+
def add(a: int | float, b: int | float) -> int | float:
1017
return a + b
1118

1219
@trace

pyfuse/graph/decorator.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
import logging
22
from collections.abc import Callable
3-
from typing import TypeVar, overload
3+
from typing import ParamSpec, TypeVar, overload
44

55
from pyfuse.graph.graph import Graph
6+
from pyfuse.typing import TraceDecorator, TracedFunction
67

78
logger = logging.getLogger(__name__)
89

9-
_F = TypeVar("_F", bound=Callable[..., object])
10+
_R = TypeVar("_R")
11+
_P = ParamSpec("_P")
1012

1113

1214
@overload
13-
def trace(func: _F) -> _F: ...
15+
def trace(func: Callable[_P, _R]) -> TracedFunction[_P, _R]: ...
1416
@overload
15-
def trace(*, timeout: float | None = ..., retries: int = ..., retry_delay: float = ...) -> Callable[[_F], _F]: ...
17+
def trace(*, timeout: float | None = ..., retries: int = ..., retry_delay: float = ...) -> TraceDecorator: ...
1618

1719

1820
def trace(
19-
func: _F | None = None,
21+
func: Callable[..., object] | None = None,
2022
*,
2123
timeout: float | None = None,
2224
retries: int = 0,
2325
retry_delay: float = 1.0,
24-
) -> _F | Callable[[_F], _F]:
26+
) -> object:
2527
"""Enable a function for serialization and remote execution.
2628
2729
The decorated function works normally when called directly.
@@ -38,19 +40,19 @@ def flaky(x): ...
3840
if func is not None:
3941
return _apply_trace(func, timeout=timeout, retries=retries, retry_delay=retry_delay)
4042

41-
def decorator(f: _F) -> _F:
43+
def decorator(f: Callable[_P, _R]) -> object:
4244
return _apply_trace(f, timeout=timeout, retries=retries, retry_delay=retry_delay)
4345

4446
return decorator
4547

4648

4749
def _apply_trace(
48-
func: _F,
50+
func: Callable[_P, _R],
4951
*,
5052
timeout: float | None = None,
5153
retries: int = 0,
5254
retry_delay: float = 1.0,
53-
) -> _F:
55+
) -> TracedFunction[_P, _R]:
5456
logger.debug("@trace applied to %s", func.__qualname__)
5557
graph = Graph.default()
5658
graph.register(func)

pyfuse/graph/tracing.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
import threading
1010
from collections.abc import AsyncGenerator, Callable, Generator
1111
from pathlib import Path
12-
from typing import Any, TypeVar
12+
from typing import Any, ParamSpec, TypeVar, cast
1313

14+
from pyfuse.typing import TracedFunction
1415
from pyfuse.worker.backends.base import Backend
1516

1617
logger = logging.getLogger(__name__)
1718

1819
_F = TypeVar("_F", bound=Callable[..., object])
20+
_P = ParamSpec("_P")
21+
_R = TypeVar("_R")
1922

2023
_BUILTIN_NAMES = set(dir(builtins))
2124

@@ -164,7 +167,7 @@ def _record_edge(self, stack: list[str], qualified_name: str) -> None:
164167
with self._lock:
165168
self._runtime_deps.setdefault(stack[-1], set()).add(qualified_name)
166169

167-
def create_wrapper(self, func: _F) -> _F:
170+
def create_wrapper(self, func: Callable[_P, _R]) -> TracedFunction[_P, _R]:
168171
"""Wrap func to record runtime caller-callee edges.
169172
170173
The wrapper preserves the original function signature via
@@ -182,7 +185,7 @@ def create_wrapper(self, func: _F) -> _F:
182185
wrapper = self._wrap_generator(func, qualified_name)
183186
else:
184187
wrapper = self._wrap_sync(func, qualified_name)
185-
return wrapper # type: ignore[no-any-return]
188+
return cast(TracedFunction[_P, _R], wrapper)
186189

187190
def _wrap_async_generator(self, func: Any, qualified_name: str) -> Any:
188191
logger.debug("Creating async generator wrapper for %s", qualified_name)

pyfuse/typing.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from collections.abc import Callable
2+
from typing import Any, ParamSpec, Protocol, TypeVar
3+
4+
from pyfuse.worker.result import Result
5+
6+
P = ParamSpec("P")
7+
R = TypeVar("R")
8+
9+
10+
class TracedFunction(Protocol[P, R]):
11+
__pyfuse_traced__: bool
12+
__wrapped__: Callable[P, R]
13+
14+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
15+
16+
async def start(self, *args: P.args, **kwargs: P.kwargs) -> Result: ...
17+
18+
async def run(self, *args: P.args, **kwargs: P.kwargs) -> Any: ...
19+
20+
async def map(self, args_list: list[tuple[Any, ...]], **kwargs: Any) -> list[Any]: ...
21+
22+
23+
class TraceDecorator(Protocol):
24+
def __call__(self, func: Callable[P, R]) -> TracedFunction[P, R]: ...

pyfuse/worker/backends/rabbitmq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any
1818

1919
try:
20-
import aio_pika
20+
import aio_pika # type: ignore[import-not-found]
2121
except ImportError:
2222
raise ImportError(
2323
"aio-pika package is required for RabbitMQBackend. "

tests/test_packaging.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
from pyfuse.core.version import _FALLBACK_VERSION, _VERSION
1717

18-
pytestmark = pytest.mark.slow
19-
2018

2119
def _project_root() -> Path:
2220
"""Return the repository root (contains pyproject.toml)."""

tests/test_venv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
temp_venv,
2626
)
2727

28-
pytestmark = pytest.mark.slow
29-
3028

3129
class TestFindProjectRoot:
3230
def test_finds_root(self) -> None:

0 commit comments

Comments
 (0)