Skip to content

Commit 6bd3f38

Browse files
fix(google-genai): include cached_content_token_count in streaming responses (#3177)
Co-authored-by: bogdankostic <bogdankostic@web.de>
1 parent 324ad33 commit 6bd3f38

2 files changed

Lines changed: 88 additions & 7 deletions

File tree

integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -550,12 +550,9 @@ def _convert_google_genai_response_to_chatmessage(response: types.GenerateConten
550550
usage["thoughts_token_count"] = usage_metadata.thoughts_token_count
551551

552552
# Add cached content token count if available (implicit or explicit context caching)
553-
if (
554-
usage_metadata
555-
and hasattr(usage_metadata, "cached_content_token_count")
556-
and usage_metadata.cached_content_token_count
557-
):
558-
usage["cached_content_token_count"] = usage_metadata.cached_content_token_count
553+
cached_content_token_count = getattr(usage_metadata, "cached_content_token_count", None) if usage_metadata else None
554+
if cached_content_token_count is not None:
555+
usage["cached_content_token_count"] = cached_content_token_count
559556

560557
usage.update(_convert_usage_metadata_to_serializable(usage_metadata))
561558

@@ -625,6 +622,11 @@ def _convert_google_chunk_to_streaming_chunk(
625622
if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count:
626623
usage["thoughts_token_count"] = usage_metadata.thoughts_token_count
627624

625+
# Add cached content token count if available (context caching)
626+
cached_content_token_count = getattr(usage_metadata, "cached_content_token_count", None) if usage_metadata else None
627+
if cached_content_token_count is not None:
628+
usage["cached_content_token_count"] = cached_content_token_count
629+
628630
if candidate.content and candidate.content.parts:
629631
tc_index = -1
630632
for part_index, part in enumerate(candidate.content.parts):
@@ -717,6 +719,7 @@ def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) ->
717719
reasoning_text_parts: list[str] = []
718720
thought_signatures: list[dict[str, Any]] = []
719721
thoughts_token_count = None
722+
cached_content_token_count = None
720723

721724
for chunk in chunks:
722725
# Extract reasoning from the StreamingChunk.reasoning field
@@ -731,18 +734,26 @@ def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) ->
731734
# We'll keep the last set of signatures as they represent the complete state
732735
thought_signatures = signature_deltas
733736

734-
# Extract thinking token usage (from the last chunk that has it)
737+
# Extract token usage metadata (from the last chunk that has it)
735738
if chunk.meta and "usage" in chunk.meta:
736739
chunk_usage = chunk.meta["usage"]
737740
if "thoughts_token_count" in chunk_usage:
738741
thoughts_token_count = chunk_usage["thoughts_token_count"]
742+
if "cached_content_token_count" in chunk_usage:
743+
cached_content_token_count = chunk_usage["cached_content_token_count"]
739744

740745
# Add thinking token count to usage if present
741746
if thoughts_token_count is not None and "usage" in message.meta:
742747
if message.meta["usage"] is None:
743748
message.meta["usage"] = {}
744749
message.meta["usage"]["thoughts_token_count"] = thoughts_token_count
745750

751+
# Add cached content token count to usage if present
752+
if cached_content_token_count is not None and "usage" in message.meta:
753+
if message.meta["usage"] is None:
754+
message.meta["usage"] = {}
755+
message.meta["usage"]["cached_content_token_count"] = cached_content_token_count
756+
746757
# Add thought signatures to meta if present (for multi-turn context preservation)
747758
if thought_signatures:
748759
message.meta["thought_signatures"] = thought_signatures

integrations/google_genai/tests/test_chat_generator_utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,76 @@ def test_aggregate_streaming_chunks_with_thought_signatures_and_thinking_tokens(
702702
assert "thought_signatures" in result.meta
703703
assert result.meta["thought_signatures"][0]["signature"] == "sig_xyz"
704704

705+
def test_convert_google_chunk_to_streaming_chunk_with_cached_tokens(self, monkeypatch):
706+
"""cached_content_token_count from usage_metadata is included in the streaming chunk's usage."""
707+
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")
708+
component_info = ComponentInfo.from_component(GoogleGenAIChatGenerator())
709+
710+
mock_usage = Mock()
711+
mock_usage.prompt_token_count = 1000
712+
mock_usage.candidates_token_count = 10
713+
mock_usage.total_token_count = 1010
714+
mock_usage.thoughts_token_count = None
715+
mock_usage.cached_content_token_count = 800
716+
717+
mock_part = Mock()
718+
mock_part.text = "The answer is 4."
719+
mock_part.function_call = None
720+
mock_part.thought = False
721+
mock_part.thought_signature = None
722+
mock_content = Mock()
723+
mock_content.parts = [mock_part]
724+
mock_candidate = Mock()
725+
mock_candidate.content = mock_content
726+
mock_candidate.finish_reason = "STOP"
727+
728+
mock_chunk = Mock()
729+
mock_chunk.candidates = [mock_candidate]
730+
mock_chunk.usage_metadata = mock_usage
731+
732+
chunk = _convert_google_chunk_to_streaming_chunk(
733+
chunk=mock_chunk,
734+
index=0,
735+
component_info=component_info,
736+
model="gemini-2.5-flash",
737+
)
738+
739+
assert chunk.meta["usage"]["prompt_tokens"] == 1000
740+
assert chunk.meta["usage"]["completion_tokens"] == 10
741+
assert chunk.meta["usage"]["total_tokens"] == 1010
742+
assert chunk.meta["usage"]["cached_content_token_count"] == 800
743+
744+
def test_aggregate_streaming_chunks_with_cached_tokens(self, monkeypatch):
745+
"""cached_content_token_count from the final chunk is propagated to the aggregated message."""
746+
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")
747+
component_info = ComponentInfo.from_component(GoogleGenAIChatGenerator())
748+
749+
chunk1 = StreamingChunk(
750+
content="Hello",
751+
component_info=component_info,
752+
index=0,
753+
meta={"usage": {"prompt_tokens": 1000, "completion_tokens": 5, "total_tokens": 1005}},
754+
)
755+
final_chunk = StreamingChunk(
756+
content=" world",
757+
component_info=component_info,
758+
index=1,
759+
meta={
760+
"usage": {
761+
"prompt_tokens": 1000,
762+
"completion_tokens": 10,
763+
"total_tokens": 1010,
764+
"cached_content_token_count": 800,
765+
},
766+
"model": "gemini-2.5-flash",
767+
},
768+
)
769+
770+
result = _aggregate_streaming_chunks_with_reasoning([chunk1, final_chunk])
771+
772+
assert result.text == "Hello world"
773+
assert result.meta["usage"]["cached_content_token_count"] == 800
774+
705775

706776
class TestConvertMessageToGoogleGenAI:
707777
def test_convert_message_to_google_genai_format_complex(self):

0 commit comments

Comments
 (0)