Skip to content

Commit 093f006

Browse files
authored
fix: normalize token usage conversion in AmazonBedrockGenerator (#3247)
1 parent 813db45 commit 093f006

4 files changed

Lines changed: 324 additions & 2 deletions

File tree

integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,86 @@
55
from botocore.eventstream import EventStream
66
from haystack.dataclasses import StreamingChunk, SyncStreamingCallbackT
77

8+
_USAGE_HEADER_MAP = {
9+
"input_tokens": "x-amzn-bedrock-input-token-count",
10+
"output_tokens": "x-amzn-bedrock-output-token-count",
11+
"cache_read_input_tokens": "x-amzn-bedrock-cache-read-input-token-count",
12+
"cache_write_input_tokens": "x-amzn-bedrock-cache-write-input-token-count",
13+
}
14+
15+
_USAGE_FIELD_MAP = {
16+
"input_tokens": "input_tokens",
17+
"output_tokens": "output_tokens",
18+
"cache_read_input_tokens": "cache_read_input_tokens",
19+
"cache_write_input_tokens": "cache_creation_input_tokens",
20+
}
21+
22+
23+
def _set_usage_value(usage: dict[str, int], key: str, value: Any) -> None:
24+
"""
25+
Sets a usage value coerced to int, ignoring values that are None or not int-convertible.
26+
27+
:param usage: The usage dictionary to update in place.
28+
:param key: The destination key.
29+
:param value: The raw value to coerce and store.
30+
"""
31+
if value is None:
32+
return
33+
try:
34+
usage[key] = int(value)
35+
except (TypeError, ValueError):
36+
return
37+
38+
39+
def _apply_usage(usage: dict[str, int], source: dict[str, Any], field_map: dict[str, str]) -> None:
40+
"""
41+
Copies usage values from a source dictionary into the usage dictionary using the given field map.
42+
43+
:param usage: The usage dictionary to update in place.
44+
:param source: The source dictionary holding raw usage values.
45+
:param field_map: A mapping from destination key to source key.
46+
"""
47+
for dst, src in field_map.items():
48+
_set_usage_value(usage, dst, source.get(src))
49+
50+
51+
def _usage_from_response_metadata(metadata: dict[str, Any]) -> dict[str, int]:
52+
"""
53+
Extracts normalized token usage from Bedrock InvokeModel ResponseMetadata HTTP headers.
54+
55+
:param metadata: The Bedrock response metadata dictionary.
56+
:returns: A normalized usage dictionary, or an empty dictionary when no usage headers are present.
57+
"""
58+
headers = metadata.get("HTTPHeaders") or metadata.get("http_headers") or {}
59+
if not isinstance(headers, dict):
60+
return {}
61+
62+
normalized_headers = {str(key).lower(): value for key, value in headers.items()}
63+
usage: dict[str, int] = {}
64+
_apply_usage(usage, normalized_headers, _USAGE_HEADER_MAP)
65+
return usage
66+
67+
68+
def _merge_usage(metadata: dict[str, Any], usage: dict[str, int]) -> None:
69+
"""
70+
Merges a usage dictionary into the metadata under the ``usage`` key.
71+
72+
Recomputes ``total_tokens`` after merging when both ``input_tokens`` and ``output_tokens``
73+
are present, so partial usage from multiple sources is summed correctly.
74+
75+
:param metadata: The metadata dictionary to update in place.
76+
:param usage: The normalized usage dictionary to merge in.
77+
"""
78+
if not usage:
79+
return
80+
81+
existing_usage = metadata.get("usage")
82+
base = existing_usage if isinstance(existing_usage, dict) else {}
83+
merged_usage = {**base, **usage}
84+
if "input_tokens" in merged_usage and "output_tokens" in merged_usage:
85+
merged_usage["total_tokens"] = merged_usage["input_tokens"] + merged_usage["output_tokens"]
86+
metadata["usage"] = merged_usage
87+
888

989
class BedrockModelAdapter(ABC):
1090
"""
@@ -54,6 +134,20 @@ def get_stream_responses(self, stream: EventStream, streaming_callback: SyncStre
54134
:param streaming_callback: The handler for the streaming response.
55135
:returns: A list of string responses.
56136
"""
137+
responses, _ = self.get_stream_responses_and_metadata(stream, streaming_callback)
138+
return responses
139+
140+
def get_stream_responses_and_metadata(
141+
self, stream: EventStream, streaming_callback: SyncStreamingCallbackT
142+
) -> tuple[list[str], dict[str, Any]]:
143+
"""
144+
Extracts both the responses and normalized metadata from the Amazon Bedrock streaming response.
145+
146+
:param stream: The streaming response from the Amazon Bedrock request.
147+
:param streaming_callback: The handler for the streaming response.
148+
:returns: A tuple of ``(responses, metadata)`` where ``responses`` is a list of string
149+
responses and ``metadata`` is a dictionary that may contain a normalized ``usage`` block.
150+
"""
57151
streaming_chunks: list[StreamingChunk] = []
58152
for event in stream:
59153
chunk = event.get("chunk")
@@ -64,7 +158,37 @@ def get_stream_responses(self, stream: EventStream, streaming_callback: SyncStre
64158
streaming_callback(streaming_chunk)
65159

66160
responses = ["".join(streaming_chunk.content for streaming_chunk in streaming_chunks).lstrip()]
67-
return responses
161+
metadata = self._extract_streaming_metadata(streaming_chunks)
162+
return responses, metadata
163+
164+
def _extract_streaming_metadata(self, streaming_chunks: list[StreamingChunk]) -> dict[str, Any]:
165+
"""
166+
Extracts normalized metadata from Bedrock streaming chunks.
167+
168+
The default implementation handles Anthropic Claude Messages API stream events, which
169+
expose input usage in ``message_start.message.usage`` and output usage in
170+
``message_delta.usage``.
171+
172+
:param streaming_chunks: The streaming chunks emitted during the response.
173+
:returns: A metadata dictionary with a ``usage`` block, or an empty dictionary when no
174+
usage information is present.
175+
"""
176+
usage: dict[str, int] = {}
177+
178+
for streaming_chunk in streaming_chunks:
179+
meta = streaming_chunk.meta
180+
if not isinstance(meta, dict):
181+
continue
182+
message = meta.get("message")
183+
chunk_usage = meta.get("usage")
184+
if message is None and chunk_usage is None:
185+
continue
186+
if isinstance(message, dict) and isinstance(message.get("usage"), dict):
187+
_apply_usage(usage, message["usage"], _USAGE_FIELD_MAP)
188+
if isinstance(chunk_usage, dict):
189+
_apply_usage(usage, chunk_usage, _USAGE_FIELD_MAP)
190+
191+
return {"usage": usage} if usage else {}
68192

69193
def _get_params(self, inference_kwargs: dict[str, Any], default_params: dict[str, Any]) -> dict[str, Any]:
70194
"""

integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
CohereCommandRAdapter,
2626
MetaLlamaAdapter,
2727
MistralAdapter,
28+
_merge_usage,
29+
_usage_from_response_metadata,
2830
)
2931

3032
logger = logging.getLogger(__name__)
@@ -215,6 +217,7 @@ def run(
215217
generation_kwargs["stream"] = streaming_callback is not None
216218

217219
body = self.model_adapter.prepare_body(prompt=prompt, **generation_kwargs)
220+
stream_metadata: dict[str, Any] = {}
218221
try:
219222
if streaming_callback:
220223
response = self.client.invoke_model_with_response_stream(
@@ -224,7 +227,7 @@ def run(
224227
contentType="application/json",
225228
)
226229
response_stream = response["body"]
227-
replies = self.model_adapter.get_stream_responses(
230+
replies, stream_metadata = self.model_adapter.get_stream_responses_and_metadata(
228231
stream=response_stream, streaming_callback=streaming_callback
229232
)
230233
else:
@@ -238,6 +241,8 @@ def run(
238241
replies = self.model_adapter.get_responses(response_body=response_body)
239242

240243
metadata = response.get("ResponseMetadata", {})
244+
_merge_usage(metadata, _usage_from_response_metadata(metadata))
245+
_merge_usage(metadata, stream_metadata.get("usage", {}))
241246

242247
except ClientError as exception:
243248
msg = f"Could not perform inference for Amazon Bedrock model {self.model} due to:\n{exception}"

integrations/amazon_bedrock/tests/test_generator.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
from io import BytesIO
13
from typing import Any
24
from unittest.mock import MagicMock, call
35

@@ -336,6 +338,137 @@ def test_run_client_error(mock_boto3_session):
336338
generator.run("Hello")
337339

338340

341+
def test_run_non_streaming_normalizes_usage_from_headers(mock_boto3_session):
342+
generator = AmazonBedrockGenerator(model="anthropic.claude-v2")
343+
mock_client = mock_boto3_session.return_value.client.return_value
344+
mock_client.invoke_model.return_value = {
345+
"body": BytesIO(json.dumps({"content": [{"type": "text", "text": "ok"}]}).encode()),
346+
"ResponseMetadata": {
347+
"HTTPHeaders": {
348+
"x-amzn-bedrock-input-token-count": "20",
349+
"x-amzn-bedrock-output-token-count": "10",
350+
"x-amzn-bedrock-cache-read-input-token-count": "0",
351+
"x-amzn-bedrock-cache-write-input-token-count": "0",
352+
}
353+
},
354+
}
355+
356+
result = generator.run("hi")
357+
358+
assert result["replies"] == ["ok"]
359+
assert result["meta"]["usage"] == {
360+
"input_tokens": 20,
361+
"output_tokens": 10,
362+
"total_tokens": 30,
363+
"cache_read_input_tokens": 0,
364+
"cache_write_input_tokens": 0,
365+
}
366+
367+
368+
def test_run_non_streaming_without_usage_headers_omits_usage(mock_boto3_session):
369+
generator = AmazonBedrockGenerator(model="anthropic.claude-v2")
370+
mock_client = mock_boto3_session.return_value.client.return_value
371+
mock_client.invoke_model.return_value = {
372+
"body": BytesIO(json.dumps({"content": [{"type": "text", "text": "ok"}]}).encode()),
373+
"ResponseMetadata": {"HTTPHeaders": {}},
374+
}
375+
376+
result = generator.run("hi")
377+
378+
assert "usage" not in result["meta"]
379+
380+
381+
def test_run_streaming_normalizes_anthropic_usage(mock_boto3_session):
382+
generator = AmazonBedrockGenerator(model="anthropic.claude-v2")
383+
mock_client = mock_boto3_session.return_value.client.return_value
384+
385+
stream_body = MagicMock()
386+
stream_body.__iter__.return_value = [
387+
{
388+
"chunk": {
389+
"bytes": json.dumps(
390+
{
391+
"type": "message_start",
392+
"message": {"usage": {"input_tokens": 20, "output_tokens": 1}},
393+
}
394+
).encode()
395+
}
396+
},
397+
{"chunk": {"bytes": json.dumps({"type": "content_block_delta", "delta": {"text": "ok"}}).encode()}},
398+
{"chunk": {"bytes": json.dumps({"type": "message_delta", "usage": {"output_tokens": 10}}).encode()}},
399+
]
400+
mock_client.invoke_model_with_response_stream.return_value = {
401+
"body": stream_body,
402+
"ResponseMetadata": {"RequestId": "req-1"},
403+
}
404+
405+
result = generator.run("hi", streaming_callback=lambda chunk: None)
406+
407+
assert result["replies"] == ["ok"]
408+
assert result["meta"]["usage"]["input_tokens"] == 20
409+
assert result["meta"]["usage"]["output_tokens"] == 10
410+
assert result["meta"]["usage"]["total_tokens"] == 30
411+
412+
413+
def test_run_streaming_with_cache_usage(mock_boto3_session):
414+
generator = AmazonBedrockGenerator(model="anthropic.claude-v2")
415+
mock_client = mock_boto3_session.return_value.client.return_value
416+
417+
stream_body = MagicMock()
418+
stream_body.__iter__.return_value = [
419+
{
420+
"chunk": {
421+
"bytes": json.dumps(
422+
{
423+
"type": "message_start",
424+
"message": {
425+
"usage": {
426+
"input_tokens": 5,
427+
"output_tokens": 1,
428+
"cache_read_input_tokens": 100,
429+
"cache_creation_input_tokens": 50,
430+
}
431+
},
432+
}
433+
).encode()
434+
}
435+
},
436+
{"chunk": {"bytes": json.dumps({"type": "message_delta", "usage": {"output_tokens": 7}}).encode()}},
437+
]
438+
mock_client.invoke_model_with_response_stream.return_value = {
439+
"body": stream_body,
440+
"ResponseMetadata": {},
441+
}
442+
443+
result = generator.run("hi", streaming_callback=lambda chunk: None)
444+
445+
assert result["meta"]["usage"] == {
446+
"input_tokens": 5,
447+
"output_tokens": 7,
448+
"total_tokens": 12,
449+
"cache_read_input_tokens": 100,
450+
"cache_write_input_tokens": 50,
451+
}
452+
453+
454+
def test_run_streaming_without_usage_omits_usage(mock_boto3_session):
455+
generator = AmazonBedrockGenerator(model="anthropic.claude-v2")
456+
mock_client = mock_boto3_session.return_value.client.return_value
457+
458+
stream_body = MagicMock()
459+
stream_body.__iter__.return_value = [
460+
{"chunk": {"bytes": b'{"delta": {"text": "ok"}}'}},
461+
]
462+
mock_client.invoke_model_with_response_stream.return_value = {
463+
"body": stream_body,
464+
"ResponseMetadata": {"RequestId": "req-1"},
465+
}
466+
467+
result = generator.run("hi", streaming_callback=lambda chunk: None)
468+
469+
assert "usage" not in result["meta"]
470+
471+
339472
def test_from_dict_with_streaming_callback(mock_boto3_session):
340473
data = {
341474
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
3+
import pytest
4+
from haystack.utils import Secret
5+
6+
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator
7+
8+
MODELS_TO_TEST = [
9+
"global.anthropic.claude-haiku-4-5-20251001-v1:0",
10+
]
11+
12+
13+
def _generator(model: str) -> AmazonBedrockGenerator:
14+
return AmazonBedrockGenerator(
15+
model=model,
16+
max_length=64,
17+
aws_region_name=Secret.from_token(os.environ["AWS_REGION"]),
18+
)
19+
20+
21+
def _assert_usage(usage: dict) -> None:
22+
assert isinstance(usage["input_tokens"], int) and usage["input_tokens"] > 0
23+
assert isinstance(usage["output_tokens"], int) and usage["output_tokens"] > 0
24+
assert usage["total_tokens"] == usage["input_tokens"] + usage["output_tokens"]
25+
26+
27+
@pytest.mark.integration
28+
@pytest.mark.skipif(
29+
not os.getenv("AWS_BEARER_TOKEN_BEDROCK") or not os.getenv("AWS_REGION"),
30+
reason="AWS_BEARER_TOKEN_BEDROCK and AWS_REGION must be set",
31+
)
32+
class TestAmazonBedrockGeneratorInference:
33+
@pytest.mark.parametrize("model", MODELS_TO_TEST)
34+
def test_run_non_streaming_normalizes_usage(self, model: str) -> None:
35+
generator = _generator(model)
36+
result = generator.run("What is the capital of France? Reply in one word.")
37+
38+
assert result["replies"], "No replies received"
39+
assert isinstance(result["replies"][0], str) and result["replies"][0]
40+
41+
meta = result["meta"]
42+
assert "usage" in meta, f"meta does not contain a normalized 'usage' block: {meta}"
43+
_assert_usage(meta["usage"])
44+
45+
@pytest.mark.parametrize("model", MODELS_TO_TEST)
46+
def test_run_streaming_normalizes_usage(self, model: str) -> None:
47+
generator = _generator(model)
48+
chunks: list = []
49+
result = generator.run(
50+
"What is the capital of France? Reply in one word.",
51+
streaming_callback=chunks.append,
52+
)
53+
54+
assert chunks, "Streaming callback was not invoked"
55+
assert result["replies"], "No replies received"
56+
assert isinstance(result["replies"][0], str) and result["replies"][0]
57+
58+
meta = result["meta"]
59+
assert "usage" in meta, f"meta does not contain a normalized 'usage' block: {meta}"
60+
_assert_usage(meta["usage"])

0 commit comments

Comments
 (0)