@@ -52,6 +52,24 @@ class TestOutputModel(pydantic.BaseModel):
5252 return TestOutputModel
5353
5454
55+ def generate_mock_stream_context (events , final_message = None ):
56+ mock_stream = unittest .mock .AsyncMock ()
57+
58+ async def mock_aiter (self ):
59+ for event in events :
60+ yield event
61+
62+ mock_stream .__aiter__ = mock_aiter
63+ if isinstance (final_message , Exception ):
64+ mock_stream .get_final_message .side_effect = final_message
65+ elif final_message :
66+ mock_stream .get_final_message .return_value = final_message
67+
68+ mock_context = unittest .mock .AsyncMock ()
69+ mock_context .__aenter__ .return_value = mock_stream
70+ return mock_context
71+
72+
5573def test__init__model_configs (anthropic_client , model_id , max_tokens ):
5674 _ = anthropic_client
5775
@@ -692,7 +710,7 @@ def test_format_chunk_unknown(model):
692710
693711
694712@pytest .mark .asyncio
695- async def test_stream (anthropic_client , model , agenerator , alist ):
713+ async def test_stream (anthropic_client , model , alist ):
696714 mock_event_1 = unittest .mock .Mock (
697715 type = "message_start" ,
698716 dict = lambda : {"type" : "message_start" },
@@ -713,9 +731,14 @@ async def test_stream(anthropic_client, model, agenerator, alist):
713731 ),
714732 )
715733
716- mock_context = unittest .mock .AsyncMock ()
717- mock_context .__aenter__ .return_value = agenerator ([mock_event_1 , mock_event_2 , mock_event_3 ])
718- anthropic_client .messages .stream .return_value = mock_context
734+ anthropic_client .messages .stream .return_value = generate_mock_stream_context (
735+ [mock_event_1 , mock_event_2 , mock_event_3 ],
736+ final_message = unittest .mock .Mock (
737+ usage = unittest .mock .Mock (
738+ model_dump = lambda : {"input_tokens" : 1 , "output_tokens" : 2 },
739+ )
740+ ),
741+ )
719742
720743 messages = [{"role" : "user" , "content" : [{"text" : "hello" }]}]
721744 response = model .stream (messages , None , None )
@@ -738,6 +761,42 @@ async def test_stream(anthropic_client, model, agenerator, alist):
738761 anthropic_client .messages .stream .assert_called_once_with (** expected_request )
739762
740763
764+ @pytest .mark .asyncio
765+ async def test_stream_early_termination (anthropic_client , model , alist , caplog ):
766+ caplog .set_level (logging .WARNING , logger = "strands.models.anthropic" )
767+ mock_event = unittest .mock .Mock (
768+ type = "message_start" ,
769+ model_dump = lambda : {"type" : "message_start" },
770+ )
771+
772+ anthropic_client .messages .stream .return_value = generate_mock_stream_context (
773+ [mock_event ],
774+ final_message = AssertionError ("message snapshot is not available" ),
775+ )
776+
777+ messages = [{"role" : "user" , "content" : [{"text" : "hello" }]}]
778+ tru_events = await alist (model .stream (messages , None , None ))
779+
780+ assert len (tru_events ) == 1
781+ assert "messageStart" in tru_events [0 ]
782+ assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog .text
783+
784+
785+ @pytest .mark .asyncio
786+ async def test_stream_empty (anthropic_client , model , alist , caplog ):
787+ caplog .set_level (logging .WARNING , logger = "strands.models.anthropic" )
788+ anthropic_client .messages .stream .return_value = generate_mock_stream_context (
789+ [],
790+ final_message = AssertionError ("message snapshot is not available" ),
791+ )
792+
793+ messages = [{"role" : "user" , "content" : [{"text" : "hello" }]}]
794+ tru_events = await alist (model .stream (messages , None , None ))
795+
796+ assert tru_events == []
797+ assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog .text
798+
799+
741800@pytest .mark .asyncio
742801async def test_stream_rate_limit_error (anthropic_client , model , alist ):
743802 anthropic_client .messages .stream .side_effect = anthropic .RateLimitError (
@@ -780,7 +839,7 @@ async def test_stream_bad_request_error(anthropic_client, model):
780839
781840
782841@pytest .mark .asyncio
783- async def test_structured_output (anthropic_client , model , test_output_model_cls , agenerator , alist ):
842+ async def test_structured_output (anthropic_client , model , test_output_model_cls , alist ):
784843 messages = [{"role" : "user" , "content" : [{"text" : "Generate a person" }]}]
785844
786845 events = [
@@ -815,18 +874,16 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
815874 return_value = {"type" : "message_stop" , "message" : {"stop_reason" : "tool_use" }}
816875 ),
817876 ),
818- unittest .mock .Mock (
819- message = unittest .mock .Mock (
820- usage = unittest .mock .Mock (
821- model_dump = unittest .mock .Mock (return_value = {"input_tokens" : 0 , "output_tokens" : 0 })
822- ),
823- ),
824- ),
825877 ]
826878
827- mock_context = unittest .mock .AsyncMock ()
828- mock_context .__aenter__ .return_value = agenerator (events )
829- anthropic_client .messages .stream .return_value = mock_context
879+ anthropic_client .messages .stream .return_value = generate_mock_stream_context (
880+ events ,
881+ final_message = unittest .mock .Mock (
882+ usage = unittest .mock .Mock (
883+ model_dump = unittest .mock .Mock (return_value = {"input_tokens" : 0 , "output_tokens" : 0 })
884+ ),
885+ ),
886+ )
830887
831888 stream = model .structured_output (test_output_model_cls , messages )
832889 events = await alist (stream )
0 commit comments