Skip to content

Commit 385227a

Browse files
Address PR #69 review: validate timeout + gather all workers
drain() now validates `timeout` is non-negative (and not NaN) at the API boundary. Negative values previously fell through to asyncio.wait as an immediate cancel; surface as ValueError with a clear message. Restructured the post-wait branch to gather all workers (both _done and pending) with return_exceptions=True after cancellation. Previous shape only awaited pending in the timeout-fired branch and skipped the gather entirely on the clean path, so any exception escaping a delivery worker would surface as a "Task exception was never retrieved" warning. Defensive — deliver_loop catches observer exceptions internally — but cheap and prevents the silent-failure mode. Two new unit tests cover the validation behavior (negative + NaN inputs raise ValueError with the expected message).
1 parent 9b56c44 commit 385227a

2 files changed

Lines changed: 54 additions & 16 deletions

File tree

src/openarmature/graph/compiled.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,18 @@ async def drain(self, timeout: float | None = None) -> DrainSummary:
719719
(idempotent writes, try/finally cleanup) so that interruption
720720
by drain timeout does not leave partial side effects in an
721721
inconsistent state.
722+
723+
Raises ``ValueError`` if ``timeout`` is negative or NaN.
724+
Non-numeric input raises ``TypeError`` from the comparison.
722725
"""
726+
# ``not (timeout >= 0)`` is the right check: catches negative
727+
# values, catches NaN (all comparisons with NaN return False),
728+
# and lets non-numeric input raise ``TypeError`` from the
729+
# comparison operator itself. Silently treating a negative
730+
# timeout as "immediate cancel" would be a user-hostile failure
731+
# mode — the spec contract is non-negative seconds.
732+
if timeout is not None and not (timeout >= 0):
733+
raise ValueError(f"drain timeout must be non-negative, got {timeout!r}")
723734
if not self._active_workers:
724735
return DrainSummary(undelivered_count=0, timeout_reached=False)
725736
# Snapshot the dict: each worker's done-callback removes its
@@ -734,22 +745,30 @@ async def drain(self, timeout: float | None = None) -> DrainSummary:
734745
return_when=asyncio.ALL_COMPLETED,
735746
)
736747

737-
if not pending:
738-
return DrainSummary(undelivered_count=0, timeout_reached=False)
739-
740-
undelivered = sum(
741-
snapshot[w].drain_counters.dispatched - snapshot[w].drain_counters.delivered for w in pending
742-
)
743-
for w in pending:
744-
w.cancel()
745-
# Wait for cancellations to settle so the done-callbacks fire
746-
# and `_active_workers` cleans for the next invocation —
747-
# load-bearing for the cross-invocation cleanliness contract.
748-
# ``return_exceptions`` absorbs the ``CancelledError`` each
749-
# cancelled worker raises.
750-
await asyncio.gather(*pending, return_exceptions=True)
751-
752-
return DrainSummary(undelivered_count=undelivered, timeout_reached=True)
748+
if pending:
749+
undelivered = sum(
750+
snapshot[w].drain_counters.dispatched - snapshot[w].drain_counters.delivered for w in pending
751+
)
752+
timeout_reached = True
753+
for w in pending:
754+
w.cancel()
755+
else:
756+
undelivered = 0
757+
timeout_reached = False
758+
759+
# Gather ALL workers (done + pending) so any exception that
760+
# escaped a delivery worker surfaces here instead of leaking
761+
# as a "Task exception was never retrieved" warning. The
762+
# ``return_exceptions=True`` absorbs both the synthetic
763+
# ``CancelledError`` from cancelled workers and any genuine
764+
# bug-escape from a ``deliver_loop`` that ever raised past
765+
# its inner ``warnings.warn`` isolation. Also load-bearing
766+
# for the cross-invocation cleanliness contract — done-
767+
# callbacks fire on cancellation, so ``_active_workers`` is
768+
# empty by the time we return.
769+
await asyncio.gather(*workers, return_exceptions=True)
770+
771+
return DrainSummary(undelivered_count=undelivered, timeout_reached=timeout_reached)
753772

754773
# ------------------------------------------------------------------
755774
# Public invocation

tests/unit/test_drain.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,22 @@ async def obs(event: NodeEvent) -> None:
215215
# before the zero-second deadline fired.
216216
assert summary.undelivered_count == 6
217217
assert len(received) == 0
218+
219+
220+
async def test_drain_rejects_negative_timeout() -> None:
221+
# Spec §6: timeout is "a non-negative duration in seconds". A
222+
# negative value is a user mistake — surface it as ValueError at
223+
# the API boundary rather than silently treating it like an
224+
# immediate cancel.
225+
compiled = _build_compiled()
226+
with pytest.raises(ValueError, match="non-negative"):
227+
await compiled.drain(timeout=-1.0)
228+
229+
230+
async def test_drain_rejects_nan_timeout() -> None:
231+
# NaN compares False against everything, so `not (timeout >= 0)`
232+
# catches it just like negative values. Without the validation it
233+
# would silently fall through `asyncio.wait` as an immediate cancel.
234+
compiled = _build_compiled()
235+
with pytest.raises(ValueError, match="non-negative"):
236+
await compiled.drain(timeout=float("nan"))

0 commit comments

Comments
 (0)