diff --git a/src/providers/openai_compatible.py b/src/providers/openai_compatible.py index e76f4617..b91076cc 100644 --- a/src/providers/openai_compatible.py +++ b/src/providers/openai_compatible.py @@ -7,11 +7,14 @@ from __future__ import annotations import json +import logging from abc import abstractmethod from typing import Any, Generator, Optional from .base import BaseProvider, ChatResponse, MessageInput, TextChunkCallback +logger = logging.getLogger(__name__) + def _apply_client_timeout(client: Any) -> Any: """Bound an OpenAI-SDK client's read timeout + retries (env-tunable). @@ -567,12 +570,12 @@ def chat_stream_response( """Stream OpenAI-compatible chunks while rebuilding the final response. ESC-cancellation runs the SDK iteration on a daemon worker - thread that pushes chunks into a ``queue.Queue``. The main - thread polls the queue with a 100 ms timeout and re-checks + thread that pushes chunks into a bounded ``queue.Queue``. The + main thread polls the queue with a 100 ms timeout and re-checks ``guard.aborted`` between ticks. On abort the main thread - raises ``AbortError`` immediately and orphans the worker — - the worker dies when the underlying connection eventually - closes. + raises ``AbortError`` immediately; the worker notices the abort + (or the consumer's exit) at its next put attempt and stops + reading the stream. Why the worker indirection (vs. the simpler in-loop check used in earlier revisions): the OpenAI Python SDK uses sync @@ -640,28 +643,59 @@ def chat_stream_response( # thread. # # Workaround: hoist the iteration onto a daemon worker thread - # that pushes chunks into a queue. The main thread polls the - # queue with a short timeout and re-checks ``guard.aborted`` - # each tick. On abort we raise ``AbortError`` immediately and - # orphan the worker — it'll die when the underlying connection - # eventually closes (server-side, idle timeout, or the SDK's - # natural exhaustion). The cost is some wasted bandwidth on - # the orphaned read; the benefit is that the user's prompt - # comes back in ~100 ms regardless of LiteLLM/httpx behavior. + # that pushes chunks into a bounded queue. The main thread polls + # the queue with a short timeout and re-checks ``guard.aborted`` + # each tick. On abort we raise ``AbortError`` immediately; the + # worker notices the abort (or the consumer's exit) at its next + # put attempt and stops reading the stream within one 0.25s + # poll. The benefit is that the user's prompt comes back in + # ~100 ms regardless of LiteLLM/httpx behavior. import queue as _queue import threading as _threading _DONE = object() - chunk_queue: _queue.Queue = _queue.Queue() + # Bounded (#278): after ESC the consumer stops draining, and a + # proxy that keeps sending bytes without closing the iterator + # would otherwise grow the queue without limit. 64 bounds the + # post-abort staleness to a trivial drain while giving the + # producer slack against transient consumer pauses. + chunk_queue: _queue.Queue = _queue.Queue(maxsize=64) + # Set when the consumer loop exits for ANY reason. Without it, a + # consumer that unwinds for a non-abort reason (on_text_chunk + # raising, KeyboardInterrupt) would leave the worker retrying a + # full queue forever — an immortal thread pinning the httpx + # connection open. + consumer_gone = _threading.Event() + + def _put_or_drop_on_abort(item: Any) -> bool: + """Block until ``item`` is enqueued, or drop it once the + abort trips or the consumer exits (either way nobody will + drain it; keeping nothing alive is the point). Returns + False when dropped.""" + while True: + if guard.aborted or consumer_gone.is_set(): + return False + try: + chunk_queue.put(item, timeout=0.25) + return True + except _queue.Full: + continue def _drain_stream() -> None: try: for c in stream: - chunk_queue.put(c) + if not _put_or_drop_on_abort(c): + return # stop reading; orphaned socket dies upstream except BaseException as exc: # noqa: BLE001 — surface to consumer - chunk_queue.put(exc) + if not _put_or_drop_on_abort(exc): + # Abort won the race against a genuine error; the + # consumer raises AbortError, so keep the loser + # visible somewhere. + logger.debug( + "stream error dropped after abort", exc_info=exc + ) finally: - chunk_queue.put(_DONE) + _put_or_drop_on_abort(_DONE) worker = _threading.Thread( target=_drain_stream, @@ -669,7 +703,14 @@ def _drain_stream() -> None: name=f"openai-stream-{id(stream)}", ) - with guard.attach(stream): + import contextlib as _contextlib + + with _contextlib.ExitStack() as _consumer_scope: + # Releases the worker (sets consumer_gone) no matter how the + # consumer loop exits — abort, callback error, or natural + # break — so a blocked put never outlives its consumer. + _consumer_scope.callback(consumer_gone.set) + _consumer_scope.enter_context(guard.attach(stream)) worker.start() while True: try: diff --git a/tests/test_openai_compat_abort_signal.py b/tests/test_openai_compat_abort_signal.py index 617d55ad..577ca90c 100644 --- a/tests/test_openai_compat_abort_signal.py +++ b/tests/test_openai_compat_abort_signal.py @@ -392,3 +392,98 @@ def test_normal_completion_still_captures_final_usage() -> None: # ``response.usage`` would be the default empty dict, and the # ``↓ N tokens`` REPL spinner would silently lose count. assert response.usage.get("total_tokens") == 15 + + +class _FirehoseStream: + """Pathological proxy (#278): keeps yielding after ESC, ignores + ``response.close()`` entirely, and never raises.""" + + def __init__(self) -> None: + self.yielded = 0 + self.response = MagicMock() # close() is a silent no-op + + def __iter__(self): + while True: + self.yielded += 1 + time.sleep(0.001) + yield _FakeChunk(content="x") + + +def test_firehose_stream_aborts_and_stops_accumulating() -> None: + """ESC against a stream that never goes quiet must still abort + promptly and bounded (#278): the worker stops enqueueing the moment + the abort trips, so the queue (bounded at 64) drains, the consumer + hits the Empty tick, and AbortError raises — instead of the queue + growing for as long as the proxy keeps sending.""" + controller = AbortController() + stream = _FirehoseStream() + provider = _provider_with_stream(stream) + + timer = threading.Timer(0.2, lambda: controller.abort("user_interrupt")) + timer.daemon = True + timer.start() + + start = time.monotonic() + with pytest.raises(AbortError): + provider.chat_stream_response( + messages=[{"role": "user", "content": "hi"}], + abort_signal=controller.signal, + ) + elapsed = time.monotonic() - start + assert elapsed < 2.0, f"abort took {elapsed:.2f}s against a firehose stream" + + # The worker must notice the abort and stop reading: the yield + # count settles (within the worker's 0.25s put-poll) instead of + # growing for as long as the proxy keeps sending. + time.sleep(0.6) + settled = stream.yielded + time.sleep(0.5) + assert stream.yielded == settled, "worker kept draining after abort" + + +def test_backpressure_happy_path_preserves_order_and_content() -> None: + """A stream larger than the queue bound (64) with a slow consumer + exercises the Full -> retry path: no drops, no duplicates, no + reordering (#278).""" + pieces = [f"c{i:03d}," for i in range(150)] + stream = _FakeStream(pieces) + provider = _provider_with_stream(stream) + + seen: list[str] = [] + + def _slow_chunk(piece: str) -> None: + seen.append(piece) + time.sleep(0.001) + + response = provider.chat_stream_response( + messages=[{"role": "user", "content": "hi"}], + on_text_chunk=_slow_chunk, + abort_signal=AbortController().signal, + ) + assert response.content == "".join(pieces) + assert seen == pieces + + +def test_consumer_crash_releases_worker() -> None: + """A consumer that dies for a NON-abort reason (on_text_chunk + raising) must not leave the worker retrying a full queue forever — + consumer_gone unblocks it and the stream stops being read (#278).""" + stream = _FirehoseStream() + provider = _provider_with_stream(stream) + + def _exploding_chunk(_piece: str) -> None: + raise ValueError("ui callback bug") + + with pytest.raises(ValueError, match="ui callback bug"): + provider.chat_stream_response( + messages=[{"role": "user", "content": "hi"}], + on_text_chunk=_exploding_chunk, + abort_signal=AbortController().signal, + ) + + # Worker must notice consumer_gone within one 0.25s put-poll and + # stop reading; the yield counter settles instead of growing. + time.sleep(0.6) + settled = stream.yielded + time.sleep(0.5) + assert stream.yielded == settled, "worker kept draining after consumer crash"