diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index d535bbc51..b96d00a2f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -81,6 +81,7 @@ class BedrockConfig(BaseModelConfig, total=False): cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. cache_tools: Cache point type for tools + cache_tools_ttl: Optional TTL duration for tool cache points (e.g. "5m", "1h") guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -115,6 +116,7 @@ class BedrockConfig(BaseModelConfig, total=False): cache_prompt: str | None cache_config: CacheConfig | None cache_tools: str | None + cache_tools_ttl: str | None guardrail_id: str | None guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None guardrail_stream_processing_mode: Literal["sync", "async"] | None @@ -279,7 +281,18 @@ def _format_request( for tool_spec in tool_specs ], *( - [{"cachePoint": {"type": self.config["cache_tools"]}}] + [ + { + "cachePoint": { + "type": self.config["cache_tools"], + **( + {"ttl": self.config["cache_tools_ttl"]} + if self.config.get("cache_tools_ttl") + else {} + ), + } + } + ] if self.config.get("cache_tools") else [] ), @@ -381,7 +394,11 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: last_user_idx = msg_idx if last_user_idx is not None and messages[last_user_idx].get("content"): - messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}}) + cache_point: dict[str, Any] = {"type": "default"} + cache_config = self.config.get("cache_config") + if cache_config and cache_config.ttl: + cache_point["ttl"] = cache_config.ttl + messages[last_user_idx]["content"].append({"cachePoint": cache_point}) logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx) def _find_last_user_text_message_index(self, messages: Messages) -> int | None: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index e5b15ebaa..d41c5d159 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -202,9 +202,12 @@ class CacheConfig: strategy: Caching strategy to use. - "auto": Automatically detect model support and inject cachePoint to maximize cache coverage - "anthropic": Inject cachePoint in Anthropic-compatible format without model support check + ttl: Optional TTL duration for cache entries (e.g. "5m", "1h"). + When specified, auto-injected cache points will include this TTL value. """ strategy: Literal["auto", "anthropic"] = "auto" + ttl: str | None = None class Model(abc.ABC): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index a80ca091e..a48b3c556 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3409,3 +3409,59 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess await model_with_client.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + +def test_inject_cache_point_with_ttl(bedrock_client): + """Test that _inject_cache_point includes TTL when cache_config has ttl set.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert cache_point["ttl"] == "5m" + + +def test_inject_cache_point_without_ttl(bedrock_client): + """Test that _inject_cache_point omits TTL when cache_config has no ttl.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert "ttl" not in cache_point + + +def test_format_request_cache_tools_with_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that cache_tools_ttl propagates into toolConfig cachePoint.""" + model.update_config(cache_tools=cache_type, cache_tools_ttl="5m") + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type, "ttl": "5m"}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_without_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that toolConfig cachePoint omits TTL when cache_tools_ttl is not set.""" + model.update_config(cache_tools=cache_type) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 73d67f414..7b481adb7 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -6,9 +6,18 @@ import strands from strands import Agent -from strands.models import BedrockModel +from strands.models import BedrockModel, CacheConfig from strands.types.content import ContentBlock +# Model ID used for prompt-caching TTL integration tests. Per +# https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html +# the models that officially support 1h TTL on CachePoint are Claude Opus 4.5, +# Claude Haiku 4.5, and Claude Sonnet 4.5. Haiku 4.5 is the newest Haiku +# available and is preferred for CI due to lower latency and cost relative to +# the same-version Sonnet 4.5. Bump this when a newer Haiku is released that +# supports CachePoint TTL. +_CACHE_TTL_MODEL_ID = "us.anthropic.claude-haiku-4-5-20251001-v1:0" + @pytest.fixture def system_prompt(): @@ -576,3 +585,129 @@ def calculator(expression: str) -> float: agent('Search for "python" with tags ["programming", "language"] using the search tool.') assert "search" in tools_called + + +def test_prompt_caching_cache_tools_ttl(): + """Test that cache_tools_ttl propagates into the auto-injected toolConfig cache point. + + Verifies that BedrockModel(cache_tools="default", cache_tools_ttl="5m") produces a + Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call + completes without a ValidationException on the TTL field. + + Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig + prefix because Bedrock's tool-prefix cache threshold varies by model and region. + The critical behavior under test here is that the TTL field is accepted end-to-end. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + (Claude Opus 4.5, Claude Haiku 4.5, and Claude Sonnet 4.5 all support 1h TTL). + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools="default", + cache_tools_ttl="5m", + ) + + @strands.tool + def lookup_fact(topic: str) -> str: + """Look up a fact about the given topic. + + This tool is useful when you need authoritative information. + """ + return f"Fact about {topic}: example" + + agent = Agent( + model=model, + tools=[lookup_fact], + load_tools_from_directory=False, + ) + + # The call must succeed — Bedrock must accept cachePoint.ttl on the toolConfig checkpoint + # without raising a ValidationException. + result = agent("Use the lookup_fact tool to look up 'python'.") + assert len(str(result)) > 0 + + +def test_prompt_caching_cache_config_auto_with_ttl(): + """Test that CacheConfig(strategy="auto", ttl="5m") propagates TTL to the auto-injected message cache point. + + Verifies that the cache point appended to the last user message by _inject_cache_point + carries the configured TTL, and that Bedrock accepts the request. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + unique_id = str(uuid.uuid4()) + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_message = f"Context for test {unique_id}: " + ("This is important context. " * 1000) + " What is 2+2?" + + agent = Agent( + model=model, + load_tools_from_directory=False, + ) + + # First call: auto-injected cache point on the last user message must include ttl and be accepted + result1 = agent(large_message) + assert len(str(result1)) > 0 + + # Verify cache write occurred with auto-inject + ttl + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 with CacheConfig(strategy='auto', ttl='5m')" + ) + + +def test_prompt_caching_aligned_1h_ttl_across_checkpoints(): + """Regression test for Bedrock TTL non-increasing ordering rule (Issue #2121). + + Bedrock processes cache checkpoints in order: toolConfig -> system -> messages, + and requires TTLs to be non-increasing. Before this change, cache_tools hardcoded + an implicit 5m TTL, so any 1h TTL on a later checkpoint would raise a + ValidationException. + + This test sets 1h TTL on all three checkpoints simultaneously and verifies the + call succeeds. + + Uses Claude Haiku 4.5 which supports 1h TTL per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools="default", + cache_tools_ttl="1h", + cache_config=CacheConfig(strategy="auto", ttl="1h"), + ) + + # Timestamp-based uniqueness to avoid cache conflicts across CI runs + unique_id = str(int(time.time() * 1000000)) + large_context = f"Background context for test {unique_id}: " + ("This is important context. " * 1000) + + # User-supplied 1h cache point on system prompt — third checkpoint also at 1h + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + {"text": "You are a helpful assistant."}, + ] + + @strands.tool + def echo(value: str) -> str: + """Echo the given value back.""" + return value + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + tools=[echo], + load_tools_from_directory=False, + ) + + # Must succeed without ValidationException on the non-increasing TTL rule + result = agent("What is 2+2?") + assert len(str(result)) > 0