Skip to content

Commit d1674eb

Browse files
committed
fix: handle premature stream termination for Anthropic (#1868)
1 parent 1682a0c commit d1674eb

File tree

2 files changed

+78
-17
lines changed

2 files changed

+78
-17
lines changed

src/strands/models/anthropic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,12 @@ async def stream(
409409
if event.type in AnthropicModel.EVENT_TYPES:
410410
yield self.format_chunk(event.model_dump())
411411

412-
usage = event.message.usage # type: ignore
413-
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})
412+
try:
413+
message_snapshot = await stream.get_final_message()
414+
except AssertionError as e:
415+
logger.warning("error=<%s> | failed to retrieve message snapshot, usage metadata unavailable", e)
416+
else:
417+
yield self.format_chunk({"type": "metadata", "usage": message_snapshot.usage.model_dump()})
414418

415419
except anthropic.RateLimitError as error:
416420
raise ModelThrottledException(str(error)) from error

tests/strands/models/test_anthropic.py

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ class TestOutputModel(pydantic.BaseModel):
5252
return TestOutputModel
5353

5454

55+
def generate_mock_stream_context(events, final_message=None):
56+
mock_stream = unittest.mock.AsyncMock()
57+
58+
async def mock_aiter(self):
59+
for event in events:
60+
yield event
61+
62+
mock_stream.__aiter__ = mock_aiter
63+
if isinstance(final_message, Exception):
64+
mock_stream.get_final_message.side_effect = final_message
65+
elif final_message:
66+
mock_stream.get_final_message.return_value = final_message
67+
68+
mock_context = unittest.mock.AsyncMock()
69+
mock_context.__aenter__.return_value = mock_stream
70+
return mock_context
71+
72+
5573
def test__init__model_configs(anthropic_client, model_id, max_tokens):
5674
_ = anthropic_client
5775

@@ -692,7 +710,7 @@ def test_format_chunk_unknown(model):
692710

693711

694712
@pytest.mark.asyncio
695-
async def test_stream(anthropic_client, model, agenerator, alist):
713+
async def test_stream(anthropic_client, model, alist):
696714
mock_event_1 = unittest.mock.Mock(
697715
type="message_start",
698716
dict=lambda: {"type": "message_start"},
@@ -713,9 +731,14 @@ async def test_stream(anthropic_client, model, agenerator, alist):
713731
),
714732
)
715733

716-
mock_context = unittest.mock.AsyncMock()
717-
mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3])
718-
anthropic_client.messages.stream.return_value = mock_context
734+
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
735+
[mock_event_1, mock_event_2, mock_event_3],
736+
final_message=unittest.mock.Mock(
737+
usage=unittest.mock.Mock(
738+
model_dump=lambda: {"input_tokens": 1, "output_tokens": 2},
739+
)
740+
),
741+
)
719742

720743
messages = [{"role": "user", "content": [{"text": "hello"}]}]
721744
response = model.stream(messages, None, None)
@@ -738,6 +761,42 @@ async def test_stream(anthropic_client, model, agenerator, alist):
738761
anthropic_client.messages.stream.assert_called_once_with(**expected_request)
739762

740763

764+
@pytest.mark.asyncio
765+
async def test_stream_early_termination(anthropic_client, model, alist, caplog):
766+
caplog.set_level(logging.WARNING, logger="strands.models.anthropic")
767+
mock_event = unittest.mock.Mock(
768+
type="message_start",
769+
model_dump=lambda: {"type": "message_start"},
770+
)
771+
772+
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
773+
[mock_event],
774+
final_message=AssertionError("message snapshot is not available"),
775+
)
776+
777+
messages = [{"role": "user", "content": [{"text": "hello"}]}]
778+
tru_events = await alist(model.stream(messages, None, None))
779+
780+
assert len(tru_events) == 1
781+
assert "messageStart" in tru_events[0]
782+
assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text
783+
784+
785+
@pytest.mark.asyncio
786+
async def test_stream_empty(anthropic_client, model, alist, caplog):
787+
caplog.set_level(logging.WARNING, logger="strands.models.anthropic")
788+
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
789+
[],
790+
final_message=AssertionError("message snapshot is not available"),
791+
)
792+
793+
messages = [{"role": "user", "content": [{"text": "hello"}]}]
794+
tru_events = await alist(model.stream(messages, None, None))
795+
796+
assert tru_events == []
797+
assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text
798+
799+
741800
@pytest.mark.asyncio
742801
async def test_stream_rate_limit_error(anthropic_client, model, alist):
743802
anthropic_client.messages.stream.side_effect = anthropic.RateLimitError(
@@ -780,7 +839,7 @@ async def test_stream_bad_request_error(anthropic_client, model):
780839

781840

782841
@pytest.mark.asyncio
783-
async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist):
842+
async def test_structured_output(anthropic_client, model, test_output_model_cls, alist):
784843
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
785844

786845
events = [
@@ -815,18 +874,16 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
815874
return_value={"type": "message_stop", "message": {"stop_reason": "tool_use"}}
816875
),
817876
),
818-
unittest.mock.Mock(
819-
message=unittest.mock.Mock(
820-
usage=unittest.mock.Mock(
821-
model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0})
822-
),
823-
),
824-
),
825877
]
826878

827-
mock_context = unittest.mock.AsyncMock()
828-
mock_context.__aenter__.return_value = agenerator(events)
829-
anthropic_client.messages.stream.return_value = mock_context
879+
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
880+
events,
881+
final_message=unittest.mock.Mock(
882+
usage=unittest.mock.Mock(
883+
model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0})
884+
),
885+
),
886+
)
830887

831888
stream = model.structured_output(test_output_model_cls, messages)
832889
events = await alist(stream)

0 commit comments

Comments
 (0)