Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions src/providers/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from __future__ import annotations

import json
import logging
from abc import abstractmethod
from typing import Any, Generator, Optional

from .base import BaseProvider, ChatResponse, MessageInput, TextChunkCallback

logger = logging.getLogger(__name__)


def _apply_client_timeout(client: Any) -> Any:
"""Bound an OpenAI-SDK client's read timeout + retries (env-tunable).
Expand Down Expand Up @@ -567,12 +570,12 @@ def chat_stream_response(
"""Stream OpenAI-compatible chunks while rebuilding the final response.

ESC-cancellation runs the SDK iteration on a daemon worker
thread that pushes chunks into a ``queue.Queue``. The main
thread polls the queue with a 100 ms timeout and re-checks
thread that pushes chunks into a bounded ``queue.Queue``. The
main thread polls the queue with a 100 ms timeout and re-checks
``guard.aborted`` between ticks. On abort the main thread
raises ``AbortError`` immediately and orphans the worker
the worker dies when the underlying connection eventually
closes.
raises ``AbortError`` immediately; the worker notices the abort
(or the consumer's exit) at its next put attempt and stops
reading the stream.

Why the worker indirection (vs. the simpler in-loop check
used in earlier revisions): the OpenAI Python SDK uses sync
Expand Down Expand Up @@ -640,36 +643,74 @@ def chat_stream_response(
# thread.
#
# Workaround: hoist the iteration onto a daemon worker thread
# that pushes chunks into a queue. The main thread polls the
# queue with a short timeout and re-checks ``guard.aborted``
# each tick. On abort we raise ``AbortError`` immediately and
# orphan the worker — it'll die when the underlying connection
# eventually closes (server-side, idle timeout, or the SDK's
# natural exhaustion). The cost is some wasted bandwidth on
# the orphaned read; the benefit is that the user's prompt
# comes back in ~100 ms regardless of LiteLLM/httpx behavior.
# that pushes chunks into a bounded queue. The main thread polls
# the queue with a short timeout and re-checks ``guard.aborted``
# each tick. On abort we raise ``AbortError`` immediately; the
# worker notices the abort (or the consumer's exit) at its next
# put attempt and stops reading the stream within one 0.25s
# poll. The benefit is that the user's prompt comes back in
# ~100 ms regardless of LiteLLM/httpx behavior.
import queue as _queue
import threading as _threading

_DONE = object()
chunk_queue: _queue.Queue = _queue.Queue()
# Bounded (#278): after ESC the consumer stops draining, and a
# proxy that keeps sending bytes without closing the iterator
# would otherwise grow the queue without limit. 64 bounds the
# post-abort staleness to a trivial drain while giving the
# producer slack against transient consumer pauses.
chunk_queue: _queue.Queue = _queue.Queue(maxsize=64)
# Set when the consumer loop exits for ANY reason. Without it, a
# consumer that unwinds for a non-abort reason (on_text_chunk
# raising, KeyboardInterrupt) would leave the worker retrying a
# full queue forever — an immortal thread pinning the httpx
# connection open.
consumer_gone = _threading.Event()

def _put_or_drop_on_abort(item: Any) -> bool:
"""Block until ``item`` is enqueued, or drop it once the
abort trips or the consumer exits (either way nobody will
drain it; keeping nothing alive is the point). Returns
False when dropped."""
while True:
if guard.aborted or consumer_gone.is_set():
return False
try:
chunk_queue.put(item, timeout=0.25)
return True
except _queue.Full:
continue

def _drain_stream() -> None:
try:
for c in stream:
chunk_queue.put(c)
if not _put_or_drop_on_abort(c):
return # stop reading; orphaned socket dies upstream
except BaseException as exc: # noqa: BLE001 — surface to consumer
chunk_queue.put(exc)
if not _put_or_drop_on_abort(exc):
# Abort won the race against a genuine error; the
# consumer raises AbortError, so keep the loser
# visible somewhere.
logger.debug(
"stream error dropped after abort", exc_info=exc
)
finally:
chunk_queue.put(_DONE)
_put_or_drop_on_abort(_DONE)

worker = _threading.Thread(
target=_drain_stream,
daemon=True,
name=f"openai-stream-{id(stream)}",
)

with guard.attach(stream):
import contextlib as _contextlib

with _contextlib.ExitStack() as _consumer_scope:
# Releases the worker (sets consumer_gone) no matter how the
# consumer loop exits — abort, callback error, or natural
# break — so a blocked put never outlives its consumer.
_consumer_scope.callback(consumer_gone.set)
_consumer_scope.enter_context(guard.attach(stream))
worker.start()
while True:
try:
Expand Down
95 changes: 95 additions & 0 deletions tests/test_openai_compat_abort_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,98 @@ def test_normal_completion_still_captures_final_usage() -> None:
# ``response.usage`` would be the default empty dict, and the
# ``↓ N tokens`` REPL spinner would silently lose count.
assert response.usage.get("total_tokens") == 15


class _FirehoseStream:
"""Pathological proxy (#278): keeps yielding after ESC, ignores
``response.close()`` entirely, and never raises."""

def __init__(self) -> None:
self.yielded = 0
self.response = MagicMock() # close() is a silent no-op

def __iter__(self):
while True:
self.yielded += 1
time.sleep(0.001)
yield _FakeChunk(content="x")


def test_firehose_stream_aborts_and_stops_accumulating() -> None:
"""ESC against a stream that never goes quiet must still abort
promptly and bounded (#278): the worker stops enqueueing the moment
the abort trips, so the queue (bounded at 64) drains, the consumer
hits the Empty tick, and AbortError raises — instead of the queue
growing for as long as the proxy keeps sending."""
controller = AbortController()
stream = _FirehoseStream()
provider = _provider_with_stream(stream)

timer = threading.Timer(0.2, lambda: controller.abort("user_interrupt"))
timer.daemon = True
timer.start()

start = time.monotonic()
with pytest.raises(AbortError):
provider.chat_stream_response(
messages=[{"role": "user", "content": "hi"}],
abort_signal=controller.signal,
)
elapsed = time.monotonic() - start
assert elapsed < 2.0, f"abort took {elapsed:.2f}s against a firehose stream"

# The worker must notice the abort and stop reading: the yield
# count settles (within the worker's 0.25s put-poll) instead of
# growing for as long as the proxy keeps sending.
time.sleep(0.6)
settled = stream.yielded
time.sleep(0.5)
assert stream.yielded == settled, "worker kept draining after abort"


def test_backpressure_happy_path_preserves_order_and_content() -> None:
"""A stream larger than the queue bound (64) with a slow consumer
exercises the Full -> retry path: no drops, no duplicates, no
reordering (#278)."""
pieces = [f"c{i:03d}," for i in range(150)]
stream = _FakeStream(pieces)
provider = _provider_with_stream(stream)

seen: list[str] = []

def _slow_chunk(piece: str) -> None:
seen.append(piece)
time.sleep(0.001)

response = provider.chat_stream_response(
messages=[{"role": "user", "content": "hi"}],
on_text_chunk=_slow_chunk,
abort_signal=AbortController().signal,
)
assert response.content == "".join(pieces)
assert seen == pieces


def test_consumer_crash_releases_worker() -> None:
"""A consumer that dies for a NON-abort reason (on_text_chunk
raising) must not leave the worker retrying a full queue forever —
consumer_gone unblocks it and the stream stops being read (#278)."""
stream = _FirehoseStream()
provider = _provider_with_stream(stream)

def _exploding_chunk(_piece: str) -> None:
raise ValueError("ui callback bug")

with pytest.raises(ValueError, match="ui callback bug"):
provider.chat_stream_response(
messages=[{"role": "user", "content": "hi"}],
on_text_chunk=_exploding_chunk,
abort_signal=AbortController().signal,
)

# Worker must notice consumer_gone within one 0.25s put-poll and
# stop reading; the yield counter settles instead of growing.
time.sleep(0.6)
settled = stream.yielded
time.sleep(0.5)
assert stream.yielded == settled, "worker kept draining after consumer crash"