|
93 | 93 | from .nodes import Node |
94 | 94 | from .observer import ( |
95 | 95 | _DRAIN_SENTINEL, |
| 96 | + DrainSummary, |
96 | 97 | Observer, |
97 | 98 | RemoveHandle, |
98 | 99 | SubscribedObserver, |
@@ -523,10 +524,14 @@ class CompiledGraph[StateT: State]: |
523 | 524 | # dataclass: the list reference is fixed but its contents change. |
524 | 525 | # Parameterized factories so pyright infers the element types. |
525 | 526 | _attached_observers: list[SubscribedObserver] = field(default_factory=list[SubscribedObserver]) |
526 | | - # `set` (not list) so a per-task `add_done_callback(self._active_workers.discard)` |
527 | | - # auto-removes completed workers — long-running services that never call |
528 | | - # drain() don't accumulate completed Task references indefinitely. |
529 | | - _active_workers: set[asyncio.Task[None]] = field(default_factory=set[asyncio.Task[None]]) |
| 527 | + # Per-task `add_done_callback` auto-removes completed workers — long- |
| 528 | + # running services that never call drain() don't accumulate completed |
| 529 | + # Task references indefinitely. Values are the per-invocation |
| 530 | + # `_InvocationContext` so `drain()` can read each worker's |
| 531 | + # `drain_counters` to compute the undelivered-event count at timeout. |
| 532 | + _active_workers: dict[asyncio.Task[None], _InvocationContext] = field( |
| 533 | + default_factory=dict[asyncio.Task[None], _InvocationContext] |
| 534 | + ) |
530 | 535 | # Single-element list so the frozen-dataclass binding is stable but |
531 | 536 | # the user can swap the registered Checkpointer via |
532 | 537 | # ``attach_checkpointer``. ``None`` when no backend is registered. |
@@ -680,35 +685,90 @@ async def _migrate_record( |
680 | 685 | ) |
681 | 686 | return migrated, summary |
682 | 687 |
|
683 | | - async def drain(self) -> None: |
| 688 | + async def drain(self, timeout: float | None = None) -> DrainSummary: |
684 | 689 | """Await delivery of every observer event produced by prior |
685 | | - invocations of this graph. |
| 690 | + invocations of this graph, optionally bounded by ``timeout``. |
686 | 691 |
|
687 | 692 | Callers running in short-lived processes (scripts, serverless |
688 | | - functions, CLIs) MUST use drain to avoid losing observer |
689 | | - events that were dispatched but not yet delivered. |
| 693 | + functions, CLIs) MUST use drain to avoid losing observer events |
| 694 | + that were dispatched but not yet delivered. |
690 | 695 |
|
691 | 696 | Only events dispatched before this call are awaited; events |
692 | 697 | from invocations started concurrently with drain may or may |
693 | 698 | not be included. Subgraph events from active invocations are |
694 | 699 | part of the parent invocation's worker and are covered |
695 | 700 | automatically. |
696 | 701 |
|
697 | | - **Unbounded by design.** Drain blocks until every queued event has |
698 | | - been delivered to every subscribed observer. A slow, hung, or |
699 | | - misbehaving observer can therefore hold drain, and the calling |
700 | | - process, indefinitely. If you need a bounded wait, wrap the call |
701 | | - in `asyncio.wait_for` and accept that events still queued when the |
702 | | - deadline elapses will not be delivered:: |
703 | | -
|
704 | | - await asyncio.wait_for(compiled.drain(), timeout=5.0) |
| 702 | + ``timeout`` is a non-negative duration in seconds. If omitted |
| 703 | + or ``None``, drain waits indefinitely — a slow, hung, or |
| 704 | + misbehaving observer can therefore hold drain (and the calling |
| 705 | + process) indefinitely. If supplied, drain returns no later |
| 706 | + than ``timeout`` seconds after the call begins; any observer |
| 707 | + events still queued or in-flight at that point are considered |
| 708 | + undelivered. Workers are cancelled via ``Task.cancel()`` so |
| 709 | + the compiled graph remains usable for subsequent invocations |
| 710 | + — partial delivery state from one drain does NOT leak into |
| 711 | + the next invocation. |
| 712 | +
|
| 713 | + Returns a :class:`DrainSummary` with ``undelivered_count`` and |
| 714 | + ``timeout_reached`` fields. The shape is the same whether or |
| 715 | + not a timeout was supplied; on the no-timeout / timeout-not- |
| 716 | + fired path both fields are zero / false. |
| 717 | +
|
| 718 | + Observers SHOULD be written to be cancellation-safe |
| 719 | + (idempotent writes, try/finally cleanup) so that interruption |
| 720 | + by drain timeout does not leave partial side effects in an |
| 721 | + inconsistent state. |
| 722 | +
|
| 723 | + Raises ``ValueError`` if ``timeout`` is negative or NaN. |
| 724 | + Non-numeric input raises ``TypeError`` from the comparison. |
705 | 725 | """ |
| 726 | + # ``not (timeout >= 0)`` is the right check: catches negative |
| 727 | + # values, catches NaN (all comparisons with NaN return False), |
| 728 | + # and lets non-numeric input raise ``TypeError`` from the |
| 729 | + # comparison operator itself. Silently treating a negative |
| 730 | + # timeout as "immediate cancel" would be a user-hostile failure |
| 731 | + # mode — the spec contract is non-negative seconds. |
| 732 | + if timeout is not None and not (timeout >= 0): |
| 733 | + raise ValueError(f"drain timeout must be non-negative, got {timeout!r}") |
706 | 734 | if not self._active_workers: |
707 | | - return |
708 | | - # Snapshot the set: each worker's done-callback removes itself |
709 | | - # from `_active_workers`, so iterating it directly while gather |
710 | | - # awaits would mutate during iteration. |
711 | | - await asyncio.gather(*list(self._active_workers), return_exceptions=True) |
| 735 | + return DrainSummary(undelivered_count=0, timeout_reached=False) |
| 736 | + # Snapshot the dict: each worker's done-callback removes its |
| 737 | + # entry from `_active_workers`, so iterating directly while |
| 738 | + # `asyncio.wait` awaits would mutate during iteration. |
| 739 | + snapshot = dict(self._active_workers) |
| 740 | + workers = list(snapshot.keys()) |
| 741 | + |
| 742 | + _done, pending = await asyncio.wait( |
| 743 | + workers, |
| 744 | + timeout=timeout, |
| 745 | + return_when=asyncio.ALL_COMPLETED, |
| 746 | + ) |
| 747 | + |
| 748 | + if pending: |
| 749 | + undelivered = sum( |
| 750 | + snapshot[w].drain_counters.dispatched - snapshot[w].drain_counters.delivered for w in pending |
| 751 | + ) |
| 752 | + timeout_reached = True |
| 753 | + for w in pending: |
| 754 | + w.cancel() |
| 755 | + else: |
| 756 | + undelivered = 0 |
| 757 | + timeout_reached = False |
| 758 | + |
| 759 | + # Gather ALL workers (done + pending) so any exception that |
| 760 | + # escaped a delivery worker surfaces here instead of leaking |
| 761 | + # as a "Task exception was never retrieved" warning. The |
| 762 | + # ``return_exceptions=True`` absorbs both the synthetic |
| 763 | + # ``CancelledError`` from cancelled workers and any genuine |
| 764 | + # bug-escape from a ``deliver_loop`` that ever raised past |
| 765 | + # its inner ``warnings.warn`` isolation. Also load-bearing |
| 766 | + # for the cross-invocation cleanliness contract — done- |
| 767 | + # callbacks fire on cancellation, so ``_active_workers`` is |
| 768 | + # empty by the time we return. |
| 769 | + await asyncio.gather(*workers, return_exceptions=True) |
| 770 | + |
| 771 | + return DrainSummary(undelivered_count=undelivered, timeout_reached=timeout_reached) |
712 | 772 |
|
713 | 773 | # ------------------------------------------------------------------ |
714 | 774 | # Public invocation |
@@ -893,12 +953,16 @@ async def invoke( |
893 | 953 | # "per-invocation is OUTERMOST invoke" wording). |
894 | 954 | correlation_token = _set_correlation_id(resolved_correlation_id) |
895 | 955 | invocation_token = _set_invocation_id(invocation_id) |
896 | | - worker = asyncio.create_task(deliver_loop(queue)) |
897 | | - self._active_workers.add(worker) |
| 956 | + worker = asyncio.create_task(deliver_loop(queue, context.drain_counters)) |
| 957 | + self._active_workers[worker] = context |
898 | 958 | # Auto-prune: when the worker completes (after the sentinel is |
899 | | - # processed), remove it from the active set so long-running |
900 | | - # services don't leak Task references between drain() calls. |
901 | | - worker.add_done_callback(self._active_workers.discard) |
| 959 | + # processed, or after cancellation by drain() on timeout), remove |
| 960 | + # it from the active set so long-running services don't leak Task |
| 961 | + # references between drain() calls. ``pop(key, None)`` is the |
| 962 | + # idempotent form — if a concurrent drain() removed the entry |
| 963 | + # already (it shouldn't with the current design, but the no-arg |
| 964 | + # form would raise KeyError), this is a safe no-op. |
| 965 | + worker.add_done_callback(lambda t: self._active_workers.pop(t, None)) |
902 | 966 | # Per spec §6 cross-ref in proposal 0014: dispatch the |
903 | 967 | # ``checkpoint_migrated`` event as soon as the delivery |
904 | 968 | # worker is alive but before any node runs, so the OTel |
|
0 commit comments