Skip to content

Commit c50457d

Browse files
austinmwlizradway
andauthored
fix: preserve guardrail_latest_message wrapping after tool execution (#1658)
Co-authored-by: Liz <91279165+lizradway@users.noreply.github.com>
1 parent 2c83216 commit c50457d

2 files changed

Lines changed: 292 additions & 8 deletions

File tree

src/strands/models/bedrock.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,23 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
363363
messages[last_assistant_idx]["content"].append({"cachePoint": {"type": "default"}})
364364
logger.debug("msg_idx=<%s> | added cache point to last assistant message", last_assistant_idx)
365365

366+
def _find_last_user_text_message_index(self, messages: Messages) -> int | None:
367+
"""Find the index of the last user message containing text or image content.
368+
369+
This is used for guardrail_latest_message to ensure that guardContent wrapping
370+
targets the correct message even when toolResult messages follow.
371+
372+
Args:
373+
messages: List of messages to search
374+
375+
Returns:
376+
Index of the last user message with text/image content, or None if not found
377+
"""
378+
for idx, msg in reversed(list(enumerate(messages))):
379+
if msg["role"] == "user" and any("text" in cb or "image" in cb for cb in msg.get("content", [])):
380+
return idx
381+
return None
382+
366383
def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
367384
"""Format messages for Bedrock API compatibility.
368385
@@ -391,7 +408,12 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
391408
filtered_unknown_members = False
392409
dropped_deepseek_reasoning_content = False
393410

394-
guardrail_latest_message = self.config.get("guardrail_latest_message", False)
411+
# Pre-compute the index of the last user message containing text or image content.
412+
# This ensures guardContent wrapping is maintained across tool execution cycles, where
413+
# the final message in the list is a toolResult (role=user) rather than text/image content.
414+
last_user_text_idx = None
415+
if self.config.get("guardrail_latest_message", False):
416+
last_user_text_idx = self._find_last_user_text_message_index(messages)
395417

396418
for idx, message in enumerate(messages):
397419
cleaned_content: list[dict[str, Any]] = []
@@ -413,13 +435,8 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
413435
if formatted_content is None:
414436
continue
415437

416-
# Wrap text or image content in guardrailContent if this is the last user message
417-
if (
418-
guardrail_latest_message
419-
and idx == len(messages) - 1
420-
and message["role"] == "user"
421-
and ("text" in formatted_content or "image" in formatted_content)
422-
):
438+
# Wrap text or image content in guardContent if this is the last user text/image message
439+
if idx == last_user_text_idx and ("text" in formatted_content or "image" in formatted_content):
423440
if "text" in formatted_content:
424441
formatted_content = {"guardContent": {"text": {"text": formatted_content["text"]}}}
425442
elif "image" in formatted_content:

tests/strands/models/test_bedrock.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,6 +2405,183 @@ async def test_format_request_with_guardrail_latest_message(model):
24052405
assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png"
24062406

24072407

2408+
@pytest.mark.asyncio
2409+
async def test_format_request_with_guardrail_latest_message_after_tool_use(model):
2410+
"""Test that guardContent wraps the last user text message even when a toolResult follows it."""
2411+
model.update_config(
2412+
guardrail_id="test-guardrail",
2413+
guardrail_version="DRAFT",
2414+
guardrail_latest_message=True,
2415+
)
2416+
2417+
messages = [
2418+
{"role": "user", "content": [{"text": "First message"}]},
2419+
{"role": "assistant", "content": [{"text": "First response"}]},
2420+
{"role": "user", "content": [{"text": "what is the standard deduction?"}]},
2421+
{
2422+
"role": "assistant",
2423+
"content": [
2424+
{
2425+
"toolUse": {
2426+
"toolUseId": "tool-1",
2427+
"name": "knowledge_base",
2428+
"input": {"query": "standard deduction"},
2429+
}
2430+
}
2431+
],
2432+
},
2433+
{
2434+
"role": "user",
2435+
"content": [
2436+
{
2437+
"toolResult": {
2438+
"toolUseId": "tool-1",
2439+
"content": [{"text": "The standard deduction for 2024 is $14,600."}],
2440+
"status": "success",
2441+
}
2442+
}
2443+
],
2444+
},
2445+
]
2446+
2447+
request = model._format_request(messages)
2448+
formatted_messages = request["messages"]
2449+
2450+
assert len(formatted_messages) == 5
2451+
2452+
# Earlier user message should NOT be wrapped
2453+
assert "text" in formatted_messages[0]["content"][0]
2454+
assert formatted_messages[0]["content"][0]["text"] == "First message"
2455+
2456+
# Last user message with text content should be wrapped, even though a toolResult comes after
2457+
assert "guardContent" in formatted_messages[2]["content"][0]
2458+
assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "what is the standard deduction?"
2459+
2460+
# toolResult-only user message should NOT be wrapped
2461+
assert "toolResult" in formatted_messages[4]["content"][0]
2462+
assert "guardContent" not in formatted_messages[4]["content"][0]
2463+
2464+
2465+
@pytest.mark.asyncio
2466+
async def test_format_request_with_guardrail_latest_message_wraps_final_user_text(model):
2467+
"""Test that guardContent wraps the last user message when it contains text content."""
2468+
model.update_config(
2469+
guardrail_id="test-guardrail",
2470+
guardrail_version="DRAFT",
2471+
guardrail_latest_message=True,
2472+
)
2473+
2474+
messages = [
2475+
{"role": "user", "content": [{"text": "First message"}]},
2476+
{"role": "assistant", "content": [{"text": "First response"}]},
2477+
{"role": "user", "content": [{"text": "Tell me about taxes"}]},
2478+
]
2479+
2480+
request = model._format_request(messages)
2481+
formatted_messages = request["messages"]
2482+
2483+
assert "guardContent" in formatted_messages[2]["content"][0]
2484+
assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "Tell me about taxes"
2485+
2486+
2487+
@pytest.mark.asyncio
2488+
async def test_format_request_with_guardrail_multiple_sequential_tool_calls(model):
2489+
"""Test guardContent with multiple tool calls in sequence (no new user input between)."""
2490+
model.update_config(
2491+
guardrail_id="test-guardrail",
2492+
guardrail_version="DRAFT",
2493+
guardrail_latest_message=True,
2494+
)
2495+
2496+
messages = [
2497+
{"role": "user", "content": [{"text": "First question"}]},
2498+
{"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "tool1", "input": {}}}]},
2499+
{
2500+
"role": "user",
2501+
"content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result 1"}], "status": "success"}}],
2502+
},
2503+
{"role": "assistant", "content": [{"toolUse": {"toolUseId": "t2", "name": "tool2", "input": {}}}]},
2504+
{
2505+
"role": "user",
2506+
"content": [{"toolResult": {"toolUseId": "t2", "content": [{"text": "Result 2"}], "status": "success"}}],
2507+
},
2508+
]
2509+
2510+
request = model._format_request(messages)
2511+
formatted_messages = request["messages"]
2512+
2513+
# Should wrap the first user text message, not the toolResults
2514+
assert "guardContent" in formatted_messages[0]["content"][0]
2515+
assert formatted_messages[0]["content"][0]["guardContent"]["text"]["text"] == "First question"
2516+
2517+
# toolResults should not be wrapped
2518+
assert "toolResult" in formatted_messages[2]["content"][0]
2519+
assert "guardContent" not in formatted_messages[2]["content"][0]
2520+
assert "toolResult" in formatted_messages[4]["content"][0]
2521+
assert "guardContent" not in formatted_messages[4]["content"][0]
2522+
2523+
2524+
@pytest.mark.asyncio
2525+
async def test_format_request_with_guardrail_image_before_tool_result(model):
2526+
"""Test guardContent wraps image content even when toolResult follows."""
2527+
model.update_config(
2528+
guardrail_id="test-guardrail",
2529+
guardrail_version="DRAFT",
2530+
guardrail_latest_message=True,
2531+
)
2532+
2533+
messages = [
2534+
{"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"fake"}}}]},
2535+
{"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "vision", "input": {}}}]},
2536+
{
2537+
"role": "user",
2538+
"content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "I see a cat"}], "status": "success"}}],
2539+
},
2540+
]
2541+
2542+
request = model._format_request(messages)
2543+
formatted_messages = request["messages"]
2544+
2545+
# Image should be wrapped even though toolResult comes after
2546+
assert "guardContent" in formatted_messages[0]["content"][0]
2547+
assert "image" in formatted_messages[0]["content"][0]["guardContent"]
2548+
2549+
2550+
@pytest.mark.asyncio
2551+
async def test_format_request_with_guardrail_multiple_tool_results_same_message(model):
2552+
"""Test guardContent with multiple parallel tool calls (multiple toolResults in one message)."""
2553+
model.update_config(
2554+
guardrail_id="test-guardrail",
2555+
guardrail_version="DRAFT",
2556+
guardrail_latest_message=True,
2557+
)
2558+
2559+
messages = [
2560+
{"role": "user", "content": [{"text": "Question requiring multiple tools"}]},
2561+
{
2562+
"role": "assistant",
2563+
"content": [
2564+
{"toolUse": {"toolUseId": "t1", "name": "tool1", "input": {}}},
2565+
{"toolUse": {"toolUseId": "t2", "name": "tool2", "input": {}}},
2566+
],
2567+
},
2568+
{
2569+
"role": "user",
2570+
"content": [
2571+
{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result 1"}], "status": "success"}},
2572+
{"toolResult": {"toolUseId": "t2", "content": [{"text": "Result 2"}], "status": "success"}},
2573+
],
2574+
},
2575+
]
2576+
2577+
request = model._format_request(messages)
2578+
formatted_messages = request["messages"]
2579+
2580+
# Should wrap the question
2581+
assert "guardContent" in formatted_messages[0]["content"][0]
2582+
assert formatted_messages[0]["content"][0]["guardContent"]["text"]["text"] == "Question requiring multiple tools"
2583+
2584+
24082585
def test_supports_caching_true_for_claude(bedrock_client):
24092586
"""Test that supports_caching returns True for Claude models."""
24102587
model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0")
@@ -2514,3 +2691,93 @@ def test_inject_cache_point_strips_existing_cache_points(bedrock_client):
25142691
# New cache point should be at end of last assistant message
25152692
assert len(cleaned_messages[3]["content"]) == 2
25162693
assert "cachePoint" in cleaned_messages[3]["content"][-1]
2694+
2695+
2696+
def test_find_last_user_text_message_index_no_user_messages(bedrock_client):
2697+
"""Test _find_last_user_text_message_index returns None when no user text messages exist."""
2698+
model = BedrockModel(model_id="test-model")
2699+
2700+
messages = [
2701+
{"role": "assistant", "content": [{"text": "hello"}]},
2702+
]
2703+
2704+
assert model._find_last_user_text_message_index(messages) is None
2705+
2706+
2707+
def test_find_last_user_text_message_index_only_tool_results(bedrock_client):
2708+
"""Test _find_last_user_text_message_index returns None when user messages only have toolResult."""
2709+
model = BedrockModel(model_id="test-model")
2710+
2711+
messages = [
2712+
{
2713+
"role": "user",
2714+
"content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "result"}]}}],
2715+
},
2716+
]
2717+
2718+
assert model._find_last_user_text_message_index(messages) is None
2719+
2720+
2721+
def test_find_last_user_text_message_index_returns_last_text_message(bedrock_client):
2722+
"""Test _find_last_user_text_message_index returns the index of the last user message with text."""
2723+
model = BedrockModel(model_id="test-model")
2724+
2725+
messages = [
2726+
{"role": "user", "content": [{"text": "First question"}]},
2727+
{"role": "assistant", "content": [{"text": "Response"}]},
2728+
{"role": "user", "content": [{"text": "Second question"}]},
2729+
]
2730+
2731+
assert model._find_last_user_text_message_index(messages) == 2
2732+
2733+
2734+
def test_find_last_user_text_message_index_skips_tool_result_messages(bedrock_client):
2735+
"""Test _find_last_user_text_message_index skips toolResult-only user messages."""
2736+
model = BedrockModel(model_id="test-model")
2737+
2738+
messages = [
2739+
{"role": "user", "content": [{"text": "Question"}]},
2740+
{"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "tool", "input": {}}}]},
2741+
{
2742+
"role": "user",
2743+
"content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}],
2744+
},
2745+
]
2746+
2747+
assert model._find_last_user_text_message_index(messages) == 0
2748+
2749+
2750+
def test_find_last_user_text_message_index_finds_image_message(bedrock_client):
2751+
"""Test _find_last_user_text_message_index finds user messages with image content."""
2752+
model = BedrockModel(model_id="test-model")
2753+
2754+
messages = [
2755+
{"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"fake"}}}]},
2756+
{"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "vision", "input": {}}}]},
2757+
{
2758+
"role": "user",
2759+
"content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}],
2760+
},
2761+
]
2762+
2763+
assert model._find_last_user_text_message_index(messages) == 0
2764+
2765+
2766+
def test_find_last_user_text_message_index_empty_messages(bedrock_client):
2767+
"""Test _find_last_user_text_message_index returns None for empty message list."""
2768+
model = BedrockModel(model_id="test-model")
2769+
2770+
assert model._find_last_user_text_message_index([]) is None
2771+
2772+
2773+
def test_guardrail_latest_message_disabled_does_not_wrap(model):
2774+
"""Test that guardContent wrapping is skipped when guardrail_latest_message is not set."""
2775+
messages = [
2776+
{"role": "user", "content": [{"text": "Hello"}]},
2777+
]
2778+
2779+
request = model._format_request(messages)
2780+
formatted = request["messages"][0]["content"][0]
2781+
2782+
assert "text" in formatted
2783+
assert "guardContent" not in formatted

0 commit comments

Comments
 (0)