4040
4141
4242# Haystack to Bedrock util methods
43- def _format_tools (tools : list [Tool ] | None = None ) -> dict [str , Any ] | None :
43+ def _format_tools (
44+ tools : list [Tool ] | None = None , tools_cachepoint_config : dict [str , dict [str , str ]] | None = None
45+ ) -> dict [str , Any ] | None :
4446 """
4547 Format Haystack Tool(s) to Amazon Bedrock toolConfig format.
4648
@@ -51,13 +53,16 @@ def _format_tools(tools: list[Tool] | None = None) -> dict[str, Any] | None:
5153 if not tools :
5254 return None
5355
54- tool_specs = []
56+ tool_specs : list [ dict [ str , Any ]] = []
5557 for tool in tools :
5658 tool_specs .append (
5759 {"toolSpec" : {"name" : tool .name , "description" : tool .description , "inputSchema" : {"json" : tool .parameters }}}
5860 )
5961
60- return {"tools" : tool_specs } if tool_specs else None
62+ if tools_cachepoint_config :
63+ tool_specs .append ({"cachePoint" : tools_cachepoint_config })
64+
65+ return {"tools" : tool_specs }
6166
6267
6368def _convert_image_content_to_bedrock_format (image_content : ImageContent ) -> dict [str , Any ]:
@@ -181,20 +186,23 @@ def _repair_tool_result_messages(bedrock_formatted_messages: list[dict[str, Any]
181186 original_idx = None
182187 for tool_call_id in tool_call_ids :
183188 for idx , tool_result in tool_result_messages :
184- tool_result_contents = [c for c in tool_result ["content" ] if "toolResult" in c ]
189+ tool_result_contents = [c for c in tool_result ["content" ] if "toolResult" in c or "cachePoint" in c ]
185190 for content in tool_result_contents :
186- if content ["toolResult" ]["toolUseId" ] == tool_call_id :
191+ if "toolResult" in content and content ["toolResult" ]["toolUseId" ] == tool_call_id :
187192 regrouped_tool_result .append (content )
188193 # Keep track of the original index of the last tool result message
189194 original_idx = idx
195+ elif "cachePoint" in content and content not in regrouped_tool_result :
196+ regrouped_tool_result .append (content )
197+
190198 if regrouped_tool_result and original_idx is not None :
191199 repaired_tool_result_prompts .append ((original_idx , {"role" : "user" , "content" : regrouped_tool_result }))
192200
193201 # Remove the tool result messages from bedrock_formatted_messages
194202 bedrock_formatted_messages_minus_tool_results : list [tuple [int , Any ]] = []
195203 for idx , msg in enumerate (bedrock_formatted_messages ):
196- # Assumes the content of tool result messages only contains ' toolResult': {...} objects (e.g. no 'text' )
197- if msg .get ("content" ) and "toolResult" not in msg ["content" ][ 0 ] :
204+ # Filter out messages that contain toolResult (they are handled by repaired_tool_result_prompts )
205+ if msg .get ("content" ) and not any ( "toolResult" in c for c in msg ["content" ]) :
198206 bedrock_formatted_messages_minus_tool_results .append ((idx , msg ))
199207
200208 # Add the repaired tool result messages and sort to maintain the correct order
@@ -251,6 +259,32 @@ def _format_text_image_message(message: ChatMessage) -> dict[str, Any]:
251259 return {"role" : message .role .value , "content" : bedrock_content_blocks }
252260
253261
262+ def _validate_and_format_cache_point (cache_point : dict [str , str ] | None ) -> dict [str , dict [str , str ]] | None :
263+ """
264+ Validate and format a cache point dictionary.
265+
266+ Schema available at https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html
267+
268+ :param cache_point: Cache point dictionary to validate and format.
269+ :returns: Dictionary in Bedrock cachePoint format or None if no cache point is provided.
270+ :raises ValueError: If cache point is not valid.
271+ """
272+ if not cache_point :
273+ return None
274+
275+ if "type" not in cache_point or cache_point ["type" ] != "default" :
276+ err_msg = "Cache point must have a 'type' key with value 'default'."
277+ raise ValueError (err_msg )
278+ if not set (cache_point ).issubset ({"type" , "ttl" }):
279+ err_msg = "Cache point can only contain 'type' and 'ttl' keys."
280+ raise ValueError (err_msg )
281+ if "ttl" in cache_point and cache_point ["ttl" ] not in ("5m" , "1h" ):
282+ err_msg = "Cache point 'ttl' must be one of '5m', '1h'."
283+ raise ValueError (err_msg )
284+
285+ return {"cachePoint" : cache_point }
286+
287+
254288def _format_messages (messages : list [ChatMessage ]) -> tuple [list [dict [str , Any ]], list [dict [str , Any ]]]:
255289 """
256290 Format a list of Haystack ChatMessages to the format expected by Bedrock API.
@@ -264,21 +298,30 @@ def _format_messages(messages: list[ChatMessage]) -> tuple[list[dict[str, Any]],
264298 non_system_messages is a list of properly formatted message dictionaries.
265299 """
266300 # Separate system messages, tool calls, and tool results
267- system_prompts = []
301+ system_prompts : list [ dict [ str , Any ]] = []
268302 bedrock_formatted_messages = []
269303 for msg in messages :
304+ cache_point = _validate_and_format_cache_point (msg .meta .get ("cachePoint" ))
270305 if msg .is_from (ChatRole .SYSTEM ):
271306 # Assuming system messages can only contain text
272307 # Don't need to track idx since system_messages are handled separately
273308 system_prompts .append ({"text" : msg .text })
274- elif msg .tool_calls :
275- bedrock_formatted_messages .append (_format_tool_call_message (msg ))
309+ if cache_point :
310+ system_prompts .append (cache_point )
311+ continue
312+
313+ if msg .tool_calls :
314+ formatted_msg = _format_tool_call_message (msg )
276315 elif msg .tool_call_results :
277- bedrock_formatted_messages . append ( _format_tool_result_message (msg ) )
316+ formatted_msg = _format_tool_result_message (msg )
278317 else :
279- bedrock_formatted_messages .append (_format_text_image_message (msg ))
318+ formatted_msg = _format_text_image_message (msg )
319+ if cache_point :
320+ formatted_msg ["content" ].append (cache_point )
321+ bedrock_formatted_messages .append (formatted_msg )
280322
281323 repaired_bedrock_formatted_messages = _repair_tool_result_messages (bedrock_formatted_messages )
324+
282325 return system_prompts , repaired_bedrock_formatted_messages
283326
284327
@@ -310,6 +353,9 @@ def _parse_completion_response(response_body: dict[str, Any], model: str) -> lis
310353 "prompt_tokens" : response_body .get ("usage" , {}).get ("inputTokens" , 0 ),
311354 "completion_tokens" : response_body .get ("usage" , {}).get ("outputTokens" , 0 ),
312355 "total_tokens" : response_body .get ("usage" , {}).get ("totalTokens" , 0 ),
356+ "cache_read_input_tokens" : response_body .get ("usage" , {}).get ("cacheReadInputTokens" , 0 ),
357+ "cache_write_input_tokens" : response_body .get ("usage" , {}).get ("cacheWriteInputTokens" , 0 ),
358+ "cache_details" : response_body .get ("usage" , {}).get ("CacheDetails" , {}),
313359 },
314360 }
315361 # guardrail trace
@@ -461,6 +507,9 @@ def _convert_event_to_streaming_chunk(
461507 "prompt_tokens" : usage .get ("inputTokens" , 0 ),
462508 "completion_tokens" : usage .get ("outputTokens" , 0 ),
463509 "total_tokens" : usage .get ("totalTokens" , 0 ),
510+ "cache_read_input_tokens" : usage .get ("cacheReadInputTokens" , 0 ),
511+ "cache_write_input_tokens" : usage .get ("cacheWriteInputTokens" , 0 ),
512+ "cache_details" : usage .get ("cacheDetails" , {}),
464513 }
465514 if "trace" in event_meta :
466515 chunk_meta ["trace" ] = event_meta ["trace" ]
0 commit comments