Skip to content

Commit b0fc796

Browse files
authored
fix: place cache point on last user message instead of assistant (#1821)
1 parent bfe9d02 commit b0fc796

File tree

2 files changed

+64
-25
lines changed

2 files changed

+64
-25
lines changed

src/strands/models/bedrock.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,15 +339,15 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict
339339
return {"additionalModelRequestFields": additional_fields}
340340

341341
def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
342-
"""Inject a cache point at the end of the last assistant message.
342+
"""Inject a cache point at the end of the last user message.
343343
344344
Args:
345345
messages: List of messages to inject cache point into (modified in place).
346346
"""
347347
if not messages:
348348
return
349349

350-
last_assistant_idx: int | None = None
350+
last_user_idx: int | None = None
351351
for msg_idx, msg in enumerate(messages):
352352
content = msg.get("content", [])
353353
for block_idx, block in reversed(list(enumerate(content))):
@@ -358,12 +358,12 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
358358
msg_idx,
359359
block_idx,
360360
)
361-
if msg.get("role") == "assistant":
362-
last_assistant_idx = msg_idx
361+
if msg.get("role") == "user":
362+
last_user_idx = msg_idx
363363

364-
if last_assistant_idx is not None and messages[last_assistant_idx].get("content"):
365-
messages[last_assistant_idx]["content"].append({"cachePoint": {"type": "default"}})
366-
logger.debug("msg_idx=<%s> | added cache point to last assistant message", last_assistant_idx)
364+
if last_user_idx is not None and messages[last_user_idx].get("content"):
365+
messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}})
366+
logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx)
367367

368368
def _find_last_user_text_message_index(self, messages: Messages) -> int | None:
369369
"""Find the index of the last user message containing text or image content.

tests/strands/models/test_bedrock.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2597,8 +2597,8 @@ def test_cache_strategy_none_for_non_claude(bedrock_client):
25972597
assert model._cache_strategy is None
25982598

25992599

2600-
def test_inject_cache_point_adds_to_last_assistant(bedrock_client):
2601-
"""Test that _inject_cache_point adds cache point to last assistant message."""
2600+
def test_inject_cache_point_adds_to_last_user(bedrock_client):
2601+
"""Test that _inject_cache_point adds cache point to last user message."""
26022602
model = BedrockModel(
26032603
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")
26042604
)
@@ -2611,13 +2611,14 @@ def test_inject_cache_point_adds_to_last_assistant(bedrock_client):
26112611

26122612
model._inject_cache_point(cleaned_messages)
26132613

2614-
assert len(cleaned_messages[1]["content"]) == 2
2615-
assert "cachePoint" in cleaned_messages[1]["content"][-1]
2616-
assert cleaned_messages[1]["content"][-1]["cachePoint"]["type"] == "default"
2614+
assert len(cleaned_messages[2]["content"]) == 2
2615+
assert "cachePoint" in cleaned_messages[2]["content"][-1]
2616+
assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default"
2617+
assert len(cleaned_messages[1]["content"]) == 1
26172618

26182619

2619-
def test_inject_cache_point_no_assistant_message(bedrock_client):
2620-
"""Test that _inject_cache_point does nothing when no assistant message exists."""
2620+
def test_inject_cache_point_single_user_message(bedrock_client):
2621+
"""Test that _inject_cache_point adds cache point to single user message."""
26212622
model = BedrockModel(
26222623
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")
26232624
)
@@ -2629,6 +2630,39 @@ def test_inject_cache_point_no_assistant_message(bedrock_client):
26292630
model._inject_cache_point(cleaned_messages)
26302631

26312632
assert len(cleaned_messages) == 1
2633+
assert len(cleaned_messages[0]["content"]) == 2
2634+
assert "cachePoint" in cleaned_messages[0]["content"][-1]
2635+
2636+
2637+
def test_inject_cache_point_empty_messages(bedrock_client):
2638+
"""Test that _inject_cache_point handles empty messages list."""
2639+
model = BedrockModel(
2640+
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")
2641+
)
2642+
2643+
cleaned_messages = []
2644+
model._inject_cache_point(cleaned_messages)
2645+
2646+
assert cleaned_messages == []
2647+
2648+
2649+
def test_inject_cache_point_with_tool_result_last_user(bedrock_client):
2650+
"""Test that cache point is added to last user message even when it contains toolResult."""
2651+
model = BedrockModel(
2652+
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto")
2653+
)
2654+
2655+
cleaned_messages = [
2656+
{"role": "user", "content": [{"text": "Use the tool"}]},
2657+
{"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {}}}]},
2658+
{"role": "user", "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}]},
2659+
]
2660+
2661+
model._inject_cache_point(cleaned_messages)
2662+
2663+
assert len(cleaned_messages[2]["content"]) == 2
2664+
assert "cachePoint" in cleaned_messages[2]["content"][-1]
2665+
assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default"
26322666
assert len(cleaned_messages[0]["content"]) == 1
26332667

26342668

@@ -2643,6 +2677,8 @@ def test_inject_cache_point_skipped_for_non_claude(bedrock_client):
26432677

26442678
formatted = model._format_bedrock_messages(messages)
26452679

2680+
assert len(formatted[0]["content"]) == 1
2681+
assert "cachePoint" not in formatted[0]["content"][0]
26462682
assert len(formatted[1]["content"]) == 1
26472683
assert "cachePoint" not in formatted[1]["content"][0]
26482684

@@ -2664,8 +2700,8 @@ def test_format_bedrock_messages_does_not_mutate_original(bedrock_client):
26642700
formatted = model._format_bedrock_messages(original_messages)
26652701

26662702
assert original_messages == messages_before
2667-
assert "cachePoint" not in original_messages[1]["content"][-1]
2668-
assert "cachePoint" in formatted[1]["content"][-1]
2703+
assert "cachePoint" not in original_messages[2]["content"][-1]
2704+
assert "cachePoint" in formatted[2]["content"][-1]
26692705

26702706

26712707
def test_inject_cache_point_strips_existing_cache_points(bedrock_client):
@@ -2685,12 +2721,13 @@ def test_inject_cache_point_strips_existing_cache_points(bedrock_client):
26852721
model._inject_cache_point(cleaned_messages)
26862722

26872723
# All old cache points should be stripped
2688-
assert len(cleaned_messages[0]["content"]) == 1 # user: only text
2724+
assert len(cleaned_messages[0]["content"]) == 1 # first user: only text
26892725
assert len(cleaned_messages[1]["content"]) == 1 # first assistant: only text
2726+
assert len(cleaned_messages[3]["content"]) == 1 # last assistant: only text
26902727

2691-
# New cache point should be at end of last assistant message
2692-
assert len(cleaned_messages[3]["content"]) == 2
2693-
assert "cachePoint" in cleaned_messages[3]["content"][-1]
2728+
# New cache point should be at end of last user message
2729+
assert len(cleaned_messages[2]["content"]) == 2
2730+
assert "cachePoint" in cleaned_messages[2]["content"][-1]
26942731

26952732

26962733
def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client):
@@ -2707,9 +2744,10 @@ def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client)
27072744

27082745
formatted = model._format_bedrock_messages(messages)
27092746

2710-
assert len(formatted[1]["content"]) == 2
2711-
assert "cachePoint" in formatted[1]["content"][-1]
2712-
assert formatted[1]["content"][-1]["cachePoint"]["type"] == "default"
2747+
assert len(formatted[0]["content"]) == 2
2748+
assert "cachePoint" in formatted[0]["content"][-1]
2749+
assert formatted[0]["content"][-1]["cachePoint"]["type"] == "default"
2750+
assert len(formatted[1]["content"]) == 1
27132751

27142752

27152753
def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedrock_client):
@@ -2725,8 +2763,9 @@ def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedro
27252763

27262764
formatted = model._format_bedrock_messages(messages)
27272765

2728-
assert len(formatted[1]["content"]) == 2
2729-
assert "cachePoint" in formatted[1]["content"][-1]
2766+
assert len(formatted[0]["content"]) == 2
2767+
assert "cachePoint" in formatted[0]["content"][-1]
2768+
assert len(formatted[1]["content"]) == 1
27302769

27312770

27322771
def test_find_last_user_text_message_index_no_user_messages(bedrock_client):

0 commit comments

Comments
 (0)