Skip to content

Commit 57f4450

Browse files
committed
fix: handle premature stream termination for Anthropic (#1868)
1 parent 1682a0c commit 57f4450

File tree

2 files changed

+86
-12
lines changed

2 files changed

+86
-12
lines changed

src/strands/models/anthropic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,12 @@ async def stream(
409409
if event.type in AnthropicModel.EVENT_TYPES:
410410
yield self.format_chunk(event.model_dump())
411411

412-
usage = event.message.usage # type: ignore
413-
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})
412+
try:
413+
message_snapshot = await stream.get_final_message()
414+
except AssertionError as e:
415+
logger.warning("error=<%s> | failed to retrieve message snapshot, usage metadata unavailable", e)
416+
else:
417+
yield self.format_chunk({"type": "metadata", "usage": message_snapshot.usage.model_dump()})
414418

415419
except anthropic.RateLimitError as error:
416420
raise ModelThrottledException(str(error)) from error

tests/strands/models/test_anthropic.py

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,21 @@ class TestOutputModel(pydantic.BaseModel):
5252
return TestOutputModel
5353

5454

55+
def generate_mock_stream(events, final_message=None):
56+
mock_stream = unittest.mock.AsyncMock()
57+
58+
async def mock_aiter(self):
59+
for event in events:
60+
yield event
61+
62+
mock_stream.__aiter__ = mock_aiter
63+
if isinstance(final_message, Exception):
64+
mock_stream.get_final_message.side_effect = final_message
65+
elif final_message:
66+
mock_stream.get_final_message.return_value = final_message
67+
return mock_stream
68+
69+
5570
def test__init__model_configs(anthropic_client, model_id, max_tokens):
5671
_ = anthropic_client
5772

@@ -692,7 +707,7 @@ def test_format_chunk_unknown(model):
692707

693708

694709
@pytest.mark.asyncio
695-
async def test_stream(anthropic_client, model, agenerator, alist):
710+
async def test_stream(anthropic_client, model, alist):
696711
mock_event_1 = unittest.mock.Mock(
697712
type="message_start",
698713
dict=lambda: {"type": "message_start"},
@@ -713,8 +728,17 @@ async def test_stream(anthropic_client, model, agenerator, alist):
713728
),
714729
)
715730

731+
mock_stream = generate_mock_stream(
732+
[mock_event_1, mock_event_2, mock_event_3],
733+
final_message=unittest.mock.Mock(
734+
usage=unittest.mock.Mock(
735+
model_dump=lambda: {"input_tokens": 1, "output_tokens": 2},
736+
)
737+
),
738+
)
739+
716740
mock_context = unittest.mock.AsyncMock()
717-
mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3])
741+
mock_context.__aenter__.return_value = mock_stream
718742
anthropic_client.messages.stream.return_value = mock_context
719743

720744
messages = [{"role": "user", "content": [{"text": "hello"}]}]
@@ -738,6 +762,50 @@ async def test_stream(anthropic_client, model, agenerator, alist):
738762
anthropic_client.messages.stream.assert_called_once_with(**expected_request)
739763

740764

765+
@pytest.mark.asyncio
766+
async def test_stream_early_termination(anthropic_client, model, alist, caplog):
767+
caplog.set_level(logging.WARNING, logger="strands.models.anthropic")
768+
mock_event = unittest.mock.Mock(
769+
type="message_start",
770+
model_dump=lambda: {"type": "message_start"},
771+
)
772+
773+
mock_stream = generate_mock_stream(
774+
[mock_event],
775+
final_message=AssertionError("message snapshot is not available"),
776+
)
777+
778+
mock_context = unittest.mock.AsyncMock()
779+
mock_context.__aenter__.return_value = mock_stream
780+
anthropic_client.messages.stream.return_value = mock_context
781+
782+
messages = [{"role": "user", "content": [{"text": "hello"}]}]
783+
tru_events = await alist(model.stream(messages, None, None))
784+
785+
assert len(tru_events) == 1
786+
assert "messageStart" in tru_events[0]
787+
assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text
788+
789+
790+
@pytest.mark.asyncio
791+
async def test_stream_empty(anthropic_client, model, alist, caplog):
792+
caplog.set_level(logging.WARNING, logger="strands.models.anthropic")
793+
mock_stream = generate_mock_stream(
794+
[],
795+
final_message=AssertionError("message snapshot is not available"),
796+
)
797+
798+
mock_context = unittest.mock.AsyncMock()
799+
mock_context.__aenter__.return_value = mock_stream
800+
anthropic_client.messages.stream.return_value = mock_context
801+
802+
messages = [{"role": "user", "content": [{"text": "hello"}]}]
803+
tru_events = await alist(model.stream(messages, None, None))
804+
805+
assert tru_events == []
806+
assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text
807+
808+
741809
@pytest.mark.asyncio
742810
async def test_stream_rate_limit_error(anthropic_client, model, alist):
743811
anthropic_client.messages.stream.side_effect = anthropic.RateLimitError(
@@ -780,7 +848,7 @@ async def test_stream_bad_request_error(anthropic_client, model):
780848

781849

782850
@pytest.mark.asyncio
783-
async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist):
851+
async def test_structured_output(anthropic_client, model, test_output_model_cls, alist):
784852
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
785853

786854
events = [
@@ -815,17 +883,19 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
815883
return_value={"type": "message_stop", "message": {"stop_reason": "tool_use"}}
816884
),
817885
),
818-
unittest.mock.Mock(
819-
message=unittest.mock.Mock(
820-
usage=unittest.mock.Mock(
821-
model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0})
822-
),
886+
]
887+
888+
mock_stream = generate_mock_stream(
889+
events,
890+
final_message=unittest.mock.Mock(
891+
usage=unittest.mock.Mock(
892+
model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0})
823893
),
824894
),
825-
]
895+
)
826896

827897
mock_context = unittest.mock.AsyncMock()
828-
mock_context.__aenter__.return_value = agenerator(events)
898+
mock_context.__aenter__.return_value = mock_stream
829899
anthropic_client.messages.stream.return_value = mock_context
830900

831901
stream = model.structured_output(test_output_model_cls, messages)

0 commit comments

Comments
 (0)