Skip to content

Commit 94e0bee

Browse files
Add drain timeout + DrainSummary (proposal 0010)
CompiledGraph.drain() gains an optional timeout parameter and returns a DrainSummary frozen dataclass (undelivered_count, timeout_reached). The timeout-fired path cancels in-flight delivery workers cleanly so the graph remains usable for subsequent invocations. Per-invocation dispatched/delivered counters on _InvocationContext track undelivered events; _active_workers changes from set[Task] to dict[Task, _InvocationContext] so drain() can read each worker's counters at cancellation time. Solves the slow-observer-blocks-process-exit footgun for short-lived processes (CLIs, scripts, serverless functions).
1 parent f381fef commit 94e0bee

5 files changed

Lines changed: 365 additions & 35 deletions

File tree

src/openarmature/graph/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
exponential_jitter_backoff,
4949
)
5050
from .nodes import FunctionNode, Node
51-
from .observer import Observer, RemoveHandle, SubscribedObserver
51+
from .observer import DrainSummary, Observer, RemoveHandle, SubscribedObserver
5252
from .parallel_branches import BranchSpec, ParallelBranchesNode
5353
from .projection import ExplicitMapping, FieldNameMatching, ProjectionStrategy
5454
from .reducers import Reducer, append, last_write_wins, merge
@@ -62,6 +62,7 @@
6262
"ConditionalEdge",
6363
"ConflictingReducers",
6464
"DanglingEdge",
65+
"DrainSummary",
6566
"EdgeException",
6667
"EndSentinel",
6768
"ExplicitMapping",

src/openarmature/graph/compiled.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
from .nodes import Node
9494
from .observer import (
9595
_DRAIN_SENTINEL,
96+
DrainSummary,
9697
Observer,
9798
RemoveHandle,
9899
SubscribedObserver,
@@ -523,10 +524,14 @@ class CompiledGraph[StateT: State]:
523524
# dataclass: the list reference is fixed but its contents change.
524525
# Parameterized factories so pyright infers the element types.
525526
_attached_observers: list[SubscribedObserver] = field(default_factory=list[SubscribedObserver])
526-
# `set` (not list) so a per-task `add_done_callback(self._active_workers.discard)`
527-
# auto-removes completed workers — long-running services that never call
528-
# drain() don't accumulate completed Task references indefinitely.
529-
_active_workers: set[asyncio.Task[None]] = field(default_factory=set[asyncio.Task[None]])
527+
# Per-task `add_done_callback` auto-removes completed workers — long-
528+
# running services that never call drain() don't accumulate completed
529+
# Task references indefinitely. Values are the per-invocation
530+
# `_InvocationContext` so `drain()` can read each worker's
531+
# `drain_counters` to compute the undelivered-event count at timeout.
532+
_active_workers: dict[asyncio.Task[None], _InvocationContext] = field(
533+
default_factory=dict[asyncio.Task[None], _InvocationContext]
534+
)
530535
# Single-element list so the frozen-dataclass binding is stable but
531536
# the user can swap the registered Checkpointer via
532537
# ``attach_checkpointer``. ``None`` when no backend is registered.
@@ -680,35 +685,71 @@ async def _migrate_record(
680685
)
681686
return migrated, summary
682687

683-
async def drain(self) -> None:
688+
async def drain(self, timeout: float | None = None) -> DrainSummary:
684689
"""Await delivery of every observer event produced by prior
685-
invocations of this graph.
690+
invocations of this graph, optionally bounded by ``timeout``.
686691
687692
Callers running in short-lived processes (scripts, serverless
688-
functions, CLIs) MUST use drain to avoid losing observer
689-
events that were dispatched but not yet delivered.
693+
functions, CLIs) MUST use drain to avoid losing observer events
694+
that were dispatched but not yet delivered.
690695
691696
Only events dispatched before this call are awaited; events
692697
from invocations started concurrently with drain may or may
693698
not be included. Subgraph events from active invocations are
694699
part of the parent invocation's worker and are covered
695700
automatically.
696701
697-
**Unbounded by design.** Drain blocks until every queued event has
698-
been delivered to every subscribed observer. A slow, hung, or
699-
misbehaving observer can therefore hold drain, and the calling
700-
process, indefinitely. If you need a bounded wait, wrap the call
701-
in `asyncio.wait_for` and accept that events still queued when the
702-
deadline elapses will not be delivered::
703-
704-
await asyncio.wait_for(compiled.drain(), timeout=5.0)
702+
``timeout`` is a non-negative duration in seconds. If omitted
703+
or ``None``, drain waits indefinitely — a slow, hung, or
704+
misbehaving observer can therefore hold drain (and the calling
705+
process) indefinitely. If supplied, drain returns no later
706+
than ``timeout`` seconds after the call begins; any observer
707+
events still queued or in-flight at that point are considered
708+
undelivered. Workers are cancelled via ``Task.cancel()`` so
709+
the compiled graph remains usable for subsequent invocations
710+
— partial delivery state from one drain does NOT leak into
711+
the next invocation.
712+
713+
Returns a :class:`DrainSummary` with ``undelivered_count`` and
714+
``timeout_reached`` fields. The shape is the same whether or
715+
not a timeout was supplied; on the no-timeout / timeout-not-
716+
fired path both fields are zero / false.
717+
718+
Observers SHOULD be written to be cancellation-safe
719+
(idempotent writes, try/finally cleanup) so that interruption
720+
by drain timeout does not leave partial side effects in an
721+
inconsistent state.
705722
"""
706723
if not self._active_workers:
707-
return
708-
# Snapshot the set: each worker's done-callback removes itself
709-
# from `_active_workers`, so iterating it directly while gather
710-
# awaits would mutate during iteration.
711-
await asyncio.gather(*list(self._active_workers), return_exceptions=True)
724+
return DrainSummary(undelivered_count=0, timeout_reached=False)
725+
# Snapshot the dict: each worker's done-callback removes its
726+
# entry from `_active_workers`, so iterating directly while
727+
# `asyncio.wait` awaits would mutate during iteration.
728+
snapshot = dict(self._active_workers)
729+
workers = list(snapshot.keys())
730+
731+
_done, pending = await asyncio.wait(
732+
workers,
733+
timeout=timeout,
734+
return_when=asyncio.ALL_COMPLETED,
735+
)
736+
737+
if not pending:
738+
return DrainSummary(undelivered_count=0, timeout_reached=False)
739+
740+
undelivered = sum(
741+
snapshot[w].drain_counters.dispatched - snapshot[w].drain_counters.delivered for w in pending
742+
)
743+
for w in pending:
744+
w.cancel()
745+
# Wait for cancellations to settle so the done-callbacks fire
746+
# and `_active_workers` cleans for the next invocation —
747+
# load-bearing for the cross-invocation cleanliness contract.
748+
# ``return_exceptions`` absorbs the ``CancelledError`` each
749+
# cancelled worker raises.
750+
await asyncio.gather(*pending, return_exceptions=True)
751+
752+
return DrainSummary(undelivered_count=undelivered, timeout_reached=True)
712753

713754
# ------------------------------------------------------------------
714755
# Public invocation
@@ -893,12 +934,16 @@ async def invoke(
893934
# "per-invocation is OUTERMOST invoke" wording).
894935
correlation_token = _set_correlation_id(resolved_correlation_id)
895936
invocation_token = _set_invocation_id(invocation_id)
896-
worker = asyncio.create_task(deliver_loop(queue))
897-
self._active_workers.add(worker)
937+
worker = asyncio.create_task(deliver_loop(queue, context.drain_counters))
938+
self._active_workers[worker] = context
898939
# Auto-prune: when the worker completes (after the sentinel is
899-
# processed), remove it from the active set so long-running
900-
# services don't leak Task references between drain() calls.
901-
worker.add_done_callback(self._active_workers.discard)
940+
# processed, or after cancellation by drain() on timeout), remove
941+
# it from the active set so long-running services don't leak Task
942+
# references between drain() calls. ``pop(key, None)`` is the
943+
# idempotent form — if a concurrent drain() removed the entry
944+
# already (it shouldn't with the current design, but the no-arg
945+
# form would raise KeyError), this is a safe no-op.
946+
worker.add_done_callback(lambda t: self._active_workers.pop(t, None))
902947
# Per spec §6 cross-ref in proposal 0014: dispatch the
903948
# ``checkpoint_migrated`` event as soon as the delivery
904949
# worker is alive but before any node runs, so the OTel

src/openarmature/graph/observer.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,46 @@ class _QueuedItem:
212212
_DRAIN_SENTINEL = None
213213

214214

215+
# Spec: realizes graph-engine §6 Drain undelivered-count bookkeeping
216+
# (proposal 0010). Per-invocation mutable counters; `_dispatch` bumps
217+
# `dispatched` after a successful `queue.put_nowait`; `deliver_loop`
218+
# bumps `delivered` after the per-event observer for-loop completes.
219+
# `undelivered = dispatched - delivered` at any point in time — and
220+
# specifically at `CompiledGraph.drain()` cancellation time when the
221+
# timeout has elapsed and pending workers' counters get summed into
222+
# the returned `DrainSummary`.
223+
@dataclass
224+
class _DrainCounters:
225+
dispatched: int = 0
226+
delivered: int = 0
227+
228+
229+
# Spec: realizes graph-engine §6 Drain summary return shape (proposal
230+
# 0010). The two declared fields are the spec-mandated minimum;
231+
# implementations MAY add richer detail in future PRs (per-observer
232+
# counts, sampled event metadata) without breaking the v0.19.0 shape.
233+
@dataclass(frozen=True)
234+
class DrainSummary:
235+
"""Outcome of a `CompiledGraph.drain()` call.
236+
237+
Returned from `drain()` regardless of whether a `timeout` was
238+
supplied. When no timeout was supplied, or the timeout did not
239+
fire, `undelivered_count == 0` and `timeout_reached is False`.
240+
When the timeout fired, `undelivered_count` reports the number of
241+
events that were dispatched to the delivery worker but not fully
242+
delivered to every subscribed observer before cancellation, and
243+
`timeout_reached is True`.
244+
245+
The spec-mandated minimum is these two fields. Implementations MAY
246+
extend the shape with diagnostic detail (per-observer counts,
247+
sampled event metadata) in subsequent versions; v0.19.0 ships the
248+
minimum.
249+
"""
250+
251+
undelivered_count: int
252+
timeout_reached: bool
253+
254+
215255
# Spec: realizes pipeline-utilities §10.11 per-instance progress
216256
# tracking in the engine. These are the MUTABLE internal-state
217257
# counterparts to the FROZEN public ``FanOutProgress`` /
@@ -353,6 +393,12 @@ class _InvocationContext:
353393
fan_out_progress_state: dict[tuple[tuple[str, ...], str], _FanOutExecutionState] = field(
354394
default_factory=dict[tuple[tuple[str, ...], str], _FanOutExecutionState]
355395
)
396+
# Per spec §6 Drain (proposal 0010): shared mutable counters that
397+
# the worker reads at drain-cancel time to report undelivered events
398+
# in the returned ``DrainSummary``. Subgraphs share the parent's
399+
# counters because subgraphs share the parent's queue + worker, so
400+
# the parent context's counts naturally cover subgraph events.
401+
drain_counters: _DrainCounters = field(default_factory=_DrainCounters)
356402

357403
def full_observers(self) -> tuple[SubscribedObserver, ...]:
358404
"""Return the ordered observer list to deliver for events from
@@ -395,6 +441,7 @@ def descend_into_subgraph(
395441
pending_resume_states=self.pending_resume_states,
396442
resume_invocation=self.resume_invocation,
397443
fan_out_progress_state=self.fan_out_progress_state,
444+
drain_counters=self.drain_counters,
398445
)
399446

400447
def descend_into_fan_out_instance(
@@ -440,6 +487,7 @@ def descend_into_fan_out_instance(
440487
# inner-instance node can update its own entry and so the
441488
# outer save sees consistent sibling state.
442489
fan_out_progress_state=self.fan_out_progress_state,
490+
drain_counters=self.drain_counters,
443491
)
444492

445493
def descend_into_parallel_branch(
@@ -485,6 +533,7 @@ def descend_into_parallel_branch(
485533
pending_resume_states={},
486534
resume_invocation=self.resume_invocation,
487535
fan_out_progress_state=self.fan_out_progress_state,
536+
drain_counters=self.drain_counters,
488537
)
489538

490539
def take_step(self) -> int:
@@ -579,9 +628,16 @@ def _dispatch(context: _InvocationContext, event: NodeEvent) -> None:
579628
stacklevel=2,
580629
)
581630
context.queue.put_nowait(_QueuedItem(event=event, observers=observers))
631+
# Per spec §6 Drain (proposal 0010): increment AFTER the put so a
632+
# raise from ``put_nowait`` (queue full on a bounded queue — we
633+
# don't bound, but the invariant holds) doesn't desync the counter.
634+
context.drain_counters.dispatched += 1
582635

583636

584-
async def deliver_loop(queue: asyncio.Queue[_QueuedItem | None]) -> None:
637+
async def deliver_loop(
638+
queue: asyncio.Queue[_QueuedItem | None],
639+
counters: _DrainCounters,
640+
) -> None:
585641
"""Background worker: read queued events, deliver to observers serially.
586642
587643
- No two observers receive the same event concurrently (we await
@@ -610,17 +666,27 @@ async def deliver_loop(queue: asyncio.Queue[_QueuedItem | None]) -> None:
610666
f"observer raised {type(e).__name__}: {e}",
611667
stacklevel=1,
612668
)
669+
# Per spec §6 Drain (proposal 0010): increment AFTER the
670+
# observer for-loop completes for this event, so an event
671+
# cancelled mid-for-loop is counted as undelivered
672+
# (``dispatched - delivered`` includes it). The phase-filter
673+
# ``continue`` above does NOT skip the increment — an event
674+
# filtered out for every observer is still considered
675+
# delivered (we did all the work there was to do for it).
676+
counters.delivered += 1
613677

614678

615679
__all__ = [
616680
"ALL_PHASES",
681+
"DrainSummary",
617682
"Observer",
618683
"RemoveHandle",
619684
"SubscribedObserver",
620685
# Engine-internal but listed so pyright sees them as exported (they're
621686
# imported by `compiled.py` and `subgraph.py`). The underscore prefix
622687
# is the user-facing "don't import these" signal.
623688
"_DRAIN_SENTINEL",
689+
"_DrainCounters",
624690
"_FanOutExecutionState",
625691
"_FanOutInstanceState",
626692
"_InvocationContext",

0 commit comments

Comments
 (0)