Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class BedrockConfig(BaseModelConfig, total=False):
cache_prompt: Cache point type for the system prompt (deprecated, use cache_config)
cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching.
cache_tools: Cache point type for tools
cache_tools_ttl: Optional TTL duration for tool cache points (e.g. "5m", "1h")
guardrail_id: ID of the guardrail to apply
guardrail_trace: Guardrail trace mode. Defaults to enabled.
guardrail_version: Version of the guardrail to apply
Expand Down Expand Up @@ -115,6 +116,7 @@ class BedrockConfig(BaseModelConfig, total=False):
cache_prompt: str | None
cache_config: CacheConfig | None
cache_tools: str | None
cache_tools_ttl: str | None
guardrail_id: str | None
guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None
guardrail_stream_processing_mode: Literal["sync", "async"] | None
Expand Down Expand Up @@ -279,7 +281,18 @@ def _format_request(
for tool_spec in tool_specs
],
*(
[{"cachePoint": {"type": self.config["cache_tools"]}}]
[
{
"cachePoint": {
"type": self.config["cache_tools"],
**(
{"ttl": self.config["cache_tools_ttl"]}
if self.config.get("cache_tools_ttl")
else {}
),
}
}
]
if self.config.get("cache_tools")
else []
),
Expand Down Expand Up @@ -381,7 +394,11 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
last_user_idx = msg_idx

if last_user_idx is not None and messages[last_user_idx].get("content"):
messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}})
cache_point: dict[str, Any] = {"type": "default"}
cache_config = self.config.get("cache_config")
if cache_config and cache_config.ttl:
cache_point["ttl"] = cache_config.ttl
messages[last_user_idx]["content"].append({"cachePoint": cache_point})
logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx)

def _find_last_user_text_message_index(self, messages: Messages) -> int | None:
Expand Down
3 changes: 3 additions & 0 deletions src/strands/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,12 @@ class CacheConfig:
strategy: Caching strategy to use.
- "auto": Automatically detect model support and inject cachePoint to maximize cache coverage
- "anthropic": Inject cachePoint in Anthropic-compatible format without model support check
ttl: Optional TTL duration for cache entries (e.g. "5m", "1h").
When specified, auto-injected cache points will include this TTL value.
"""

strategy: Literal["auto", "anthropic"] = "auto"
ttl: str | None = None


class Model(abc.ABC):
Expand Down
56 changes: 56 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3409,3 +3409,59 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess
await model_with_client.count_tokens(messages=messages)

assert any("native token counting failed" in record.message for record in caplog.records)


def test_inject_cache_point_with_ttl(bedrock_client):
"""Test that _inject_cache_point includes TTL when cache_config has ttl set."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0",
cache_config=CacheConfig(strategy="auto", ttl="5m"),
)

cleaned_messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

model._inject_cache_point(cleaned_messages)

cache_point = cleaned_messages[0]["content"][-1]["cachePoint"]
assert cache_point["type"] == "default"
assert cache_point["ttl"] == "5m"


def test_inject_cache_point_without_ttl(bedrock_client):
"""Test that _inject_cache_point omits TTL when cache_config has no ttl."""
model = BedrockModel(
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0",
cache_config=CacheConfig(strategy="auto"),
)

cleaned_messages = [
{"role": "user", "content": [{"text": "Hello"}]},
]

model._inject_cache_point(cleaned_messages)

cache_point = cleaned_messages[0]["content"][-1]["cachePoint"]
assert cache_point["type"] == "default"
assert "ttl" not in cache_point


def test_format_request_cache_tools_with_ttl(model, messages, model_id, tool_spec, cache_type):
"""Test that cache_tools_ttl propagates into toolConfig cachePoint."""
model.update_config(cache_tools=cache_type, cache_tools_ttl="5m")

tru_request = model._format_request(messages, tool_specs=[tool_spec])

exp_cache_point = {"cachePoint": {"type": cache_type, "ttl": "5m"}}
assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point


def test_format_request_cache_tools_without_ttl(model, messages, model_id, tool_spec, cache_type):
"""Test that toolConfig cachePoint omits TTL when cache_tools_ttl is not set."""
model.update_config(cache_tools=cache_type)

tru_request = model._format_request(messages, tool_specs=[tool_spec])

exp_cache_point = {"cachePoint": {"type": cache_type}}
assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point
137 changes: 136 additions & 1 deletion tests_integ/models/test_model_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,18 @@

import strands
from strands import Agent
from strands.models import BedrockModel
from strands.models import BedrockModel, CacheConfig
from strands.types.content import ContentBlock

# Model ID used for prompt-caching TTL integration tests. Per
# https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
# the models that officially support 1h TTL on CachePoint are Claude Opus 4.5,
# Claude Haiku 4.5, and Claude Sonnet 4.5. Haiku 4.5 is the newest Haiku
# available and is preferred for CI due to lower latency and cost relative to
# the same-version Sonnet 4.5. Bump this when a newer Haiku is released that
# supports CachePoint TTL.
_CACHE_TTL_MODEL_ID = "us.anthropic.claude-haiku-4-5-20251001-v1:0"


@pytest.fixture
def system_prompt():
Expand Down Expand Up @@ -576,3 +585,129 @@ def calculator(expression: str) -> float:
agent('Search for "python" with tags ["programming", "language"] using the search tool.')

assert "search" in tools_called


def test_prompt_caching_cache_tools_ttl():
"""Test that cache_tools_ttl propagates into the auto-injected toolConfig cache point.

Verifies that BedrockModel(cache_tools="default", cache_tools_ttl="5m") produces a
Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call
completes without a ValidationException on the TTL field.

Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig
prefix because Bedrock's tool-prefix cache threshold varies by model and region.
The critical behavior under test here is that the TTL field is accepted end-to-end.

Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
(Claude Opus 4.5, Claude Haiku 4.5, and Claude Sonnet 4.5 all support 1h TTL).
"""
model = BedrockModel(
model_id=_CACHE_TTL_MODEL_ID,
streaming=False,
cache_tools="default",
cache_tools_ttl="5m",
)

@strands.tool
def lookup_fact(topic: str) -> str:
"""Look up a fact about the given topic.

This tool is useful when you need authoritative information.
"""
return f"Fact about {topic}: example"

agent = Agent(
model=model,
tools=[lookup_fact],
load_tools_from_directory=False,
)

# The call must succeed — Bedrock must accept cachePoint.ttl on the toolConfig checkpoint
# without raising a ValidationException.
result = agent("Use the lookup_fact tool to look up 'python'.")
assert len(str(result)) > 0


def test_prompt_caching_cache_config_auto_with_ttl():
"""Test that CacheConfig(strategy="auto", ttl="5m") propagates TTL to the auto-injected message cache point.

Verifies that the cache point appended to the last user message by _inject_cache_point
carries the configured TTL, and that Bedrock accepts the request.

Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
"""
model = BedrockModel(
model_id=_CACHE_TTL_MODEL_ID,
streaming=False,
cache_config=CacheConfig(strategy="auto", ttl="5m"),
)

unique_id = str(uuid.uuid4())
# Minimum 4096 tokens required for caching with Haiku 4.5
large_message = f"Context for test {unique_id}: " + ("This is important context. " * 1000) + " What is 2+2?"

agent = Agent(
model=model,
load_tools_from_directory=False,
)

# First call: auto-injected cache point on the last user message must include ttl and be accepted
result1 = agent(large_message)
assert len(str(result1)) > 0

# Verify cache write occurred with auto-inject + ttl
assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, (
"Expected cacheWriteInputTokens > 0 with CacheConfig(strategy='auto', ttl='5m')"
)


def test_prompt_caching_aligned_1h_ttl_across_checkpoints():
"""Regression test for Bedrock TTL non-increasing ordering rule (Issue #2121).

Bedrock processes cache checkpoints in order: toolConfig -> system -> messages,
and requires TTLs to be non-increasing. Before this change, cache_tools hardcoded
an implicit 5m TTL, so any 1h TTL on a later checkpoint would raise a
ValidationException.

This test sets 1h TTL on all three checkpoints simultaneously and verifies the
call succeeds.

Uses Claude Haiku 4.5 which supports 1h TTL per
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
"""
model = BedrockModel(
model_id=_CACHE_TTL_MODEL_ID,
streaming=False,
cache_tools="default",
cache_tools_ttl="1h",
cache_config=CacheConfig(strategy="auto", ttl="1h"),
)

# Timestamp-based uniqueness to avoid cache conflicts across CI runs
unique_id = str(int(time.time() * 1000000))
large_context = f"Background context for test {unique_id}: " + ("This is important context. " * 1000)

# User-supplied 1h cache point on system prompt — third checkpoint also at 1h
system_prompt_with_cache = [
{"text": large_context},
{"cachePoint": {"type": "default", "ttl": "1h"}},
{"text": "You are a helpful assistant."},
]

@strands.tool
def echo(value: str) -> str:
"""Echo the given value back."""
return value

agent = Agent(
model=model,
system_prompt=system_prompt_with_cache,
tools=[echo],
load_tools_from_directory=False,
)

# Must succeed without ValidationException on the non-increasing TTL rule
result = agent("What is 2+2?")
assert len(str(result)) > 0