Skip to content

Commit 2f9ffb1

Browse files
fix: handle premature stream termination for Anthropic (#1868) (#2047)
1 parent 46937d2 commit 2f9ffb1

File tree

2 files changed

+78
-17
lines changed

2 files changed

+78
-17
lines changed

src/strands/models/anthropic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,12 @@ async def stream(
419419
else:
420420
yield self.format_chunk(event.model_dump())
421421

422-
usage = event.message.usage # type: ignore
423-
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})
422+
try:
423+
message_snapshot = await stream.get_final_message()
424+
except AssertionError as e:
425+
logger.warning("error=<%s> | failed to retrieve message snapshot, usage metadata unavailable", e)
426+
else:
427+
yield self.format_chunk({"type": "metadata", "usage": message_snapshot.usage.model_dump()})
424428

425429
except anthropic.RateLimitError as error:
426430
raise ModelThrottledException(str(error)) from error

tests/strands/models/test_anthropic.py

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,24 @@ class TestOutputModel(pydantic.BaseModel):
5353
return TestOutputModel
5454

5555

56+
def generate_mock_stream_context(events, final_message=None):
57+
mock_stream = unittest.mock.AsyncMock()
58+
59+
async def mock_aiter(self):
60+
for event in events:
61+
yield event
62+
63+
mock_stream.__aiter__ = mock_aiter
64+
if isinstance(final_message, Exception):
65+
mock_stream.get_final_message.side_effect = final_message
66+
elif final_message:
67+
mock_stream.get_final_message.return_value = final_message
68+
69+
mock_context = unittest.mock.AsyncMock()
70+
mock_context.__aenter__.return_value = mock_stream
71+
return mock_context
72+
73+
5674
def test__init__model_configs(anthropic_client, model_id, max_tokens):
5775
_ = anthropic_client
5876

@@ -693,7 +711,7 @@ def test_format_chunk_unknown(model):
693711

694712

695713
@pytest.mark.asyncio
696-
async def test_stream(anthropic_client, model, agenerator, alist):
714+
async def test_stream(anthropic_client, model, alist):
697715
mock_event_1 = unittest.mock.Mock(
698716
type="message_start",
699717
dict=lambda: {"type": "message_start"},
@@ -714,9 +732,14 @@ async def test_stream(anthropic_client, model, agenerator, alist):
714732
),
715733
)
716734

717-
mock_context = unittest.mock.AsyncMock()
718-
mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3])
719-
anthropic_client.messages.stream.return_value = mock_context
735+
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
736+
[mock_event_1, mock_event_2, mock_event_3],
737+
final_message=unittest.mock.Mock(
738+
usage=unittest.mock.Mock(
739+
model_dump=lambda: {"input_tokens": 1, "output_tokens": 2},
740+
)
741+
),
742+
)
720743

721744
messages = [{"role": "user", "content": [{"text": "hello"}]}]
722745
response = model.stream(messages, None, None)
@@ -739,6 +762,42 @@ async def test_stream(anthropic_client, model, agenerator, alist):
739762
anthropic_client.messages.stream.assert_called_once_with(**expected_request)
740763

741764

765+
@pytest.mark.asyncio
766+
async def test_stream_early_termination(anthropic_client, model, alist, caplog):
767+
caplog.set_level(logging.WARNING, logger="strands.models.anthropic")
768+
mock_event = unittest.mock.Mock(
769+
type="message_start",
770+
model_dump=lambda: {"type": "message_start"},
771+
)
772+
773+
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
774+
[mock_event],
775+
final_message=AssertionError("message snapshot is not available"),
776+
)
777+
778+
messages = [{"role": "user", "content": [{"text": "hello"}]}]
779+
tru_events = await alist(model.stream(messages, None, None))
780+
781+
assert len(tru_events) == 1
782+
assert "messageStart" in tru_events[0]
783+
assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text
784+
785+
786+
@pytest.mark.asyncio
787+
async def test_stream_empty(anthropic_client, model, alist, caplog):
788+
caplog.set_level(logging.WARNING, logger="strands.models.anthropic")
789+
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
790+
[],
791+
final_message=AssertionError("message snapshot is not available"),
792+
)
793+
794+
messages = [{"role": "user", "content": [{"text": "hello"}]}]
795+
tru_events = await alist(model.stream(messages, None, None))
796+
797+
assert tru_events == []
798+
assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text
799+
800+
742801
@pytest.mark.asyncio
743802
async def test_stream_rate_limit_error(anthropic_client, model, alist):
744803
anthropic_client.messages.stream.side_effect = anthropic.RateLimitError(
@@ -781,7 +840,7 @@ async def test_stream_bad_request_error(anthropic_client, model):
781840

782841

783842
@pytest.mark.asyncio
784-
async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist):
843+
async def test_structured_output(anthropic_client, model, test_output_model_cls, alist):
785844
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
786845

787846
events = [
@@ -817,18 +876,16 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
817876
return_value={"type": "message_stop", "message": {"stop_reason": "tool_use"}}
818877
),
819878
),
820-
unittest.mock.Mock(
821-
message=unittest.mock.Mock(
822-
usage=unittest.mock.Mock(
823-
model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0})
824-
),
825-
),
826-
),
827879
]
828880

829-
mock_context = unittest.mock.AsyncMock()
830-
mock_context.__aenter__.return_value = agenerator(events)
831-
anthropic_client.messages.stream.return_value = mock_context
881+
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
882+
events,
883+
final_message=unittest.mock.Mock(
884+
usage=unittest.mock.Mock(
885+
model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0})
886+
),
887+
),
888+
)
832889

833890
stream = model.structured_output(test_output_model_cls, messages)
834891
events = await alist(stream)

0 commit comments

Comments
 (0)