Skip to content

Commit 46ce50b

Browse files
authored
feat(bedrock): add TTL support to auto-injected tool and system/user cache points (#2232)
1 parent 1232230 commit 46ce50b

5 files changed

Lines changed: 248 additions & 12 deletions

File tree

src/strands/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77

88
from . import bedrock, model
99
from .bedrock import BedrockModel
10-
from .model import BaseModelConfig, CacheConfig, Model
10+
from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model
1111

1212
__all__ = [
1313
"bedrock",
1414
"model",
1515
"BaseModelConfig",
1616
"BedrockModel",
1717
"CacheConfig",
18+
"CacheToolsConfig",
1819
"Model",
1920
]
2021

src/strands/models/bedrock.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ._defaults import resolve_config_metadata
3535
from ._strict_schema import ensure_strict_json_schema
3636
from ._validation import validate_config_keys
37-
from .model import BaseModelConfig, CacheConfig, Model
37+
from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model
3838

3939
logger = logging.getLogger(__name__)
4040

@@ -90,7 +90,8 @@ class BedrockConfig(BaseModelConfig, total=False):
9090
additional_response_field_paths: Additional response field paths to extract
9191
cache_prompt: Cache point type for the system prompt (deprecated, use cache_config)
9292
cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching.
93-
cache_tools: Cache point type for tools
93+
cache_tools: Cache point type for tools. Pass a string (e.g. "default") for the default 5m TTL,
94+
or a CacheToolsConfig instance to set both type and TTL (e.g. "1h").
9495
guardrail_id: ID of the guardrail to apply
9596
guardrail_trace: Guardrail trace mode. Defaults to enabled.
9697
guardrail_version: Version of the guardrail to apply
@@ -127,7 +128,7 @@ class BedrockConfig(BaseModelConfig, total=False):
127128
additional_response_field_paths: list[str] | None
128129
cache_prompt: str | None
129130
cache_config: CacheConfig | None
130-
cache_tools: str | None
131+
cache_tools: str | CacheToolsConfig | None
131132
guardrail_id: str | None
132133
guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None
133134
guardrail_stream_processing_mode: Literal["sync", "async"] | None
@@ -292,11 +293,7 @@ def _format_request(
292293
}
293294
for tool_spec in tool_specs
294295
],
295-
*(
296-
[{"cachePoint": {"type": self.config["cache_tools"]}}]
297-
if self.config.get("cache_tools")
298-
else []
299-
),
296+
*self._build_tools_cache_point(),
300297
],
301298
**({"toolChoice": tool_choice if tool_choice else {"auto": {}}}),
302299
}
@@ -371,6 +368,25 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict
371368

372369
return {"additionalModelRequestFields": additional_fields}
373370

371+
def _build_tools_cache_point(self) -> list[dict[str, Any]]:
372+
"""Build the cache point block appended to ``toolConfig.tools`` if ``cache_tools`` is configured.
373+
374+
Returns:
375+
A single-element list containing the cache point block, or an empty list if no cache_tools is set.
376+
"""
377+
cache_tools = self.config.get("cache_tools")
378+
if not cache_tools:
379+
return []
380+
381+
if isinstance(cache_tools, CacheToolsConfig):
382+
cache_point: dict[str, Any] = {"type": cache_tools.type}
383+
if cache_tools.ttl:
384+
cache_point["ttl"] = cache_tools.ttl
385+
else:
386+
cache_point = {"type": cache_tools}
387+
388+
return [{"cachePoint": cache_point}]
389+
374390
def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
375391
"""Inject a cache point at the end of the last user message.
376392
@@ -395,7 +411,11 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None:
395411
last_user_idx = msg_idx
396412

397413
if last_user_idx is not None and messages[last_user_idx].get("content"):
398-
messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}})
414+
cache_point: dict[str, Any] = {"type": "default"}
415+
cache_config = self.config.get("cache_config")
416+
if cache_config and cache_config.ttl:
417+
cache_point["ttl"] = cache_config.ttl
418+
messages[last_user_idx]["content"].append({"cachePoint": cache_point})
399419
logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx)
400420

401421
def _find_last_user_text_message_index(self, messages: Messages) -> int | None:

src/strands/models/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,25 @@ class CacheConfig:
134134
strategy: Caching strategy to use.
135135
- "auto": Automatically detect model support and inject cachePoint to maximize cache coverage
136136
- "anthropic": Inject cachePoint in Anthropic-compatible format without model support check
137+
ttl: Optional TTL duration for cache entries (e.g. "5m", "1h").
138+
When specified, auto-injected cache points will include this TTL value.
137139
"""
138140

139141
strategy: Literal["auto", "anthropic"] = "auto"
142+
ttl: str | None = None
143+
144+
145+
@dataclass
146+
class CacheToolsConfig:
147+
"""Configuration for the toolConfig cache point.
148+
149+
Attributes:
150+
type: Cache point type (e.g. "default").
151+
ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h").
152+
"""
153+
154+
type: str = "default"
155+
ttl: str | None = None
140156

141157

142158
class Model(abc.ABC):

tests/strands/models/test_bedrock.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import strands
1616
from strands import _exception_notes
17-
from strands.models import BedrockModel, CacheConfig
17+
from strands.models import BedrockModel, CacheConfig, CacheToolsConfig
1818
from strands.models.bedrock import (
1919
DEFAULT_BEDROCK_MODEL_ID,
2020
DEFAULT_BEDROCK_REGION,
@@ -3554,3 +3554,69 @@ async def test_skip_native_api_by_default(self, bedrock_client, model_id, messag
35543554
bedrock_client.count_tokens.assert_not_called()
35553555
assert isinstance(result, int)
35563556
assert result >= 0
3557+
3558+
3559+
def test_inject_cache_point_with_ttl(bedrock_client):
3560+
"""Test that _inject_cache_point includes TTL when cache_config has ttl set."""
3561+
model = BedrockModel(
3562+
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0",
3563+
cache_config=CacheConfig(strategy="auto", ttl="5m"),
3564+
)
3565+
3566+
cleaned_messages = [
3567+
{"role": "user", "content": [{"text": "Hello"}]},
3568+
]
3569+
3570+
model._inject_cache_point(cleaned_messages)
3571+
3572+
cache_point = cleaned_messages[0]["content"][-1]["cachePoint"]
3573+
assert cache_point["type"] == "default"
3574+
assert cache_point["ttl"] == "5m"
3575+
3576+
3577+
def test_inject_cache_point_without_ttl(bedrock_client):
3578+
"""Test that _inject_cache_point omits TTL when cache_config has no ttl."""
3579+
model = BedrockModel(
3580+
model_id="us.anthropic.claude-sonnet-4-20250514-v1:0",
3581+
cache_config=CacheConfig(strategy="auto"),
3582+
)
3583+
3584+
cleaned_messages = [
3585+
{"role": "user", "content": [{"text": "Hello"}]},
3586+
]
3587+
3588+
model._inject_cache_point(cleaned_messages)
3589+
3590+
cache_point = cleaned_messages[0]["content"][-1]["cachePoint"]
3591+
assert cache_point["type"] == "default"
3592+
assert "ttl" not in cache_point
3593+
3594+
3595+
def test_format_request_cache_tools_config_with_ttl(model, messages, model_id, tool_spec, cache_type):
3596+
"""Test that CacheToolsConfig propagates type and ttl into toolConfig cachePoint."""
3597+
model.update_config(cache_tools=CacheToolsConfig(type=cache_type, ttl="5m"))
3598+
3599+
tru_request = model._format_request(messages, tool_specs=[tool_spec])
3600+
3601+
exp_cache_point = {"cachePoint": {"type": cache_type, "ttl": "5m"}}
3602+
assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point
3603+
3604+
3605+
def test_format_request_cache_tools_config_without_ttl(model, messages, model_id, tool_spec, cache_type):
3606+
"""Test that CacheToolsConfig without ttl produces a cachePoint with only type."""
3607+
model.update_config(cache_tools=CacheToolsConfig(type=cache_type))
3608+
3609+
tru_request = model._format_request(messages, tool_specs=[tool_spec])
3610+
3611+
exp_cache_point = {"cachePoint": {"type": cache_type}}
3612+
assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point
3613+
3614+
3615+
def test_format_request_cache_tools_string_backward_compat(model, messages, model_id, tool_spec, cache_type):
3616+
"""Test that passing cache_tools as a string still produces a cachePoint with only type."""
3617+
model.update_config(cache_tools=cache_type)
3618+
3619+
tru_request = model._format_request(messages, tool_specs=[tool_spec])
3620+
3621+
exp_cache_point = {"cachePoint": {"type": cache_type}}
3622+
assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point

tests_integ/models/test_model_bedrock.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,18 @@
66

77
import strands
88
from strands import Agent
9-
from strands.models import BedrockModel
9+
from strands.models import BedrockModel, CacheConfig, CacheToolsConfig
1010
from strands.types.content import ContentBlock
1111

12+
# Model ID used for prompt-caching TTL integration tests. Per
13+
# https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
14+
# the models that officially support 1h TTL on CachePoint are Claude Opus 4.5,
15+
# Claude Haiku 4.5, and Claude Sonnet 4.5. Haiku 4.5 is the newest Haiku
16+
# available and is preferred for CI due to lower latency and cost relative to
17+
# the same-version Sonnet 4.5. Bump this when a newer Haiku is released that
18+
# supports CachePoint TTL.
19+
_CACHE_TTL_MODEL_ID = "us.anthropic.claude-haiku-4-5-20251001-v1:0"
20+
1221

1322
@pytest.fixture
1423
def system_prompt():
@@ -561,3 +570,127 @@ def calculator(expression: str) -> float:
561570
agent('Search for "python" with tags ["programming", "language"] using the search tool.')
562571

563572
assert "search" in tools_called
573+
574+
575+
def test_prompt_caching_cache_tools_ttl():
576+
"""Test that CacheToolsConfig(ttl=...) propagates into the auto-injected toolConfig cache point.
577+
578+
Verifies that BedrockModel(cache_tools=CacheToolsConfig(type="default", ttl="5m")) produces a
579+
Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call
580+
completes without a ValidationException on the TTL field.
581+
582+
Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig
583+
prefix because Bedrock's tool-prefix cache threshold varies by model and region.
584+
The critical behavior under test here is that the TTL field is accepted end-to-end.
585+
586+
Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per
587+
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
588+
(Claude Opus 4.5, Claude Haiku 4.5, and Claude Sonnet 4.5 all support 1h TTL).
589+
"""
590+
model = BedrockModel(
591+
model_id=_CACHE_TTL_MODEL_ID,
592+
streaming=False,
593+
cache_tools=CacheToolsConfig(type="default", ttl="5m"),
594+
)
595+
596+
@strands.tool
597+
def lookup_fact(topic: str) -> str:
598+
"""Look up a fact about the given topic.
599+
600+
This tool is useful when you need authoritative information.
601+
"""
602+
return f"Fact about {topic}: example"
603+
604+
agent = Agent(
605+
model=model,
606+
tools=[lookup_fact],
607+
load_tools_from_directory=False,
608+
)
609+
610+
# The call must succeed — Bedrock must accept cachePoint.ttl on the toolConfig checkpoint
611+
# without raising a ValidationException.
612+
result = agent("Use the lookup_fact tool to look up 'python'.")
613+
assert len(str(result)) > 0
614+
615+
616+
def test_prompt_caching_cache_config_auto_with_ttl():
617+
"""Test that CacheConfig(strategy="auto", ttl="5m") propagates TTL to the auto-injected message cache point.
618+
619+
Verifies that the cache point appended to the last user message by _inject_cache_point
620+
carries the configured TTL, and that Bedrock accepts the request.
621+
622+
Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per
623+
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
624+
"""
625+
model = BedrockModel(
626+
model_id=_CACHE_TTL_MODEL_ID,
627+
streaming=False,
628+
cache_config=CacheConfig(strategy="auto", ttl="5m"),
629+
)
630+
631+
unique_id = str(uuid.uuid4())
632+
# Minimum 4096 tokens required for caching with Haiku 4.5
633+
large_message = f"Context for test {unique_id}: " + ("This is important context. " * 1000) + " What is 2+2?"
634+
635+
agent = Agent(
636+
model=model,
637+
load_tools_from_directory=False,
638+
)
639+
640+
# First call: auto-injected cache point on the last user message must include ttl and be accepted
641+
result1 = agent(large_message)
642+
assert len(str(result1)) > 0
643+
644+
# Verify cache write occurred with auto-inject + ttl
645+
assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, (
646+
"Expected cacheWriteInputTokens > 0 with CacheConfig(strategy='auto', ttl='5m')"
647+
)
648+
649+
650+
def test_prompt_caching_aligned_1h_ttl_across_checkpoints():
651+
"""Regression test for Bedrock TTL non-increasing ordering rule (Issue #2121).
652+
653+
Bedrock processes cache checkpoints in order: toolConfig -> system -> messages,
654+
and requires TTLs to be non-increasing. Before this change, cache_tools hardcoded
655+
an implicit 5m TTL, so any 1h TTL on a later checkpoint would raise a
656+
ValidationException.
657+
658+
This test sets 1h TTL on all three checkpoints simultaneously and verifies the
659+
call succeeds.
660+
661+
Uses Claude Haiku 4.5 which supports 1h TTL per
662+
https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
663+
"""
664+
model = BedrockModel(
665+
model_id=_CACHE_TTL_MODEL_ID,
666+
streaming=False,
667+
cache_tools=CacheToolsConfig(type="default", ttl="1h"),
668+
cache_config=CacheConfig(strategy="auto", ttl="1h"),
669+
)
670+
671+
# Timestamp-based uniqueness to avoid cache conflicts across CI runs
672+
unique_id = str(int(time.time() * 1000000))
673+
large_context = f"Background context for test {unique_id}: " + ("This is important context. " * 1000)
674+
675+
# User-supplied 1h cache point on system prompt — third checkpoint also at 1h
676+
system_prompt_with_cache = [
677+
{"text": large_context},
678+
{"cachePoint": {"type": "default", "ttl": "1h"}},
679+
{"text": "You are a helpful assistant."},
680+
]
681+
682+
@strands.tool
683+
def echo(value: str) -> str:
684+
"""Echo the given value back."""
685+
return value
686+
687+
agent = Agent(
688+
model=model,
689+
system_prompt=system_prompt_with_cache,
690+
tools=[echo],
691+
load_tools_from_directory=False,
692+
)
693+
694+
# Must succeed without ValidationException on the non-increasing TTL rule
695+
result = agent("What is 2+2?")
696+
assert len(str(result)) > 0

0 commit comments

Comments
 (0)