Skip to content

Commit 412f152

Browse files
committed
fix(bedrock): consume orphaned task exception on stream cancellation
When a consumer cancels, breaks from, or times out on BedrockModel.stream, the internal asyncio.Task wrapping asyncio.to_thread is never awaited. If boto3 eventually raises, asyncio emits 'Task exception was never retrieved'. Add a done-callback on the unhappy path that retrieves the exception, silencing the warning without interrupting the background thread. Resolves: #2266
1 parent f862185 commit 412f152

2 files changed

Lines changed: 118 additions & 8 deletions

File tree

src/strands/models/bedrock.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def _clear_unsupported_count_tokens_cache() -> None:
6464
_UNSUPPORTED_COUNT_TOKENS_MODELS.clear()
6565

6666

67+
def _suppress_task_exception(task: "asyncio.Task[None]") -> None:
68+
"""Consume exception from orphaned stream task to silence 'never retrieved' warning."""
69+
if not task.cancelled():
70+
task.exception()
71+
72+
6773
T = TypeVar("T", bound=BaseModel)
6874

6975
DEFAULT_READ_TIMEOUT = 120
@@ -898,14 +904,17 @@ def callback(event: StreamEvent | None = None) -> None:
898904
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt_content, tool_choice)
899905
task = asyncio.create_task(thread)
900906

901-
while True:
902-
event = await queue.get()
903-
if event is None:
904-
break
905-
906-
yield event
907-
908-
await task
907+
try:
908+
while True:
909+
event = await queue.get()
910+
if event is None:
911+
break
912+
913+
yield event
914+
await task
915+
except BaseException:
916+
task.add_done_callback(_suppress_task_exception)
917+
raise
909918

910919
def _stream(
911920
self,

tests/strands/models/test_bedrock.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import copy
23
import logging
34
import os
@@ -20,6 +21,7 @@
2021
DEFAULT_BEDROCK_REGION,
2122
DEFAULT_READ_TIMEOUT,
2223
_clear_unsupported_count_tokens_cache,
24+
_suppress_task_exception,
2325
)
2426
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
2527
from strands.types.tools import ToolSpec
@@ -3495,3 +3497,102 @@ async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_c
34953497
bedrock_client.count_tokens.assert_not_called()
34963498
assert isinstance(result, int)
34973499
assert result >= 0
3500+
3501+
3502+
@pytest.mark.asyncio
3503+
async def test_suppress_task_exception(bedrock_client, model, messages):
3504+
"""_suppress_task_exception consumes exception from a failed task without re-raising."""
3505+
3506+
async def fail() -> None:
3507+
raise RuntimeError("inner task failure")
3508+
3509+
task = asyncio.create_task(fail())
3510+
await asyncio.sleep(0) # let the task complete with exception
3511+
3512+
assert task.done()
3513+
assert task.exception() is not None
3514+
3515+
# Calling the helper should not raise — it simply retrieves the exception
3516+
_suppress_task_exception(task)
3517+
3518+
3519+
@pytest.mark.asyncio
3520+
async def test_suppress_task_exception_skips_cancelled():
3521+
"""_suppress_task_exception is a no-op for cancelled tasks."""
3522+
3523+
async def hang() -> None:
3524+
await asyncio.sleep(999)
3525+
3526+
task = asyncio.create_task(hang())
3527+
task.cancel()
3528+
with pytest.raises(asyncio.CancelledError):
3529+
await task
3530+
3531+
# Should not raise — cancelled tasks are skipped
3532+
_suppress_task_exception(task)
3533+
3534+
3535+
@pytest.mark.asyncio
3536+
async def test_stream_break_does_not_leak_task_exception(bedrock_client, model, messages, caplog, alist):
3537+
"""Breaking from an async-for on BedrockModel.stream must not leak the inner task's exception."""
3538+
caplog.set_level(logging.WARNING, logger="asyncio")
3539+
3540+
# Mock converse_stream to yield one event then raise — simulates e.g. ReadTimeoutError
3541+
# in the boto3 thread *after* the consumer has disconnected.
3542+
3543+
def stream_with_error():
3544+
yield {"messageStart": {"role": "assistant"}}
3545+
raise RuntimeError("simulated boto3 timeout after consumer disconnect")
3546+
3547+
bedrock_client.converse_stream.return_value = {"stream": stream_with_error()}
3548+
3549+
stream = model.stream(messages)
3550+
collected: list = []
3551+
async for event in stream:
3552+
collected.append(event)
3553+
break # disconnect before the generator raises
3554+
3555+
# Let the event loop process the done-callback and the thread task
3556+
await asyncio.sleep(0.01)
3557+
3558+
# Verify we got the event before breaking
3559+
assert len(collected) == 1
3560+
3561+
# The critical assertion: no "Task exception was never retrieved" warning
3562+
assert "Task exception was never retrieved" not in caplog.text
3563+
# Also ensure no exception propagates to consumer
3564+
assert "exception was never retrieved" not in caplog.text.lower()
3565+
3566+
3567+
@pytest.mark.asyncio
3568+
async def test_stream_timeout_cancellation_does_not_leak(
3569+
bedrock_client,
3570+
model,
3571+
messages,
3572+
caplog,
3573+
):
3574+
"""Applying asyncio.wait_for on BedrockModel.stream must not leak the inner task's exception."""
3575+
caplog.set_level(logging.WARNING, logger="asyncio")
3576+
3577+
# Make converse_stream yield slowly so wait_for fires first
3578+
import time
3579+
3580+
def slow_stream():
3581+
time.sleep(0.05) # simulate a slow network call
3582+
yield {"messageStart": {"role": "assistant"}}
3583+
time.sleep(0.05)
3584+
raise RuntimeError("boto3 timeout after consumer disconnected")
3585+
3586+
bedrock_client.converse_stream.return_value = {"stream": slow_stream()}
3587+
3588+
stream = model.stream(messages)
3589+
with pytest.raises(TimeoutError):
3590+
# Very short timeout — fires before the slow stream finishes
3591+
await asyncio.wait_for(stream.__anext__(), timeout=0.001)
3592+
3593+
# Let event loop settle
3594+
await asyncio.sleep(0.01)
3595+
3596+
# Critical: no orphaned-task warning
3597+
assert "Task exception was never retrieved" not in caplog.text
3598+
assert "exception was never retrieved" not in caplog.text.lower()

0 commit comments

Comments
 (0)