Skip to content

Commit 25557e4

Browse files
committed
cancellation and progress
1 parent e8bf3b3 commit 25557e4

21 files changed

Lines changed: 1116 additions & 40 deletions

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ Common mappings (`cv2` -> `opencv-python`, `PIL` -> `Pillow`, etc.) are built in
8181
- **Third-party package auto-install** -- Workers install missing packages via pip before execution.
8282
- **Async-native** -- The entire I/O layer is built on `asyncio`. `.run()`, `.start()`, `.map()`, `await result`, and `asyncio.gather` all work out of the box.
8383
- **Heartbeat & stall detection** -- Workers send periodic heartbeats. Clients raise `TaskStalled` when a worker stops responding.
84+
- **Task cancellation** -- `await future.cancel()` cancels pending or in-progress tasks.
85+
- **Progress reporting** -- Call `pyfuse.progress(75.0)` or `pyfuse.progress(3, 10)` inside tasks; query with `await future.progress()`.
86+
- **Graceful shutdown** -- Workers finish in-progress tasks before stopping. Second Ctrl+C force-quits.
8487
- **Class methods** -- `self.method()` and `cls.method()` dependencies are detected. Entire class hierarchies (including `super()`), class-level attributes, decorators (`@dataclass`, etc.), and metaclass keywords are reconstructed.
8588
- **Retry and timeout** -- `@trace(timeout=30, retries=3)` with exponential backoff.
8689
- **Batch submission** -- `await func.map([(a1, b1), (a2, b2)])` submits and awaits multiple tasks.
@@ -98,6 +101,8 @@ pyfuse run examples/remote_execution.py
98101
- **[`examples/remote_execution.py`](examples/remote_execution.py)** -- Remote execution with auto-discovered dependencies
99102
- **[`examples/async_execution.py`](examples/async_execution.py)** -- Async: `.run()`, `.start()`, `.map()`, `asyncio.gather`
100103
- **[`examples/package_installation.py`](examples/package_installation.py)** -- Auto-installing third-party packages on workers
104+
- **[`examples/progress_reporting.py`](examples/progress_reporting.py)** -- Real-time progress tracking from long-running tasks
105+
- **[`examples/cancellation.py`](examples/cancellation.py)** -- Cancelling pending or in-progress tasks
101106
- **[`examples/large_module.py`](examples/large_module.py)** -- Stress test: 47 functions across 7 files, one `@trace`
102107

103108
## Documentation

docs/CONTEXT.md

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ pyfuse/
1818
│ ├── task.py # Task dataclass: serializable envelope (graph + args + options)
1919
│ ├── models.py # FunctionNode and ImportInfo dataclasses, content hashing
2020
│ ├── version.py # _VERSION = "0.4.0"
21-
│ └── errors.py # Error, WorkerError, RemoteError, DependencyError, TaskStalled
21+
│ ├── errors.py # Error, WorkerError, RemoteError, DependencyError, TaskStalled, TaskCancelled
22+
│ └── progress.py # ProgressInfo dataclass, progress() function, context variable
2223
├── graph/
2324
│ ├── decorator.py # @trace: marks functions, adds .run()/.start()/.map()
2425
│ ├── graph.py # Graph class: registration, auto-discovery, serialization
@@ -54,9 +55,9 @@ Data models and error types. `FunctionNode` represents a function in the graph,
5455

5556
Handles remote execution. Built entirely on `asyncio`:
5657

57-
- **`worker.py`**: `Worker` class reconstructs functions from serialized stores, caches compiled namespaces by subgraph content hash, and executes with retry/timeout policies. Async user functions are awaited directly; sync user functions run in `loop.run_in_executor()`. Timeouts use `asyncio.wait_for()`.
58-
- **`remote.py`**: Orchestrates the connection lifecycle, worker event loop (`asyncio.TaskGroup` + `asyncio.Semaphore` for bounded concurrency), and heartbeat tasks (`asyncio.create_task`).
59-
- **`result.py`**: `Result` is an awaitable future returned by `.start()`. Simple async polling loop for stall detection.
58+
- **`worker.py`**: `Worker` class reconstructs functions from serialized stores, caches compiled namespaces by subgraph content hash, and executes with retry/timeout policies. Async user functions are awaited directly; sync user functions run in `loop.run_in_executor()` with explicit context propagation via `contextvars.copy_context()`. Timeouts use `asyncio.wait_for()`.
59+
- **`remote.py`**: Orchestrates the connection lifecycle, worker event loop (`asyncio.Semaphore` for bounded concurrency), heartbeat tasks (`asyncio.create_task`), progress injection, cancellation checking, and graceful shutdown via signal handling.
60+
- **`result.py`**: `Result` is an awaitable future returned by `.start()`. Simple async polling loop for stall detection. Supports `cancel()` and `progress()` methods.
6061
- **`deps.py`**: Package installation via `asyncio.create_subprocess_exec`.
6162
- **`backends/`**: All backend methods are `async def`. `listen()` and `subscribe_results()` are async generators.
6263

@@ -99,8 +100,11 @@ await Worker.run(task)
99100
- **Decorator stripping**: `@trace` lines are removed from captured source so reconstructed code doesn't depend on pyfuse.
100101
- **Backend auto-detection**: `connect()` picks Redis or local TCP backend based on URL scheme. Falls back to `PYFUSE_BACKEND` env var.
101102
- **Worker caching**: Keyed by SHA-256 of all reachable content hashes (sorted + joined). Same code from different clients = cache hit.
102-
- **Async-native I/O**: All backend methods, worker execution, result handling, pip installation, and subprocess management use `asyncio`. Sync user functions run in `loop.run_in_executor()` to avoid blocking the event loop.
103+
- **Async-native I/O**: All backend methods, worker execution, result handling, pip installation, and subprocess management use `asyncio`. Sync user functions run in `loop.run_in_executor()` with explicit `contextvars.copy_context()` to propagate progress callbacks.
103104
- **Heartbeat**: Workers send heartbeats via `asyncio.create_task`. Client-side stall detection tracks when heartbeat *values* last changed using local monotonic clock (no cross-machine timestamp comparison).
105+
- **Task cancellation**: Cooperative via backend flags. Workers check before execution; clients store a "cancelled" result envelope. `TaskCancelled` is raised on await.
106+
- **Progress reporting**: Uses `contextvars.ContextVar` for the progress callback. Sync functions get context propagated via explicit copy. Progress updates are fire-and-forget async tasks.
107+
- **Graceful shutdown**: Signal-based (`SIGINT`/`SIGTERM`). First signal stops the listener and waits for in-flight tasks. Second signal cancels all tasks.
104108

105109
## Serialization format (v0.4.0)
106110

@@ -146,7 +150,16 @@ results = await func.map([(a1, b1), ...]) # batch submit + await all (returns va
146150
result = await future # shorthand for await future.result()
147151
result = await future.result(timeout=10, stall_timeout=10.0) # with options
148152
await future.done() # non-blocking check
149-
await future.status() # "pending", "success", or "error"
153+
await future.status() # "pending", "success", "error", or "cancelled"
154+
155+
# Cancellation
156+
await future.cancel() # cancel task; raises TaskCancelled when awaited
157+
158+
# Progress reporting
159+
pyfuse.progress(75.0) # report percentage (no-op locally)
160+
pyfuse.progress(3, 10, message="step 3") # report current/total with message
161+
p = await future.progress() # get latest ProgressInfo (or None)
162+
if p: print(f"{p.current}/{p.total} {p.percent:.0f}%")
150163

151164
# Serialization (sync -- pure CPU)
152165
pyfuse.serialize(func) # -> JSON string
@@ -166,7 +179,7 @@ pytest # run all tests
166179
pytest tests/test_api.py # specific module
167180
```
168181

169-
15 test modules covering: API surface, AST analysis, async features (Result.result, await, .run(), .start(), .map(), gather, heartbeat, stall detection), auto-discovery (including metaclass keywords, class attributes, class decorators, `__init_subclass__`), dependency management, graph operations, integration scenarios, local backend (async-native TCP), remote execution, runtime tracing (including closure capture of non-traced functions, lambdas, constructor expressions, pickle fallback), store operations, stress tests (47 functions across 7 files), task serialization, temp venv management, and worker caching/execution.
182+
16 test modules covering: API surface, AST analysis, async features (Result.result, await, .run(), .start(), .map(), gather, heartbeat, stall detection), auto-discovery (including metaclass keywords, class attributes, class decorators, `__init_subclass__`), cancellation and progress reporting, dependency management, graph operations, integration scenarios, local backend (async-native TCP), remote execution, runtime tracing (including closure capture of non-traced functions, lambdas, constructor expressions, pickle fallback), store operations, stress tests (47 functions across 7 files), task serialization, temp venv management, and worker caching/execution.
170183

171184
All async tests use `pytest-asyncio` with `asyncio_mode = "auto"`.
172185

@@ -185,7 +198,7 @@ pytest # test suite
185198
- `analyzer.py` is the core of static analysis (~365 lines). Changes here affect what gets captured.
186199
- `tracing.py` uses `contextvars.ContextVar` for thread/async safety. The `_runtime_deps` dict is guarded by `threading.Lock`.
187200
- The `Task` wire format keeps `graph` as a JSON string (not nested object) to keep the envelope flat.
188-
- Backend implementations must satisfy the `Backend` ABC in `backends/base.py`. All methods are `async def`. New methods (`notify_result`, `subscribe_results`, `get_heartbeats`) are non-abstract with safe defaults -- custom backends don't break.
201+
- Backend implementations must satisfy the `Backend` ABC in `backends/base.py`. All methods are `async def`. New methods (`notify_result`, `subscribe_results`, `get_heartbeats`, `cancel_task`, `is_cancelled`, `send_progress`, `get_progress`) are non-abstract with safe defaults -- custom backends don't break.
189202
- `install_package_as()` is a no-op at runtime; the AST analyzer in `decorator.py`/`analyzer.py` detects the `with` block pattern and tags `ImportInfo` objects with the package name.
190203
- `_capture_closure()` in `graph.py` uses a multi-tier strategy: repr validation → traced functions → lambdas (source extraction) → non-traced user functions (auto-registration) → constructor expressions (defaultdict/Counter/deque) → pickle fallback → warning. Returns function objects for auto-registration.
191204
- `_set_class_metadata()` in `graph.py` captures class-level attributes and decorators from the class source AST. Called from both `_auto_register_class` and `_discover_self_call_deps` to handle both constructor-discovered and directly-traced method classes.

docs/QUICK_START.md

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,57 @@ result = await to_yaml.run({"key": "value"})
118118

119119
Common mappings like `cv2` -> `opencv-python` and `PIL` -> `Pillow` are built in.
120120

121+
## Task cancellation
122+
123+
Cancel a pending or in-progress task with `await future.cancel()`:
124+
125+
```python
126+
from pyfuse import TaskCancelled
127+
128+
future = await slow_task.start(data)
129+
await asyncio.sleep(1)
130+
await future.cancel()
131+
132+
try:
133+
result = await future
134+
except TaskCancelled:
135+
print("Task was cancelled")
136+
```
137+
138+
If the worker hasn't started execution yet, it skips the task entirely. If execution is already in progress, the client receives `TaskCancelled` when awaiting the result.
139+
140+
## Progress reporting
141+
142+
Long-running tasks can report progress back to the client:
143+
144+
```python
145+
from pyfuse import trace, progress
146+
147+
@trace
148+
def process_batch(items: list[str]) -> list[str]:
149+
results = []
150+
for i, item in enumerate(items):
151+
results.append(transform(item))
152+
progress(i + 1, len(items), message=f"Processing {item}")
153+
return results
154+
```
155+
156+
Query progress from the client:
157+
158+
```python
159+
future = await process_batch.start(items)
160+
161+
while not await future.done():
162+
p = await future.progress()
163+
if p is not None:
164+
print(f"{p.current}/{p.total} ({p.percent:.0f}%) - {p.message}")
165+
await asyncio.sleep(0.5)
166+
167+
result = await future
168+
```
169+
170+
`progress()` is a silent no-op when called outside a worker, so the function works unchanged locally.
171+
121172
## Async API
122173

123174
pyfuse is async-native. All remote execution methods are coroutines.
@@ -259,7 +310,14 @@ future = await hypotenuse.start(3.0, 4.0)
259310

260311
# Check status without blocking
261312
await future.done() # True / False
262-
await future.status() # "pending", "success", or "error"
313+
await future.status() # "pending", "success", "error", or "cancelled"
314+
315+
# Check progress (from pyfuse.progress() calls)
316+
p = await future.progress() # ProgressInfo or None
317+
if p: print(f"{p.current}/{p.total}")
318+
319+
# Cancel a task
320+
await future.cancel() # raises TaskCancelled when awaited
263321

264322
# Await the result
265323
result = await future # shorthand
@@ -428,7 +486,7 @@ merged = json.dumps({
428486
## Error handling
429487

430488
```python
431-
from pyfuse import trace, Error, RemoteError
489+
from pyfuse import trace, Error, RemoteError, TaskCancelled
432490

433491
# Tracing errors
434492
try:
@@ -441,6 +499,8 @@ try:
441499
result = await future.result()
442500
except RemoteError as e:
443501
print(e) # includes remote traceback
502+
except TaskCancelled:
503+
print("Task was cancelled")
444504
```
445505

446506
## Running the examples

docs/TECHNICAL_OVERVIEW.md

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,13 @@ pyfuse/
5858
### Worker side: `await serve()` / `python -m pyfuse worker`
5959

6060
1. **Listen** -- `async for task_json in backend.listen()` yields tasks as they arrive.
61-
2. **Deserialize** -- Parse the JSON graph into a `Store`.
62-
3. **Cache check** -- Compute a subgraph key (SHA-256 of all reachable content hashes). If cached, skip to step 6.
63-
4. **Install dependencies** -- Extract third-party imports, install missing packages via `asyncio.create_subprocess_exec`.
64-
5. **Reconstruct** -- Produce a self-contained Python script from the store. `compile()` and `exec()` it into a fresh namespace.
65-
6. **Execute** -- Call the function with the provided arguments. Async functions are awaited directly; sync functions run in an executor. Apply retry/timeout policies via `asyncio.wait_for`.
66-
7. **Send result** -- Wrap the return value (or exception traceback) in a `ResultEnvelope` and send it back.
61+
2. **Cancellation check** -- If `await backend.is_cancelled(task_id)` returns ``True``, skip execution and log the cancellation.
62+
3. **Deserialize** -- Parse the JSON graph into a `Store`.
63+
4. **Cache check** -- Compute a subgraph key (SHA-256 of all reachable content hashes). If cached, skip to step 7.
64+
5. **Install dependencies** -- Extract third-party imports, install missing packages via `asyncio.create_subprocess_exec`.
65+
6. **Reconstruct** -- Produce a self-contained Python script from the store. `compile()` and `exec()` it into a fresh namespace.
66+
7. **Execute** -- Call the function with the provided arguments. Async functions are awaited directly; sync functions run in an executor with explicit context propagation (for progress reporting). Apply retry/timeout policies via `asyncio.wait_for`.
67+
8. **Send result** -- Wrap the return value (or exception traceback) in a `ResultEnvelope` and send it back. If cancelled during execution, skip result delivery.
6768

6869
### Client side: `await future` / `await future.result()`
6970

@@ -89,6 +90,10 @@ The `Backend` ABC defines an async transport interface:
8990
| `async send_heartbeat(task_id)` | Signal active processing (no-op default) |
9091
| `async get_heartbeat(task_id)` | Get last heartbeat timestamp |
9192
| `async get_heartbeats(task_ids)` | Batch heartbeat fetch (default loops over `get_heartbeat`) |
93+
| `async cancel_task(task_id)` | Mark a task as cancelled (no-op default) |
94+
| `async is_cancelled(task_id)` | Check cancellation flag (default returns `False`) |
95+
| `async send_progress(task_id, json)` | Store latest progress data (no-op default) |
96+
| `async get_progress(task_id)` | Get latest progress JSON (default returns `None`) |
9297
| `async notify_result(task_id)` | Push notification that result is ready (no-op default) |
9398
| `async subscribe_results()` | Async iterator yielding task_ids on result arrival |
9499
| `async close()` | Release resources |
@@ -101,6 +106,8 @@ Uses `redis.asyncio.Redis` with `RPUSH`/`BLPOP` patterns. Keys:
101106
- `pyfuse:tasks` -- task queue
102107
- `pyfuse:result:{task_id}` -- per-task result (TTL: 300s)
103108
- `pyfuse:heartbeat:{task_id}` -- worker heartbeat timestamp (TTL: 30s)
109+
- `pyfuse:cancel:{task_id}` -- cancellation flag (TTL: 3600s)
110+
- `pyfuse:progress:{task_id}` -- latest progress JSON (TTL: 300s)
104111
- `pyfuse:notify` -- Pub/Sub channel for result notifications
105112

106113
Result notifications use Redis Pub/Sub (`PUBLISH`/`SUBSCRIBE`). Batch heartbeat fetching uses `MGET` for efficiency.
@@ -283,9 +290,11 @@ Error (includes remote traceback):
283290
| Method / Property | Description |
284291
|------------------|-------------|
285292
| `await result` | Shorthand for `await result.result()` |
286-
| `await result.result(timeout, stall_timeout=10.0)` | Await with options; raises `RemoteError` on failure, `TaskStalled` on stall |
293+
| `await result.result(timeout, stall_timeout=10.0)` | Await with options; raises `RemoteError` on failure, `TaskStalled` on stall, `TaskCancelled` on cancel |
294+
| `await result.cancel()` | Cancel the task; awaiting raises `TaskCancelled` |
295+
| `await result.progress()` | Latest `ProgressInfo` (or `None`) |
287296
| `await result.done()` | Non-blocking check |
288-
| `await result.status()` | `"pending"`, `"success"`, or `"error"` |
297+
| `await result.status()` | `"pending"`, `"success"`, `"error"`, or `"cancelled"` |
289298
| `.task_id` | The task identifier |
290299

291300
## Serialization format
@@ -417,6 +426,54 @@ Stall detection only triggers after at least one heartbeat has been observed, av
417426
|--------|----------------|
418427
| `await result.result()` | On by default (`stall_timeout=10.0`). Disable with `stall_timeout=None` |
419428

429+
## Task cancellation
430+
431+
Tasks can be cancelled via `await result.cancel()`. Cancellation is cooperative:
432+
433+
1. **Client** calls `cancel()`, which sets a cancellation flag in the backend and stores a ``"cancelled"`` result envelope.
434+
2. **Worker** checks `is_cancelled()` before starting execution. If cancelled, the task is skipped entirely (no result sent, since the cancel already stored one).
435+
3. **During execution**: if `cancel()` is called while a task is running, the execution continues. When the worker finishes, it checks `is_cancelled()` again and discards the result if cancelled.
436+
4. **Client** receives `TaskCancelled` when awaiting a cancelled task.
437+
438+
### Backend storage
439+
440+
| Backend | Cancellation storage |
441+
|---------|---------------------|
442+
| Redis | `SET pyfuse:cancel:{task_id} 1 EX 3600` |
443+
| Local | In-memory `set()` in the broker |
444+
445+
## Progress reporting
446+
447+
Tasks can report progress via `pyfuse.progress(percent)` or `pyfuse.progress(current, total)`.
448+
449+
### How it works
450+
451+
1. **Context variable**: A `contextvars.ContextVar` holds the progress callback. The worker sets it before executing each task.
452+
2. **Sync function support**: `Worker.run()` explicitly propagates context variables to executor threads via `contextvars.copy_context().run()`.
453+
3. **Rate-limited sends**: The progress callback rate-limits backend sends to one per 50 ms. Intermediate updates are stored locally. A ``flush()`` coroutine sends the final state after execution completes.
454+
4. **No-op locally**: When called outside a worker, `progress()` is a silent no-op (context variable is ``None``).
455+
456+
### Backend storage
457+
458+
| Backend | Progress storage |
459+
|---------|-----------------|
460+
| Redis | `SET pyfuse:progress:{task_id} <json> EX 300` |
461+
| Local | In-memory `dict` in the broker |
462+
463+
## Graceful shutdown
464+
465+
Workers support graceful shutdown via signal handling:
466+
467+
1. **First SIGINT/SIGTERM**: Sets a shutdown event, stops accepting new tasks, and waits for in-progress tasks to complete.
468+
2. **Second SIGINT/SIGTERM**: Cancels all in-progress tasks immediately and exits.
469+
470+
Signal handlers are installed via `loop.add_signal_handler()` (Unix). On Windows, falls back to `KeyboardInterrupt` handling. The worker logs shutdown progress:
471+
472+
```
473+
12:30:00 INFO Graceful shutdown: waiting for 2 task(s) to complete... (Ctrl+C to force quit)
474+
12:30:02 INFO Worker stopped.
475+
```
476+
420477
## Thread and task safety
421478

422479
- The runtime call stack uses `contextvars.ContextVar`, providing per-thread isolation in sync code and per-task isolation in async code.

0 commit comments

Comments
 (0)