|
| 1 | +import asyncio |
1 | 2 | import copy |
2 | 3 | import logging |
3 | 4 | import os |
|
20 | 21 | DEFAULT_BEDROCK_REGION, |
21 | 22 | DEFAULT_READ_TIMEOUT, |
22 | 23 | _clear_unsupported_count_tokens_cache, |
| 24 | + _suppress_task_exception, |
23 | 25 | ) |
24 | 26 | from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException |
25 | 27 | from strands.types.tools import ToolSpec |
@@ -3495,3 +3497,102 @@ async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_c |
3495 | 3497 | bedrock_client.count_tokens.assert_not_called() |
3496 | 3498 | assert isinstance(result, int) |
3497 | 3499 | assert result >= 0 |
| 3500 | + |
| 3501 | + |
| 3502 | +@pytest.mark.asyncio |
| 3503 | +async def test_suppress_task_exception(bedrock_client, model, messages): |
| 3504 | + """_suppress_task_exception consumes exception from a failed task without re-raising.""" |
| 3505 | + |
| 3506 | + async def fail() -> None: |
| 3507 | + raise RuntimeError("inner task failure") |
| 3508 | + |
| 3509 | + task = asyncio.create_task(fail()) |
| 3510 | + await asyncio.sleep(0) # let the task complete with exception |
| 3511 | + |
| 3512 | + assert task.done() |
| 3513 | + assert task.exception() is not None |
| 3514 | + |
| 3515 | + # Calling the helper should not raise — it simply retrieves the exception |
| 3516 | + _suppress_task_exception(task) |
| 3517 | + |
| 3518 | + |
| 3519 | +@pytest.mark.asyncio |
| 3520 | +async def test_suppress_task_exception_skips_cancelled(): |
| 3521 | + """_suppress_task_exception is a no-op for cancelled tasks.""" |
| 3522 | + |
| 3523 | + async def hang() -> None: |
| 3524 | + await asyncio.sleep(999) |
| 3525 | + |
| 3526 | + task = asyncio.create_task(hang()) |
| 3527 | + task.cancel() |
| 3528 | + with pytest.raises(asyncio.CancelledError): |
| 3529 | + await task |
| 3530 | + |
| 3531 | + # Should not raise — cancelled tasks are skipped |
| 3532 | + _suppress_task_exception(task) |
| 3533 | + |
| 3534 | + |
| 3535 | +@pytest.mark.asyncio |
| 3536 | +async def test_stream_break_does_not_leak_task_exception(bedrock_client, model, messages, caplog, alist): |
| 3537 | + """Breaking from an async-for on BedrockModel.stream must not leak the inner task's exception.""" |
| 3538 | + caplog.set_level(logging.WARNING, logger="asyncio") |
| 3539 | + |
| 3540 | + # Mock converse_stream to yield one event then raise — simulates e.g. ReadTimeoutError |
| 3541 | + # in the boto3 thread *after* the consumer has disconnected. |
| 3542 | + |
| 3543 | + def stream_with_error(): |
| 3544 | + yield {"messageStart": {"role": "assistant"}} |
| 3545 | + raise RuntimeError("simulated boto3 timeout after consumer disconnect") |
| 3546 | + |
| 3547 | + bedrock_client.converse_stream.return_value = {"stream": stream_with_error()} |
| 3548 | + |
| 3549 | + stream = model.stream(messages) |
| 3550 | + collected: list = [] |
| 3551 | + async for event in stream: |
| 3552 | + collected.append(event) |
| 3553 | + break # disconnect before the generator raises |
| 3554 | + |
| 3555 | + # Let the event loop process the done-callback and the thread task |
| 3556 | + await asyncio.sleep(0.01) |
| 3557 | + |
| 3558 | + # Verify we got the event before breaking |
| 3559 | + assert len(collected) == 1 |
| 3560 | + |
| 3561 | + # The critical assertion: no "Task exception was never retrieved" warning |
| 3562 | + assert "Task exception was never retrieved" not in caplog.text |
| 3563 | + # Also ensure no exception propagates to consumer |
| 3564 | + assert "exception was never retrieved" not in caplog.text.lower() |
| 3565 | + |
| 3566 | + |
| 3567 | +@pytest.mark.asyncio |
| 3568 | +async def test_stream_timeout_cancellation_does_not_leak( |
| 3569 | + bedrock_client, |
| 3570 | + model, |
| 3571 | + messages, |
| 3572 | + caplog, |
| 3573 | +): |
| 3574 | + """Applying asyncio.wait_for on BedrockModel.stream must not leak the inner task's exception.""" |
| 3575 | + caplog.set_level(logging.WARNING, logger="asyncio") |
| 3576 | + |
| 3577 | + # Make converse_stream yield slowly so wait_for fires first |
| 3578 | + import time |
| 3579 | + |
| 3580 | + def slow_stream(): |
| 3581 | + time.sleep(0.05) # simulate a slow network call |
| 3582 | + yield {"messageStart": {"role": "assistant"}} |
| 3583 | + time.sleep(0.05) |
| 3584 | + raise RuntimeError("boto3 timeout after consumer disconnected") |
| 3585 | + |
| 3586 | + bedrock_client.converse_stream.return_value = {"stream": slow_stream()} |
| 3587 | + |
| 3588 | + stream = model.stream(messages) |
| 3589 | + with pytest.raises(TimeoutError): |
| 3590 | + # Very short timeout — fires before the slow stream finishes |
| 3591 | + await asyncio.wait_for(stream.__anext__(), timeout=0.001) |
| 3592 | + |
| 3593 | + # Let event loop settle |
| 3594 | + await asyncio.sleep(0.01) |
| 3595 | + |
| 3596 | + # Critical: no orphaned-task warning |
| 3597 | + assert "Task exception was never retrieved" not in caplog.text |
| 3598 | + assert "exception was never retrieved" not in caplog.text.lower() |
0 commit comments