55from botocore .eventstream import EventStream
66from haystack .dataclasses import StreamingChunk , SyncStreamingCallbackT
77
8+ _USAGE_HEADER_MAP = {
9+ "input_tokens" : "x-amzn-bedrock-input-token-count" ,
10+ "output_tokens" : "x-amzn-bedrock-output-token-count" ,
11+ "cache_read_input_tokens" : "x-amzn-bedrock-cache-read-input-token-count" ,
12+ "cache_write_input_tokens" : "x-amzn-bedrock-cache-write-input-token-count" ,
13+ }
14+
15+ _USAGE_FIELD_MAP = {
16+ "input_tokens" : "input_tokens" ,
17+ "output_tokens" : "output_tokens" ,
18+ "cache_read_input_tokens" : "cache_read_input_tokens" ,
19+ "cache_write_input_tokens" : "cache_creation_input_tokens" ,
20+ }
21+
22+
23+ def _set_usage_value (usage : dict [str , int ], key : str , value : Any ) -> None :
24+ """
25+ Sets a usage value coerced to int, ignoring values that are None or not int-convertible.
26+
27+ :param usage: The usage dictionary to update in place.
28+ :param key: The destination key.
29+ :param value: The raw value to coerce and store.
30+ """
31+ if value is None :
32+ return
33+ try :
34+ usage [key ] = int (value )
35+ except (TypeError , ValueError ):
36+ return
37+
38+
39+ def _apply_usage (usage : dict [str , int ], source : dict [str , Any ], field_map : dict [str , str ]) -> None :
40+ """
41+ Copies usage values from a source dictionary into the usage dictionary using the given field map.
42+
43+ :param usage: The usage dictionary to update in place.
44+ :param source: The source dictionary holding raw usage values.
45+ :param field_map: A mapping from destination key to source key.
46+ """
47+ for dst , src in field_map .items ():
48+ _set_usage_value (usage , dst , source .get (src ))
49+
50+
51+ def _usage_from_response_metadata (metadata : dict [str , Any ]) -> dict [str , int ]:
52+ """
53+ Extracts normalized token usage from Bedrock InvokeModel ResponseMetadata HTTP headers.
54+
55+ :param metadata: The Bedrock response metadata dictionary.
56+ :returns: A normalized usage dictionary, or an empty dictionary when no usage headers are present.
57+ """
58+ headers = metadata .get ("HTTPHeaders" ) or metadata .get ("http_headers" ) or {}
59+ if not isinstance (headers , dict ):
60+ return {}
61+
62+ normalized_headers = {str (key ).lower (): value for key , value in headers .items ()}
63+ usage : dict [str , int ] = {}
64+ _apply_usage (usage , normalized_headers , _USAGE_HEADER_MAP )
65+ return usage
66+
67+
68+ def _merge_usage (metadata : dict [str , Any ], usage : dict [str , int ]) -> None :
69+ """
70+ Merges a usage dictionary into the metadata under the ``usage`` key.
71+
72+ Recomputes ``total_tokens`` after merging when both ``input_tokens`` and ``output_tokens``
73+ are present, so partial usage from multiple sources is summed correctly.
74+
75+ :param metadata: The metadata dictionary to update in place.
76+ :param usage: The normalized usage dictionary to merge in.
77+ """
78+ if not usage :
79+ return
80+
81+ existing_usage = metadata .get ("usage" )
82+ base = existing_usage if isinstance (existing_usage , dict ) else {}
83+ merged_usage = {** base , ** usage }
84+ if "input_tokens" in merged_usage and "output_tokens" in merged_usage :
85+ merged_usage ["total_tokens" ] = merged_usage ["input_tokens" ] + merged_usage ["output_tokens" ]
86+ metadata ["usage" ] = merged_usage
87+
888
989class BedrockModelAdapter (ABC ):
1090 """
@@ -54,6 +134,20 @@ def get_stream_responses(self, stream: EventStream, streaming_callback: SyncStre
54134 :param streaming_callback: The handler for the streaming response.
55135 :returns: A list of string responses.
56136 """
137+ responses , _ = self .get_stream_responses_and_metadata (stream , streaming_callback )
138+ return responses
139+
140+ def get_stream_responses_and_metadata (
141+ self , stream : EventStream , streaming_callback : SyncStreamingCallbackT
142+ ) -> tuple [list [str ], dict [str , Any ]]:
143+ """
144+ Extracts both the responses and normalized metadata from the Amazon Bedrock streaming response.
145+
146+ :param stream: The streaming response from the Amazon Bedrock request.
147+ :param streaming_callback: The handler for the streaming response.
148+ :returns: A tuple of ``(responses, metadata)`` where ``responses`` is a list of string
149+ responses and ``metadata`` is a dictionary that may contain a normalized ``usage`` block.
150+ """
57151 streaming_chunks : list [StreamingChunk ] = []
58152 for event in stream :
59153 chunk = event .get ("chunk" )
@@ -64,7 +158,37 @@ def get_stream_responses(self, stream: EventStream, streaming_callback: SyncStre
64158 streaming_callback (streaming_chunk )
65159
66160 responses = ["" .join (streaming_chunk .content for streaming_chunk in streaming_chunks ).lstrip ()]
67- return responses
161+ metadata = self ._extract_streaming_metadata (streaming_chunks )
162+ return responses , metadata
163+
164+ def _extract_streaming_metadata (self , streaming_chunks : list [StreamingChunk ]) -> dict [str , Any ]:
165+ """
166+ Extracts normalized metadata from Bedrock streaming chunks.
167+
168+ The default implementation handles Anthropic Claude Messages API stream events, which
169+ expose input usage in ``message_start.message.usage`` and output usage in
170+ ``message_delta.usage``.
171+
172+ :param streaming_chunks: The streaming chunks emitted during the response.
173+ :returns: A metadata dictionary with a ``usage`` block, or an empty dictionary when no
174+ usage information is present.
175+ """
176+ usage : dict [str , int ] = {}
177+
178+ for streaming_chunk in streaming_chunks :
179+ meta = streaming_chunk .meta
180+ if not isinstance (meta , dict ):
181+ continue
182+ message = meta .get ("message" )
183+ chunk_usage = meta .get ("usage" )
184+ if message is None and chunk_usage is None :
185+ continue
186+ if isinstance (message , dict ) and isinstance (message .get ("usage" ), dict ):
187+ _apply_usage (usage , message ["usage" ], _USAGE_FIELD_MAP )
188+ if isinstance (chunk_usage , dict ):
189+ _apply_usage (usage , chunk_usage , _USAGE_FIELD_MAP )
190+
191+ return {"usage" : usage } if usage else {}
68192
69193 def _get_params (self , inference_kwargs : dict [str , Any ], default_params : dict [str , Any ]) -> dict [str , Any ]:
70194 """
0 commit comments