@@ -53,6 +53,24 @@ class TestOutputModel(pydantic.BaseModel):
5353 return TestOutputModel
5454
5555
56+ def generate_mock_stream_context (events , final_message = None ):
57+ mock_stream = unittest .mock .AsyncMock ()
58+
59+ async def mock_aiter (self ):
60+ for event in events :
61+ yield event
62+
63+ mock_stream .__aiter__ = mock_aiter
64+ if isinstance (final_message , Exception ):
65+ mock_stream .get_final_message .side_effect = final_message
66+ elif final_message :
67+ mock_stream .get_final_message .return_value = final_message
68+
69+ mock_context = unittest .mock .AsyncMock ()
70+ mock_context .__aenter__ .return_value = mock_stream
71+ return mock_context
72+
73+
5674def test__init__model_configs (anthropic_client , model_id , max_tokens ):
5775 _ = anthropic_client
5876
@@ -693,7 +711,7 @@ def test_format_chunk_unknown(model):
693711
694712
695713@pytest .mark .asyncio
696- async def test_stream (anthropic_client , model , agenerator , alist ):
714+ async def test_stream (anthropic_client , model , alist ):
697715 mock_event_1 = unittest .mock .Mock (
698716 type = "message_start" ,
699717 dict = lambda : {"type" : "message_start" },
@@ -714,9 +732,14 @@ async def test_stream(anthropic_client, model, agenerator, alist):
714732 ),
715733 )
716734
717- mock_context = unittest .mock .AsyncMock ()
718- mock_context .__aenter__ .return_value = agenerator ([mock_event_1 , mock_event_2 , mock_event_3 ])
719- anthropic_client .messages .stream .return_value = mock_context
735+ anthropic_client .messages .stream .return_value = generate_mock_stream_context (
736+ [mock_event_1 , mock_event_2 , mock_event_3 ],
737+ final_message = unittest .mock .Mock (
738+ usage = unittest .mock .Mock (
739+ model_dump = lambda : {"input_tokens" : 1 , "output_tokens" : 2 },
740+ )
741+ ),
742+ )
720743
721744 messages = [{"role" : "user" , "content" : [{"text" : "hello" }]}]
722745 response = model .stream (messages , None , None )
@@ -739,6 +762,42 @@ async def test_stream(anthropic_client, model, agenerator, alist):
739762 anthropic_client .messages .stream .assert_called_once_with (** expected_request )
740763
741764
765+ @pytest .mark .asyncio
766+ async def test_stream_early_termination (anthropic_client , model , alist , caplog ):
767+ caplog .set_level (logging .WARNING , logger = "strands.models.anthropic" )
768+ mock_event = unittest .mock .Mock (
769+ type = "message_start" ,
770+ model_dump = lambda : {"type" : "message_start" },
771+ )
772+
773+ anthropic_client .messages .stream .return_value = generate_mock_stream_context (
774+ [mock_event ],
775+ final_message = AssertionError ("message snapshot is not available" ),
776+ )
777+
778+ messages = [{"role" : "user" , "content" : [{"text" : "hello" }]}]
779+ tru_events = await alist (model .stream (messages , None , None ))
780+
781+ assert len (tru_events ) == 1
782+ assert "messageStart" in tru_events [0 ]
783+ assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog .text
784+
785+
786+ @pytest .mark .asyncio
787+ async def test_stream_empty (anthropic_client , model , alist , caplog ):
788+ caplog .set_level (logging .WARNING , logger = "strands.models.anthropic" )
789+ anthropic_client .messages .stream .return_value = generate_mock_stream_context (
790+ [],
791+ final_message = AssertionError ("message snapshot is not available" ),
792+ )
793+
794+ messages = [{"role" : "user" , "content" : [{"text" : "hello" }]}]
795+ tru_events = await alist (model .stream (messages , None , None ))
796+
797+ assert tru_events == []
798+ assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog .text
799+
800+
742801@pytest .mark .asyncio
743802async def test_stream_rate_limit_error (anthropic_client , model , alist ):
744803 anthropic_client .messages .stream .side_effect = anthropic .RateLimitError (
@@ -781,7 +840,7 @@ async def test_stream_bad_request_error(anthropic_client, model):
781840
782841
783842@pytest .mark .asyncio
784- async def test_structured_output (anthropic_client , model , test_output_model_cls , agenerator , alist ):
843+ async def test_structured_output (anthropic_client , model , test_output_model_cls , alist ):
785844 messages = [{"role" : "user" , "content" : [{"text" : "Generate a person" }]}]
786845
787846 events = [
@@ -817,18 +876,16 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
817876 return_value = {"type" : "message_stop" , "message" : {"stop_reason" : "tool_use" }}
818877 ),
819878 ),
820- unittest .mock .Mock (
821- message = unittest .mock .Mock (
822- usage = unittest .mock .Mock (
823- model_dump = unittest .mock .Mock (return_value = {"input_tokens" : 0 , "output_tokens" : 0 })
824- ),
825- ),
826- ),
827879 ]
828880
829- mock_context = unittest .mock .AsyncMock ()
830- mock_context .__aenter__ .return_value = agenerator (events )
831- anthropic_client .messages .stream .return_value = mock_context
881+ anthropic_client .messages .stream .return_value = generate_mock_stream_context (
882+ events ,
883+ final_message = unittest .mock .Mock (
884+ usage = unittest .mock .Mock (
885+ model_dump = unittest .mock .Mock (return_value = {"input_tokens" : 0 , "output_tokens" : 0 })
886+ ),
887+ ),
888+ )
832889
833890 stream = model .structured_output (test_output_model_cls , messages )
834891 events = await alist (stream )
0 commit comments