@@ -52,6 +52,21 @@ class TestOutputModel(pydantic.BaseModel):
5252 return TestOutputModel
5353
5454
55+ def generate_mock_stream (events , final_message = None ):
56+ mock_stream = unittest .mock .AsyncMock ()
57+
58+ async def mock_aiter (self ):
59+ for event in events :
60+ yield event
61+
62+ mock_stream .__aiter__ = mock_aiter
63+ if isinstance (final_message , Exception ):
64+ mock_stream .get_final_message .side_effect = final_message
65+ elif final_message :
66+ mock_stream .get_final_message .return_value = final_message
67+ return mock_stream
68+
69+
5570def test__init__model_configs (anthropic_client , model_id , max_tokens ):
5671 _ = anthropic_client
5772
@@ -692,7 +707,7 @@ def test_format_chunk_unknown(model):
692707
693708
694709@pytest .mark .asyncio
695- async def test_stream (anthropic_client , model , agenerator , alist ):
710+ async def test_stream (anthropic_client , model , alist ):
696711 mock_event_1 = unittest .mock .Mock (
697712 type = "message_start" ,
698713 dict = lambda : {"type" : "message_start" },
@@ -713,8 +728,17 @@ async def test_stream(anthropic_client, model, agenerator, alist):
713728 ),
714729 )
715730
731+ mock_stream = generate_mock_stream (
732+ [mock_event_1 , mock_event_2 , mock_event_3 ],
733+ final_message = unittest .mock .Mock (
734+ usage = unittest .mock .Mock (
735+ model_dump = lambda : {"input_tokens" : 1 , "output_tokens" : 2 },
736+ )
737+ ),
738+ )
739+
716740 mock_context = unittest .mock .AsyncMock ()
717- mock_context .__aenter__ .return_value = agenerator ([ mock_event_1 , mock_event_2 , mock_event_3 ])
741+ mock_context .__aenter__ .return_value = mock_stream
718742 anthropic_client .messages .stream .return_value = mock_context
719743
720744 messages = [{"role" : "user" , "content" : [{"text" : "hello" }]}]
@@ -738,6 +762,50 @@ async def test_stream(anthropic_client, model, agenerator, alist):
738762 anthropic_client .messages .stream .assert_called_once_with (** expected_request )
739763
740764
765+ @pytest .mark .asyncio
766+ async def test_stream_early_termination (anthropic_client , model , alist , caplog ):
767+ caplog .set_level (logging .WARNING , logger = "strands.models.anthropic" )
768+ mock_event = unittest .mock .Mock (
769+ type = "message_start" ,
770+ model_dump = lambda : {"type" : "message_start" },
771+ )
772+
773+ mock_stream = generate_mock_stream (
774+ [mock_event ],
775+ final_message = AssertionError ("message snapshot is not available" ),
776+ )
777+
778+ mock_context = unittest .mock .AsyncMock ()
779+ mock_context .__aenter__ .return_value = mock_stream
780+ anthropic_client .messages .stream .return_value = mock_context
781+
782+ messages = [{"role" : "user" , "content" : [{"text" : "hello" }]}]
783+ tru_events = await alist (model .stream (messages , None , None ))
784+
785+ assert len (tru_events ) == 1
786+ assert "messageStart" in tru_events [0 ]
787+ assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog .text
788+
789+
790+ @pytest .mark .asyncio
791+ async def test_stream_empty (anthropic_client , model , alist , caplog ):
792+ caplog .set_level (logging .WARNING , logger = "strands.models.anthropic" )
793+ mock_stream = generate_mock_stream (
794+ [],
795+ final_message = AssertionError ("message snapshot is not available" ),
796+ )
797+
798+ mock_context = unittest .mock .AsyncMock ()
799+ mock_context .__aenter__ .return_value = mock_stream
800+ anthropic_client .messages .stream .return_value = mock_context
801+
802+ messages = [{"role" : "user" , "content" : [{"text" : "hello" }]}]
803+ tru_events = await alist (model .stream (messages , None , None ))
804+
805+ assert tru_events == []
806+ assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog .text
807+
808+
741809@pytest .mark .asyncio
742810async def test_stream_rate_limit_error (anthropic_client , model , alist ):
743811 anthropic_client .messages .stream .side_effect = anthropic .RateLimitError (
@@ -780,7 +848,7 @@ async def test_stream_bad_request_error(anthropic_client, model):
780848
781849
782850@pytest .mark .asyncio
783- async def test_structured_output (anthropic_client , model , test_output_model_cls , agenerator , alist ):
851+ async def test_structured_output (anthropic_client , model , test_output_model_cls , alist ):
784852 messages = [{"role" : "user" , "content" : [{"text" : "Generate a person" }]}]
785853
786854 events = [
@@ -815,17 +883,19 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
815883 return_value = {"type" : "message_stop" , "message" : {"stop_reason" : "tool_use" }}
816884 ),
817885 ),
818- unittest .mock .Mock (
819- message = unittest .mock .Mock (
820- usage = unittest .mock .Mock (
821- model_dump = unittest .mock .Mock (return_value = {"input_tokens" : 0 , "output_tokens" : 0 })
822- ),
886+ ]
887+
888+ mock_stream = generate_mock_stream (
889+ events ,
890+ final_message = unittest .mock .Mock (
891+ usage = unittest .mock .Mock (
892+ model_dump = unittest .mock .Mock (return_value = {"input_tokens" : 0 , "output_tokens" : 0 })
823893 ),
824894 ),
825- ]
895+ )
826896
827897 mock_context = unittest .mock .AsyncMock ()
828- mock_context .__aenter__ .return_value = agenerator ( events )
898+ mock_context .__aenter__ .return_value = mock_stream
829899 anthropic_client .messages .stream .return_value = mock_context
830900
831901 stream = model .structured_output (test_output_model_cls , messages )
0 commit comments