Skip to content

Commit a5f0224

Browse files
authored
Prompt caching with application inference profiles (#281)
* fix(bedrock): Resolve inference profile ARNs for cachePoint support * feat(version): bump version to 0.5.7-wip1 and update template.yaml permissions
1 parent 7456202 commit a5f0224

11 files changed

Lines changed: 524 additions & 5 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ SPDX-License-Identifier: MIT-0
99

1010
- **Configuration Version in Metering Database** — Added `config_version` field to the metering database to enable cost tracking and analytics per configuration version. The metering Glue table now includes a `config_version` column, and all metering Parquet files store the configuration version used for each document. Enables Athena queries to compare costs across different configurations, support A/B testing analytics, and optimize per-version costs. Documents without a config version default to "default".
1111

12+
### Fixed
13+
14+
- **Application Inference Profile IAM permissions** — Added `application-inference-profile/*` ARN pattern to `bedrock:InvokeModel` IAM policies across all templates (root, appsync, multi-doc-discovery, and sample templates). PR #236 previously fixed only `patterns/unified/template.yaml`; this completes the fix for all Lambda execution roles. Also added `bedrock:GetInferenceProfile` read permission to support prompt caching resolution. ([#272](https://github.com/aws-solutions-library-samples/accelerated-intelligent-document-processing-on-aws/issues/272))
15+
16+
- **Prompt caching with application inference profiles** — Fixed `<<CACHEPOINT>>` tags being stripped when using Bedrock application inference profile ARNs as model IDs. The cachepoint check now resolves inference profile ARNs to their underlying foundation model via the `GetInferenceProfile` API, enabling prompt caching for profiles that wrap supported models (Claude, Nova). Results are cached to avoid repeated API calls, with graceful fallback if the API call fails. ([#272](https://github.com/aws-solutions-library-samples/accelerated-intelligent-document-processing-on-aws/issues/272))
17+
1218
## [0.5.6]
1319

1420
### Added

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.5.6
1+
0.5.7-wip1

lib/idp_common_pkg/idp_common/bedrock/client.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class _RequestsConnectTimeout(Exception):
6666
DEFAULT_MAX_BACKOFF = 300 # 5 minutes
6767

6868

69+
# Base model names that support cachePoint (without region prefix)
70+
# Used to check inference profiles by resolving their underlying foundation model
71+
_CACHEPOINT_BASE_MODELS = set()
72+
6973
# Models that support cachePoint functionality
7074
CACHEPOINT_SUPPORTED_MODELS = [
7175
"us.anthropic.claude-3-5-haiku-20241022-v1:0",
@@ -111,6 +115,21 @@ class _RequestsConnectTimeout(Exception):
111115
"global.anthropic.claude-opus-4-6-v1:1m",
112116
]
113117

118+
# Build set of base model names (without region/tier prefixes) for inference profile resolution.
119+
# e.g., "us.anthropic.claude-sonnet-4-6" -> "anthropic.claude-sonnet-4-6"
120+
# and "eu.amazon.nova-2-lite-v1:0:priority" -> "amazon.nova-2-lite-v1:0"
121+
for _model_id in CACHEPOINT_SUPPORTED_MODELS:
122+
_parts = _model_id.split(".", 1)
123+
if len(_parts) == 2 and _parts[0] in ("us", "eu", "global"):
124+
_base = _parts[1]
125+
# Strip tier suffixes (:priority, :flex) but keep version suffixes (:0, :1m)
126+
if _base.endswith(":priority") or _base.endswith(":flex"):
127+
_base = _base.rsplit(":", 1)[0]
128+
_CACHEPOINT_BASE_MODELS.add(_base)
129+
130+
# Module-level cache for inference profile -> cachepoint support resolution
131+
_inference_profile_cachepoint_cache: Dict[str, bool] = {}
132+
114133

115134
class BedrockClient:
116135
"""Client for interacting with Amazon Bedrock models and custom Lambda hooks."""
@@ -139,6 +158,7 @@ def __init__(
139158
self.max_backoff = max_backoff
140159
self.metrics_enabled = metrics_enabled
141160
self._client = None
161+
self._bedrock_control_client = None
142162
self._lambda_client = None
143163
self._s3_client = None
144164

@@ -164,6 +184,15 @@ def lambda_client(self):
164184
)
165185
return self._lambda_client
166186

187+
@property
188+
def bedrock_control_client(self):
189+
"""Lazy-loaded Bedrock control plane client for GetInferenceProfile etc."""
190+
if self._bedrock_control_client is None:
191+
self._bedrock_control_client = boto3.client(
192+
"bedrock", region_name=self.region
193+
)
194+
return self._bedrock_control_client
195+
167196
@property
168197
def s3_client(self):
169198
"""Lazy-loaded S3 client for LambdaHook image uploads."""
@@ -173,6 +202,93 @@ def s3_client(self):
173202
)
174203
return self._s3_client
175204

205+
def _is_model_cachepoint_supported(self, model_id: str) -> bool:
206+
"""
207+
Check if a model supports cachePoint, including inference profile resolution.
208+
209+
For standard model IDs (e.g., "us.anthropic.claude-sonnet-4-6"), checks
210+
the CACHEPOINT_SUPPORTED_MODELS list directly.
211+
212+
For inference profile ARNs (containing "inference-profile" or
213+
"application-inference-profile"), resolves the underlying foundation
214+
model via the GetInferenceProfile API and checks if that base model
215+
supports cachePoint. Results are cached to avoid repeated API calls.
216+
217+
Args:
218+
model_id: Bedrock model ID or inference profile ARN
219+
220+
Returns:
221+
True if the model (or underlying model for inference profiles) supports cachePoint
222+
"""
223+
# Fast path: direct match against the known list
224+
if model_id in CACHEPOINT_SUPPORTED_MODELS:
225+
return True
226+
227+
# Check if this is an inference profile ARN
228+
if "inference-profile" not in model_id:
229+
return False
230+
231+
# Check module-level cache
232+
if model_id in _inference_profile_cachepoint_cache:
233+
cached = _inference_profile_cachepoint_cache[model_id]
234+
logger.debug(
235+
f"Inference profile cachepoint support (cached): {model_id} -> {cached}"
236+
)
237+
return cached
238+
239+
# Resolve the inference profile to its underlying foundation model
240+
try:
241+
response = self.bedrock_control_client.get_inference_profile(
242+
inferenceProfileIdentifier=model_id
243+
)
244+
models = response.get("models", [])
245+
if not models:
246+
logger.warning(
247+
f"Inference profile {model_id} has no models listed. "
248+
"Cannot determine cachePoint support."
249+
)
250+
_inference_profile_cachepoint_cache[model_id] = False
251+
return False
252+
253+
# Extract the base model name from the first model's ARN.
254+
# Model ARN format: arn:aws:bedrock:<region>::foundation-model/<base-model-name>
255+
# e.g., "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-sonnet-4-6"
256+
first_model_arn = models[0].get("modelArn", "")
257+
if "foundation-model/" in first_model_arn:
258+
base_model_name = first_model_arn.split("foundation-model/")[-1]
259+
else:
260+
logger.warning(
261+
f"Cannot parse foundation model from ARN: {first_model_arn}"
262+
)
263+
_inference_profile_cachepoint_cache[model_id] = False
264+
return False
265+
266+
supported = base_model_name in _CACHEPOINT_BASE_MODELS
267+
_inference_profile_cachepoint_cache[model_id] = supported
268+
269+
logger.info(
270+
f"Resolved inference profile {model_id} -> "
271+
f"foundation model '{base_model_name}' -> "
272+
f"cachePoint {'supported' if supported else 'not supported'}"
273+
)
274+
return supported
275+
276+
except ClientError as e:
277+
error_code = e.response["Error"]["Code"]
278+
logger.warning(
279+
f"Failed to resolve inference profile {model_id} for cachePoint check "
280+
f"({error_code}): {e}. Disabling cachePoint for this model."
281+
)
282+
_inference_profile_cachepoint_cache[model_id] = False
283+
return False
284+
except Exception as e:
285+
logger.warning(
286+
f"Unexpected error resolving inference profile {model_id} "
287+
f"for cachePoint check: {e}. Disabling cachePoint for this model."
288+
)
289+
_inference_profile_cachepoint_cache[model_id] = False
290+
return False
291+
176292
def __call__(
177293
self,
178294
model_id: str,
@@ -375,7 +491,7 @@ def invoke_model(
375491
)
376492

377493
if has_cachepoint_tags:
378-
if model_id in CACHEPOINT_SUPPORTED_MODELS:
494+
if self._is_model_cachepoint_supported(model_id):
379495
# Process content for cachePoint tags with supported model
380496
processed_content = self._preprocess_content_for_cachepoint(content)
381497
logger.info(
@@ -394,7 +510,9 @@ def invoke_model(
394510
clean_text = item["text"].replace("<<CACHEPOINT>>", "")
395511
processed_content.append({"text": clean_text})
396512
logger.warning(
397-
f"Removed <<CACHEPOINT>> tags for unsupported model: {model_id}. CachePoint is only supported for: {', '.join(CACHEPOINT_SUPPORTED_MODELS)}"
513+
f"Removed <<CACHEPOINT>> tags for unsupported model: {model_id}. "
514+
"CachePoint is supported for standard cross-region inference profiles "
515+
"and application inference profiles that wrap supported foundation models."
398516
)
399517
else:
400518
# Pass through unchanged

0 commit comments

Comments
 (0)