Skip to content

Commit 38c1ab6

Browse files
atian8179Unshure
andauthored
fix: override end_turn stop reason when streaming response contains toolUse blocks (#1827)
Co-authored-by: atian8179 <atian8179@users.noreply.github.com> Co-authored-by: Nicholas Clegg <ncclegg@amazon.com>
1 parent ae28397 commit 38c1ab6

File tree

4 files changed

+104
-80
lines changed

4 files changed

+104
-80
lines changed

src/strands/event_loop/streaming.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,16 +324,31 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
324324
return state
325325

326326

327-
def handle_message_stop(event: MessageStopEvent) -> StopReason:
327+
def handle_message_stop(event: MessageStopEvent, content: list[dict[str, Any]]) -> StopReason:
328328
"""Handles the end of a message by returning the stop reason.
329329
330+
Some models return "end_turn" even when tool calls are present, which prevents the event loop from processing
331+
those tool calls. This function overrides to "tool_use" so tool execution proceeds correctly.
332+
330333
Args:
331334
event: Stop event.
335+
content: The message content blocks accumulated during streaming.
332336
333337
Returns:
334338
The reason for stopping the stream.
335339
"""
336-
return event["stopReason"]
340+
stop_reason = event["stopReason"]
341+
342+
if stop_reason == "end_turn" and any("toolUse" in item for item in content):
343+
logger.warning(
344+
"original_stop_reason=<%s>, new_stop_reason=<%s> | "
345+
"overriding stop reason due to toolUse blocks in response",
346+
"end_turn",
347+
"tool_use",
348+
)
349+
stop_reason = "tool_use"
350+
351+
return stop_reason
337352

338353

339354
def handle_redact_content(event: RedactContentEvent, state: dict[str, Any]) -> None:
@@ -427,7 +442,7 @@ async def process_stream(
427442
elif "contentBlockStop" in chunk:
428443
state = handle_content_block_stop(state)
429444
elif "messageStop" in chunk:
430-
stop_reason = handle_message_stop(chunk["messageStop"])
445+
stop_reason = handle_message_stop(chunk["messageStop"], state["message"].get("content", []))
431446
elif "metadata" in chunk:
432447
time_to_first_byte_ms = (
433448
int(1000 * (first_byte_time - start_time)) if (start_time and first_byte_time) else None

src/strands/models/bedrock.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -823,8 +823,6 @@ def _stream(
823823
logger.debug("got response from model")
824824
if streaming:
825825
response = self.client.converse_stream(**request)
826-
# Track tool use events to fix stopReason for streaming responses
827-
has_tool_use = False
828826
for chunk in response["stream"]:
829827
if (
830828
"metadata" in chunk
@@ -836,24 +834,7 @@ def _stream(
836834
for event in self._generate_redaction_events():
837835
callback(event)
838836

839-
# Track if we see tool use events
840-
if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"):
841-
has_tool_use = True
842-
843-
# Fix stopReason for streaming responses that contain tool use
844-
if (
845-
has_tool_use
846-
and "messageStop" in chunk
847-
and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn"
848-
):
849-
# Create corrected chunk with tool_use stopReason
850-
modified_chunk = chunk.copy()
851-
modified_chunk["messageStop"] = message_stop.copy()
852-
modified_chunk["messageStop"]["stopReason"] = "tool_use"
853-
logger.warning("Override stop reason from end_turn to tool_use")
854-
callback(modified_chunk)
855-
else:
856-
callback(chunk)
837+
callback(chunk)
857838

858839
else:
859840
response = self.client.converse(**request)
@@ -992,17 +973,9 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
992973
yield {"contentBlockStop": {}}
993974

994975
# Yield messageStop event
995-
# Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side
996-
current_stop_reason = response["stopReason"]
997-
if current_stop_reason == "end_turn":
998-
message_content = response["output"]["message"]["content"]
999-
if any("toolUse" in content for content in message_content):
1000-
current_stop_reason = "tool_use"
1001-
logger.warning("Override stop reason from end_turn to tool_use")
1002-
1003976
yield {
1004977
"messageStop": {
1005-
"stopReason": current_stop_reason,
978+
"stopReason": response["stopReason"],
1006979
"additionalModelResponseFields": response.get("additionalModelResponseFields"),
1007980
}
1008981
}

tests/strands/event_loop/test_streaming.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,12 +530,30 @@ def test_handle_content_block_stop(state, exp_updated_state):
530530
def test_handle_message_stop():
531531
event: MessageStopEvent = {"stopReason": "end_turn"}
532532

533-
tru_reason = strands.event_loop.streaming.handle_message_stop(event)
533+
tru_reason = strands.event_loop.streaming.handle_message_stop(event, [])
534534
exp_reason = "end_turn"
535535

536536
assert tru_reason == exp_reason
537537

538538

539+
def test_handle_message_stop_overrides_end_turn_when_tool_use_present():
540+
event: MessageStopEvent = {"stopReason": "end_turn"}
541+
content = [{"toolUse": {"toolUseId": "t1", "name": "myTool", "input": {}}}]
542+
543+
tru_reason = strands.event_loop.streaming.handle_message_stop(event, content)
544+
545+
assert tru_reason == "tool_use"
546+
547+
548+
def test_handle_message_stop_keeps_tool_use_unchanged():
549+
event: MessageStopEvent = {"stopReason": "tool_use"}
550+
content = [{"toolUse": {"toolUseId": "t1", "name": "myTool", "input": {}}}]
551+
552+
tru_reason = strands.event_loop.streaming.handle_message_stop(event, content)
553+
554+
assert tru_reason == "tool_use"
555+
556+
539557
def test_extract_usage_metrics():
540558
event = {
541559
"usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0},
@@ -1334,3 +1352,68 @@ async def test_stream_messages_normalizes_messages(agenerator, alist):
13341352
{"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"},
13351353
{"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"},
13361354
]
1355+
1356+
1357+
@pytest.mark.asyncio
1358+
async def test_process_stream_overrides_end_turn_when_tool_use_present(agenerator, alist):
1359+
response = [
1360+
{"messageStart": {"role": "assistant"}},
1361+
{"contentBlockStart": {"contentBlockIndex": 0, "start": {"toolUse": {"toolUseId": "t1", "name": "myTool"}}}},
1362+
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"key": "val"}'}}, "contentBlockIndex": 0}},
1363+
{"contentBlockStop": {"contentBlockIndex": 0}},
1364+
{"messageStop": {"stopReason": "end_turn"}},
1365+
{
1366+
"metadata": {
1367+
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
1368+
"metrics": {"latencyMs": 100},
1369+
}
1370+
},
1371+
]
1372+
1373+
stream = strands.event_loop.streaming.process_stream(agenerator(response))
1374+
last_event = cast(ModelStopReason, (await alist(stream))[-1])
1375+
1376+
assert last_event["stop"][0] == "tool_use"
1377+
1378+
1379+
@pytest.mark.asyncio
1380+
async def test_process_stream_keeps_end_turn_when_no_tool_use(agenerator, alist):
1381+
response = [
1382+
{"messageStart": {"role": "assistant"}},
1383+
{"contentBlockDelta": {"delta": {"text": "Hello!"}, "contentBlockIndex": 0}},
1384+
{"contentBlockStop": {"contentBlockIndex": 0}},
1385+
{"messageStop": {"stopReason": "end_turn"}},
1386+
{
1387+
"metadata": {
1388+
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
1389+
"metrics": {"latencyMs": 100},
1390+
}
1391+
},
1392+
]
1393+
1394+
stream = strands.event_loop.streaming.process_stream(agenerator(response))
1395+
last_event = cast(ModelStopReason, (await alist(stream))[-1])
1396+
1397+
assert last_event["stop"][0] == "end_turn"
1398+
1399+
1400+
@pytest.mark.asyncio
1401+
async def test_process_stream_keeps_tool_use_stop_reason_unchanged(agenerator, alist):
1402+
response = [
1403+
{"messageStart": {"role": "assistant"}},
1404+
{"contentBlockStart": {"contentBlockIndex": 0, "start": {"toolUse": {"toolUseId": "t1", "name": "myTool"}}}},
1405+
{"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}, "contentBlockIndex": 0}},
1406+
{"contentBlockStop": {"contentBlockIndex": 0}},
1407+
{"messageStop": {"stopReason": "tool_use"}},
1408+
{
1409+
"metadata": {
1410+
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
1411+
"metrics": {"latencyMs": 100},
1412+
}
1413+
},
1414+
]
1415+
1416+
stream = strands.event_loop.streaming.process_stream(agenerator(response))
1417+
last_event = cast(ModelStopReason, (await alist(stream))[-1])
1418+
1419+
assert last_event["stop"][0] == "tool_use"

tests/strands/models/test_bedrock.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,53 +1565,6 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist):
15651565
assert "finished streaming response from model" in log_text
15661566

15671567

1568-
@pytest.mark.asyncio
1569-
async def test_stream_stop_reason_override_streaming(bedrock_client, model, messages, alist):
1570-
"""Test that stopReason is overridden from end_turn to tool_use in streaming mode when tool use is detected."""
1571-
bedrock_client.converse_stream.return_value = {
1572-
"stream": [
1573-
{"messageStart": {"role": "assistant"}},
1574-
{"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test_tool"}}}},
1575-
{"contentBlockDelta": {"delta": {"test": {"input": '{"param": "value"}'}}}},
1576-
{"contentBlockStop": {}},
1577-
{"messageStop": {"stopReason": "end_turn"}},
1578-
]
1579-
}
1580-
1581-
response = model.stream(messages)
1582-
events = await alist(response)
1583-
1584-
# Find the messageStop event
1585-
message_stop_event = next(event for event in events if "messageStop" in event)
1586-
1587-
# Verify stopReason was overridden to tool_use
1588-
assert message_stop_event["messageStop"]["stopReason"] == "tool_use"
1589-
1590-
1591-
@pytest.mark.asyncio
1592-
async def test_stream_stop_reason_override_non_streaming(bedrock_client, alist, messages):
1593-
"""Test that stopReason is overridden from end_turn to tool_use in non-streaming mode when tool use is detected."""
1594-
bedrock_client.converse.return_value = {
1595-
"output": {
1596-
"message": {
1597-
"role": "assistant",
1598-
"content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"param": "value"}}}],
1599-
}
1600-
},
1601-
"stopReason": "end_turn",
1602-
}
1603-
1604-
model = BedrockModel(model_id="test-model", streaming=False)
1605-
response = model.stream(messages)
1606-
events = await alist(response)
1607-
1608-
# Find the messageStop event
1609-
message_stop_event = next(event for event in events if "messageStop" in event)
1610-
1611-
# Verify stopReason was overridden to tool_use
1612-
assert message_stop_event["messageStop"]["stopReason"] == "tool_use"
1613-
1614-
16151568
def test_format_request_cleans_tool_result_content_blocks(model, model_id):
16161569
messages = [
16171570
{

0 commit comments

Comments
 (0)