diff --git a/src/anthropic/_streaming.py b/src/anthropic/_streaming.py index a6d0d20db..d4cce68f1 100644 --- a/src/anthropic/_streaming.py +++ b/src/anthropic/_streaming.py @@ -11,6 +11,7 @@ import httpx +from ._exceptions import APIConnectionError from ._utils import is_dict, extract_type_var_from_base if TYPE_CHECKING: @@ -72,7 +73,23 @@ def __iter__(self) -> Iterator[_T]: yield item def _iter_events(self) -> Iterator[ServerSentEvent]: - yield from self._decoder.iter_bytes(self.response.iter_bytes()) + try: + yield from self._decoder.iter_bytes(self.response.iter_bytes()) + except httpx.TimeoutException: + # Mid-stream timeouts are already handled by `_base_client._request` for the + # initial request, but the SSE body iteration doesn't go through that path — + # re-raise as-is so callers can distinguish a hung stream from a dropped one. + # APITimeoutError is an APIConnectionError subclass, so customers catching + # the latter will still see it; this clause only exists so the next clause + # doesn't double-wrap it (TimeoutException is also a TransportError). + raise + except httpx.TransportError as exc: + # Mid-stream transport drops (RemoteProtocolError, ReadError, ConnectError, …) + # leak through as bare httpx exceptions because the SDK's wrapping in + # `_base_client._request` only covers the pre-body request. Re-wrap them so + # `except anthropic.APIConnectionError:` catches mid-stream drops the same way + # it catches connection failures, and the original is preserved as `__cause__`. + raise APIConnectionError(message=f"Stream interrupted: {exc}", request=self.response.request) from exc def __stream__(self) -> Iterator[_T]: cast_to = cast(Any, self._cast_to) @@ -226,8 +243,17 @@ async def __aiter__(self) -> AsyncIterator[_T]: yield item async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: - async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): - yield sse + try: + async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): + yield sse + except httpx.TimeoutException: + # See sync `_iter_events` — let timeouts pass through so the next clause + # doesn't double-wrap them (TimeoutException is also a TransportError). + raise + except httpx.TransportError as exc: + # See sync `_iter_events` — wrap mid-stream transport drops so + # `except anthropic.APIConnectionError:` catches them. + raise APIConnectionError(message=f"Stream interrupted: {exc}", request=self.response.request) from exc async def __stream__(self) -> AsyncIterator[_T]: cast_to = cast(Any, self._cast_to) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index ac8cc0299..472068251 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -7,7 +7,7 @@ from anthropic import Anthropic, AsyncAnthropic from anthropic._streaming import Stream, AsyncStream, ServerSentEvent -from anthropic._exceptions import APIStatusError +from anthropic._exceptions import APIConnectionError, APIStatusError _T = TypeVar("_T") @@ -219,6 +219,68 @@ def body() -> Iterator[bytes]: assert sse.json() == {"content": "известни"} +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_mid_stream_transport_error_is_wrapped( + sync: bool, + client: Anthropic, + async_client: AsyncAnthropic, +) -> None: + """A transport drop mid-SSE-stream (RemoteProtocolError, ReadError, …) raises + APIConnectionError with the original httpx exception as __cause__, so that + `except anthropic.APIConnectionError:` catches mid-stream drops the same way + it catches initial-connection failures. + """ + + def body() -> Iterator[bytes]: + yield b"event: completion\n" + yield b'data: {"foo":1}\n' + yield b"\n" + raise httpx.RemoteProtocolError("peer closed connection without sending complete message body") + + request = httpx.Request("POST", "http://test") + if sync: + iterator: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent] = Stream( + cast_to=object, client=client, response=httpx.Response(200, content=body(), request=request) + )._iter_events() + else: + iterator = AsyncStream( + cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(body()), request=request) + )._iter_events() + + # First event arrives normally — the drop is mid-stream, not at connect. + sse = await iter_next(iterator) + assert sse.event == "completion" + + with pytest.raises(APIConnectionError) as exc_info: + await iter_next(iterator) + assert isinstance(exc_info.value.__cause__, httpx.RemoteProtocolError) + assert "Stream interrupted" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_mid_stream_timeout_is_not_wrapped( + sync: bool, + client: Anthropic, + async_client: AsyncAnthropic, +) -> None: + """TimeoutException is a TransportError subclass, but the wrapping clause must + NOT double-wrap it — APITimeoutError already exists for timeouts and is itself + an APIConnectionError subclass. The bare httpx.TimeoutException should pass + through so callers can map it to APITimeoutError if they want.""" + + def body() -> Iterator[bytes]: + yield b"event: completion\n" + raise httpx.ReadTimeout("read timeout") + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + with pytest.raises(httpx.ReadTimeout): + await iter_next(iterator) + await iter_next(iterator) + + @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) async def test_error_type( sync: bool,