From d5c94fc04dbe97c919541986693fa83d339cc05c Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Feb 2026 16:06:56 +0100 Subject: [PATCH 1/5] feat: Google GenAI - add FileContent support + refactoring --- integrations/google_genai/pyproject.toml | 2 +- .../google_genai/chat/chat_generator.py | 740 +------------- .../generators/google_genai/chat/utils.py | 734 ++++++++++++++ .../google_genai/tests/test_chat_generator.py | 934 ++---------------- .../tests/test_chat_generator_utils.py | 806 +++++++++++++++ .../tests/test_files/sample_pdf_3.pdf | Bin 0 -> 19039 bytes 6 files changed, 1628 insertions(+), 1588 deletions(-) create mode 100644 integrations/google_genai/tests/test_chat_generator_utils.py create mode 100644 integrations/google_genai/tests/test_files/sample_pdf_3.pdf diff --git a/integrations/google_genai/pyproject.toml b/integrations/google_genai/pyproject.toml index b1b16959fb..49f4ae793a 100644 --- a/integrations/google_genai/pyproject.toml +++ b/integrations/google_genai/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.23.0", "google-genai[aiohttp]>=1.51.0", "jsonref>=1.0.0"] +dependencies = ["haystack-ai>=2.24.1", "google-genai[aiohttp]>=1.51.0", "jsonref>=1.0.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_genai#readme" diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py index 11c26970bb..01ee67c314 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py @@ -2,32 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 -import base64 -import json -from collections.abc import AsyncIterator, Iterator -from datetime import datetime, timezone from typing import Any, Literal from google.genai import types -from google.genai.types import GenerateContentResponseUsageMetadata, UsageMetadata -from haystack import logging -from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses import ( - AsyncStreamingCallbackT, - ComponentInfo, - FinishReason, - ImageContent, - StreamingCallbackT, - StreamingChunk, - TextContent, - ToolCall, - ToolCallDelta, - ToolCallResult, - select_streaming_callback, -) -from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ReasoningContent +from haystack.dataclasses import ComponentInfo, StreamingCallbackT, select_streaming_callback +from haystack.dataclasses.chat_message import ChatMessage, ChatRole from haystack.tools import ( ToolsType, _check_duplicate_tool_names, @@ -36,409 +17,16 @@ serialize_tools_or_toolset, ) from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable -from jsonref import replace_refs from haystack_integrations.components.common.google_genai.utils import _get_client -from haystack_integrations.components.generators.google_genai.chat.utils import remove_key_from_schema - -# Mapping from Google GenAI finish reasons to Haystack FinishReason values -FINISH_REASON_MAPPING: dict[str, FinishReason] = { - "STOP": "stop", - "MAX_TOKENS": "length", - "SAFETY": "content_filter", - "BLOCKLIST": "content_filter", - "PROHIBITED_CONTENT": "content_filter", - "SPII": "content_filter", - "IMAGE_SAFETY": "content_filter", -} - - -logger = logging.getLogger(__name__) - -# Google AI supported image MIME types based on documentation -# https://ai.google.dev/gemini-api/docs/image-understanding?lang=python -GOOGLE_AI_SUPPORTED_MIME_TYPES = { - "image/png": "png", - "image/jpeg": "jpeg", - "image/jpg": "jpeg", # Common alias - "image/webp": "webp", - "image/heic": "heic", - "image/heif": "heif", -} - - -def _convert_message_to_google_genai_format(message: ChatMessage) -> types.Content: - """ - Converts a Haystack ChatMessage to Google Gen AI Content format. - - :param message: The Haystack ChatMessage to convert. - :returns: Google Gen AI Content object. - """ - # Check if message has content - if not message._content: - msg = "A `ChatMessage` must contain at least one content part." - raise ValueError(msg) - - parts = [] - - # Check if this message has thought signatures from a previous response - # These need to be reconstructed in their original part structure - thought_signatures = message.meta.get("thought_signatures", []) if message.meta else [] - - # If we have thought signatures, we need to reconstruct the exact part structure - # from the previous assistant response to maintain multi-turn thinking context - if thought_signatures and message.is_from(ChatRole.ASSISTANT): - # Track which tool calls we've used (to handle multiple tool calls) - tool_call_index = 0 - - # Reconstruct parts with their original thought signatures - for sig_info in thought_signatures: - part_dict: dict[str, Any] = {} - - # Check what type of content this part had - if sig_info.get("has_text"): - # Find the corresponding text content - if sig_info.get("is_thought"): - # This was a thought part - find it in reasoning content - if message.reasoning: - part_dict["text"] = message.reasoning.reasoning_text - part_dict["thought"] = True - else: - # Regular text part - part_dict["text"] = message.text or "" - - if sig_info.get("has_function_call"): - # Find the corresponding tool call by index - if message.tool_calls and tool_call_index < len(message.tool_calls): - tool_call = message.tool_calls[tool_call_index] - part_dict["function_call"] = types.FunctionCall( - id=tool_call.id, name=tool_call.tool_name, args=tool_call.arguments - ) - tool_call_index += 1 # Move to next tool call for next part - - # Add the thought signature to preserve context - part_dict["thought_signature"] = sig_info["signature"] - - parts.append(types.Part(**part_dict)) - - # If we reconstructed from signatures, we're done - if parts: - role = "model" # Assistant messages with signatures are always from the model - return types.Content(role=role, parts=parts) - - # Standard processing for messages without thought signatures - for content_part in message._content: - if isinstance(content_part, TextContent): - # Only add text parts that are not empty to avoid unnecessary empty text parts - if content_part.text.strip(): - parts.append(types.Part(text=content_part.text)) - - elif isinstance(content_part, ImageContent): - if not message.is_from(ChatRole.USER): - msg = "Image content is only supported for user messages" - raise ValueError(msg) - - # Validate image MIME type and format - if not content_part.mime_type: - msg = "Image MIME type could not be determined. Please provide a valid image with detectable format." - raise ValueError(msg) - - if content_part.mime_type not in GOOGLE_AI_SUPPORTED_MIME_TYPES: - supported_types = list(GOOGLE_AI_SUPPORTED_MIME_TYPES.keys()) - msg = ( - f"Unsupported image MIME type: {content_part.mime_type}. " - f"Google AI supports the following MIME types: {supported_types}" - ) - raise ValueError(msg) - - # Use inline image data approach - try: - # ImageContent already has base64 data, decode it for bytes - image_bytes = base64.b64decode(content_part.base64_image) - - # Create Part using from_bytes method - image_part = types.Part.from_bytes(data=image_bytes, mime_type=content_part.mime_type) - parts.append(image_part) - - except Exception as e: - msg = f"Failed to process image data: {e}" - raise RuntimeError(msg) from e - - elif isinstance(content_part, ToolCall): - parts.append( - types.Part( - function_call=types.FunctionCall( - id=content_part.id, name=content_part.tool_name, args=content_part.arguments - ) - ) - ) - - elif isinstance(content_part, ToolCallResult): - if isinstance(content_part.result, str): - parts.append( - types.Part( - function_response=types.FunctionResponse( - id=content_part.origin.id, - name=content_part.origin.tool_name, - response={"result": content_part.result}, - ) - ) - ) - elif isinstance(content_part.result, list): - tool_call_result_parts: list[types.FunctionResponsePart] = [] - for item in content_part.result: - if isinstance(item, TextContent): - tool_call_result_parts.append( - types.FunctionResponsePart( - inline_data=types.FunctionResponseBlob( - data=item.text.encode("utf-8"), mime_type="text/plain" - ), - ) - ) - elif isinstance(item, ImageContent): - tool_call_result_parts.append( - types.FunctionResponsePart( - inline_data=types.FunctionResponseBlob( - data=base64.b64decode(item.base64_image), mime_type=item.mime_type - ), - ) - ) - else: - msg = ( - "Unsupported content type in tool call result list. " - "Only TextContent and ImageContent are supported." - ) - raise ValueError(msg) - parts.append( - types.Part( - function_response=types.FunctionResponse( - id=content_part.origin.id, - name=content_part.origin.tool_name, - parts=tool_call_result_parts, - # the response field is mandatory, but in this case the LLM just needs multimodal parts - response={"result": ""}, - ) - ) - ) - else: - msg = "Unsupported content type in tool call result" - raise ValueError(msg) - elif isinstance(content_part, ReasoningContent): - # Reasoning content is for human transparency only, not for maintaining LLM context - # Thought signatures (stored in message.meta) handle context preservation - # Leave this here so we don't implement reasoning content handling in the future accidentally - pass - - # Determine role - if message.is_from(ChatRole.USER) or message.tool_call_results: - role = "user" - elif message.is_from(ChatRole.ASSISTANT): - role = "model" - elif message.is_from(ChatRole.SYSTEM): - # System messages will be handled separately as system instruction - # When we convert a list of ChatMessage to be sent to google genai, - # we need to handle system messages separately as system instruction and we only take the first message - # as the system instruction - if it is present. - # - # If we find any additional system messages, we will treat them as user messages - role = "user" - else: - msg = f"Unsupported message role: {message._role}" - raise ValueError(msg) - - return types.Content(role=role, parts=parts) - - -def _sanitize_tool_schema(tool_schema: dict[str, Any]) -> dict[str, Any]: - """ - Sanitizes a tool schema to remove any keys that are not supported by Google Gen AI. - - Google Gen AI does not support additionalProperties, $schema, $defs, or $ref in the tool schema. - - :param tool_schema: The tool schema to sanitize. - :returns: The sanitized tool schema. - """ - # google Gemini does not support additionalProperties and $schema in the tool schema - sanitized_schema = remove_key_from_schema(tool_schema, "additionalProperties") - sanitized_schema = remove_key_from_schema(sanitized_schema, "$schema") - # expand $refs in the tool schema - expanded_schema = replace_refs(sanitized_schema) - # and remove the $defs key leaving the rest of the schema - final_schema = remove_key_from_schema(expanded_schema, "$defs") - - if not isinstance(final_schema, dict): - msg = "Tool schema must be a dictionary after sanitization" - raise ValueError(msg) - - return final_schema - - -def _convert_tools_to_google_genai_format(tools: ToolsType) -> list[types.Tool]: - """ - Converts a list of Haystack Tools, Toolsets, or a mix to Google Gen AI Tool format. - - :param tools: List of Haystack Tool and/or Toolset objects, or a single Toolset. - :returns: List of Google Gen AI Tool objects. - """ - # Flatten Tools and Toolsets into a single list of Tools - flattened_tools = flatten_tools_or_toolsets(tools) - - function_declarations: list[types.FunctionDeclaration] = [] - for tool in flattened_tools: - parameters = _sanitize_tool_schema(tool.parameters) - function_declarations.append( - types.FunctionDeclaration( - name=tool.name, description=tool.description, parameters=types.Schema(**parameters) - ) - ) - - # Return a single Tool object with all function declarations as in the Google GenAI docs - # we could also return multiple Tool objects, doesn't seem to make a difference - # revisit this decision - return [types.Tool(function_declarations=function_declarations)] - - -def _convert_usage_metadata_to_serializable( - usage_metadata: UsageMetadata | GenerateContentResponseUsageMetadata | None, -) -> dict[str, Any]: - """Build a JSON-serializable usage dict from a UsageMetadata object. - - Iterates over known UsageMetadata attribute names and adds each non-None value - in serialized form. Full list of fields: https://ai.google.dev/api/generate-content#UsageMetadata - """ - - def serialize(val: Any) -> Any: - if val is None: - return None - if isinstance(val, (str, int, float, bool)): - return val - if isinstance(val, list): - return [serialize(item) for item in val] - token_count = getattr(val, "token_count", None) or getattr(val, "tokenCount", None) - if hasattr(val, "modality") and token_count is not None: - mod = getattr(val, "modality", None) - mod_str = getattr(mod, "value", getattr(mod, "name", str(mod))) if mod is not None else None - return {"modality": mod_str, "token_count": token_count} - if hasattr(val, "name"): - return getattr(val, "value", getattr(val, "name", val)) - return val - - if not usage_metadata: - return {} - - _usage_attr_names = ( - "prompt_token_count", - "candidates_token_count", - "total_token_count", - "cache_tokens_details", - "candidates_tokens_details", - "prompt_tokens_details", - "tool_use_prompt_token_count", - "tool_use_prompt_tokens_details", - ) - result: dict[str, Any] = {} - for attr in _usage_attr_names: - val = getattr(usage_metadata, attr, None) - if val is not None: - result[attr] = serialize(val) - return result - - -def _convert_google_genai_response_to_chatmessage(response: types.GenerateContentResponse, model: str) -> ChatMessage: - """ - Converts a Google Gen AI response to a Haystack ChatMessage. - - :param response: The response from Google Gen AI. - :param model: The model name. - :returns: A Haystack ChatMessage. - """ - text_parts = [] - tool_calls = [] - reasoning_parts = [] - thought_signatures = [] # Store thought signatures for multi-turn context - - # Extract text, function calls, thoughts, and thought signatures from response - finish_reason = None - if response.candidates: - candidate = response.candidates[0] - finish_reason = getattr(candidate, "finish_reason", None) - if candidate.content is not None and candidate.content.parts is not None: - for i, part in enumerate(candidate.content.parts): - # Check for thought signature on this part - if hasattr(part, "thought_signature") and part.thought_signature: - # Store the thought signature with its part index for reconstruction - thought_signatures.append( - { - "part_index": i, - "signature": part.thought_signature, - "has_text": part.text is not None, - "has_function_call": part.function_call is not None, - "is_thought": hasattr(part, "thought") and part.thought, - } - ) - - if part.text is not None and not (hasattr(part, "thought") and part.thought): - text_parts.append(part.text) - if part.function_call is not None: - tool_call = ToolCall( - tool_name=part.function_call.name or "", - arguments=dict(part.function_call.args) if part.function_call.args else {}, - id=part.function_call.id, - ) - tool_calls.append(tool_call) - # Handle thought parts for Gemini 2.5 series - if hasattr(part, "thought") and part.thought: - # Extract thought content - if part.text: - reasoning_parts.append(part.text) - - # Combine text parts - text = " ".join(text_parts) if text_parts else "" - - usage_metadata = response.usage_metadata - - # Create usage metadata including thinking tokens if available - usage = { - "prompt_tokens": getattr(usage_metadata, "prompt_token_count", 0), - "completion_tokens": getattr(usage_metadata, "candidates_token_count", 0), - "total_tokens": getattr(usage_metadata, "total_token_count", 0), - } - - # Add thinking token count if available - if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: - usage["thoughts_token_count"] = usage_metadata.thoughts_token_count - - # Add cached content token count if available (implicit or explicit context caching) - if ( - usage_metadata - and hasattr(usage_metadata, "cached_content_token_count") - and usage_metadata.cached_content_token_count - ): - usage["cached_content_token_count"] = usage_metadata.cached_content_token_count - - usage.update(_convert_usage_metadata_to_serializable(usage_metadata)) - - # Create meta with reasoning content and thought signatures if available - meta: dict[str, Any] = { - "model": model, - "finish_reason": FINISH_REASON_MAPPING.get(finish_reason or ""), - "usage": usage, - } - - # Add thought signatures to meta if present (for multi-turn context preservation) - if thought_signatures: - meta["thought_signatures"] = thought_signatures - - # Create ReasoningContent object if there are reasoning parts - reasoning_content = None - if reasoning_parts: - reasoning_text = " ".join(reasoning_parts) - reasoning_content = ReasoningContent(reasoning_text=reasoning_text) - - # Create ChatMessage - message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta, reasoning=reasoning_content) - - return message +from haystack_integrations.components.generators.google_genai.chat.utils import ( + _convert_google_genai_response_to_chatmessage, + _convert_message_to_google_genai_format, + _convert_tools_to_google_genai_format, + _handle_streaming_response, + _handle_streaming_response_async, + _process_thinking_config, +) @component @@ -649,294 +237,6 @@ def from_dict(cls, data: dict[str, Any]) -> "GoogleGenAIChatGenerator": init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"]) return default_from_dict(cls, data) - def _convert_google_chunk_to_streaming_chunk( - self, - chunk: types.GenerateContentResponse, - index: int, - component_info: ComponentInfo, - ) -> StreamingChunk: - """ - Convert a chunk from Google Gen AI to a Haystack StreamingChunk. - - :param chunk: The chunk from Google Gen AI. - :param index: The index of the chunk. - :returns: A StreamingChunk object. - """ - content = "" - tool_calls: list[ToolCallDelta] = [] - finish_reason = None - reasoning_deltas: list[dict[str, str]] = [] - thought_signature_deltas: list[dict[str, Any]] = [] # Track thought signatures in streaming - - if chunk.candidates: - candidate = chunk.candidates[0] - finish_reason = getattr(candidate, "finish_reason", None) - - usage_metadata = chunk.usage_metadata - - usage = { - "prompt_tokens": getattr(usage_metadata, "prompt_token_count", 0) if usage_metadata else 0, - "completion_tokens": getattr(usage_metadata, "candidates_token_count", 0) if usage_metadata else 0, - "total_tokens": getattr(usage_metadata, "total_token_count", 0) if usage_metadata else 0, - } - - # Add thinking token count if available - if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: - usage["thoughts_token_count"] = usage_metadata.thoughts_token_count - - if candidate.content and candidate.content.parts: - tc_index = -1 - for part_index, part in enumerate(candidate.content.parts): - # Check for thought signature on this part (for multi-turn context) - if hasattr(part, "thought_signature") and part.thought_signature: - thought_signature_deltas.append( - { - "part_index": part_index, - "signature": part.thought_signature, - "has_text": part.text is not None, - "has_function_call": part.function_call is not None, - "is_thought": hasattr(part, "thought") and part.thought, - } - ) - - if part.text is not None and not (hasattr(part, "thought") and part.thought): - content += part.text - - elif part.function_call: - tc_index += 1 - tool_calls.append( - ToolCallDelta( - # Google GenAI does not provide index, but it is required for tool calls - index=tc_index, - id=part.function_call.id, - tool_name=part.function_call.name or "", - arguments=json.dumps(part.function_call.args) if part.function_call.args else None, - ) - ) - - # Handle thought parts for Gemini 2.5 series - elif hasattr(part, "thought") and part.thought: - thought_delta = { - "type": "reasoning", - "content": part.text if part.text else "", - } - reasoning_deltas.append(thought_delta) - - # start is only used by print_streaming_chunk. We try to make a reasonable assumption here but it should not be - # a problem if we change it in the future. - start = index == 0 or len(tool_calls) > 0 - - # Create meta with reasoning deltas and thought signatures if available - meta: dict[str, Any] = { - "received_at": datetime.now(timezone.utc).isoformat(), - "model": self._model, - "usage": usage, - } - - # Add reasoning deltas to meta if available - if reasoning_deltas: - meta["reasoning_deltas"] = reasoning_deltas - - # Add thought signature deltas to meta if available (for multi-turn context) - if thought_signature_deltas: - meta["thought_signature_deltas"] = thought_signature_deltas - - return StreamingChunk( - content="" if tool_calls else content, # prioritize tool calls over content when both are present - tool_calls=tool_calls, - component_info=component_info, - index=index, - start=start, - finish_reason=FINISH_REASON_MAPPING.get(finish_reason or ""), - meta=meta, - ) - - @staticmethod - def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) -> ChatMessage: - """ - Aggregate streaming chunks into a final ChatMessage with reasoning content and thought signatures. - - This method extends the standard streaming chunk aggregation to handle Google GenAI's - specific reasoning content, thinking token usage, and thought signatures for multi-turn context. - - :param chunks: List of streaming chunks to aggregate. - :returns: Final ChatMessage with aggregated content, reasoning, and thought signatures. - """ - - # Use the generic aggregator for standard content (text, tool calls, basic meta) - message = _convert_streaming_chunks_to_chat_message(chunks) - - # Now enhance with Google-specific features: reasoning content, thinking token usage, and thought signatures - reasoning_text_parts: list[str] = [] - thought_signatures: list[dict[str, Any]] = [] - thoughts_token_count = None - - for chunk in chunks: - # Extract reasoning deltas - if chunk.meta and "reasoning_deltas" in chunk.meta: - reasoning_deltas = chunk.meta["reasoning_deltas"] - if isinstance(reasoning_deltas, list): - for delta in reasoning_deltas: - if delta.get("type") == "reasoning": - reasoning_text_parts.append(delta.get("content", "")) - - # Extract thought signature deltas (for multi-turn context preservation) - if chunk.meta and "thought_signature_deltas" in chunk.meta: - signature_deltas = chunk.meta["thought_signature_deltas"] - if isinstance(signature_deltas, list): - # Aggregate thought signatures - they should come from the final chunks - # We'll keep the last set of signatures as they represent the complete state - thought_signatures = signature_deltas - - # Extract thinking token usage (from the last chunk that has it) - if chunk.meta and "usage" in chunk.meta: - chunk_usage = chunk.meta["usage"] - if "thoughts_token_count" in chunk_usage: - thoughts_token_count = chunk_usage["thoughts_token_count"] - - # Add thinking token count to usage if present - if thoughts_token_count is not None and "usage" in message.meta: - if message.meta["usage"] is None: - message.meta["usage"] = {} - message.meta["usage"]["thoughts_token_count"] = thoughts_token_count - - # Add thought signatures to meta if present (for multi-turn context preservation) - if thought_signatures: - message.meta["thought_signatures"] = thought_signatures - - # If we have reasoning content, reconstruct the message to include it - # Note: ChatMessage doesn't support adding reasoning after creation, reconstruction is necessary - if reasoning_text_parts: - reasoning_content = ReasoningContent(reasoning_text="".join(reasoning_text_parts)) - return ChatMessage.from_assistant( - text=message.text, tool_calls=message.tool_calls, meta=message.meta, reasoning=reasoning_content - ) - - return message - - def _handle_streaming_response( - self, response_stream: Iterator[types.GenerateContentResponse], streaming_callback: StreamingCallbackT - ) -> dict[str, list[ChatMessage]]: - """ - Handle streaming response from Google Gen AI generate_content_stream. - :param response_stream: The streaming response from generate_content_stream. - :param streaming_callback: The callback function for streaming chunks. - :returns: A dictionary with the replies. - """ - component_info = ComponentInfo.from_component(self) - - try: - chunks = [] - - chunk = None - for i, chunk in enumerate(response_stream): - streaming_chunk = self._convert_google_chunk_to_streaming_chunk( - chunk=chunk, index=i, component_info=component_info - ) - chunks.append(streaming_chunk) - - # Stream the chunk - streaming_callback(streaming_chunk) - - # Use custom aggregation that supports reasoning content - message = GoogleGenAIChatGenerator._aggregate_streaming_chunks_with_reasoning(chunks) - return {"replies": [message]} - - except Exception as e: - msg = f"Error in streaming response: {e}" - raise RuntimeError(msg) from e - - async def _handle_streaming_response_async( - self, response_stream: AsyncIterator[types.GenerateContentResponse], streaming_callback: AsyncStreamingCallbackT - ) -> dict[str, list[ChatMessage]]: - """ - Handle async streaming response from Google Gen AI generate_content_stream. - :param response_stream: The async streaming response from generate_content_stream. - :param streaming_callback: The async callback function for streaming chunks. - :returns: A dictionary with the replies. - """ - component_info = ComponentInfo.from_component(self) - - try: - chunks = [] - - i = 0 - chunk = None - async for chunk in response_stream: - i += 1 - - streaming_chunk = self._convert_google_chunk_to_streaming_chunk( - chunk=chunk, index=i, component_info=component_info - ) - chunks.append(streaming_chunk) - - # Stream the chunk - await streaming_callback(streaming_chunk) - - # Use custom aggregation that supports reasoning content - message = GoogleGenAIChatGenerator._aggregate_streaming_chunks_with_reasoning(chunks) - return {"replies": [message]} - - except Exception as e: - msg = f"Error in async streaming response: {e}" - raise RuntimeError(msg) from e - - @staticmethod - def _process_thinking_config(generation_kwargs: dict[str, Any]) -> dict[str, Any]: - """ - Process thinking configuration from generation_kwargs. - - :param generation_kwargs: The generation configuration dictionary. - :returns: Updated generation_kwargs with thinking_config if applicable. - """ - if "thinking_budget" in generation_kwargs: - thinking_budget = generation_kwargs.pop("thinking_budget") - - # Basic type validation - if not isinstance(thinking_budget, int): - logger.warning( - f"Invalid thinking_budget type: {type(thinking_budget)}. Expected int, using dynamic allocation." - ) - # fall back to default: dynamic thinking budget allocation - thinking_budget = -1 - - # Create thinking config - thinking_config = types.ThinkingConfig(thinking_budget=thinking_budget, include_thoughts=True) - generation_kwargs["thinking_config"] = thinking_config - - if "thinking_level" in generation_kwargs: - thinking_level = generation_kwargs.pop("thinking_level") - - # Basic type validation - if not isinstance(thinking_level, str): - logger.warning( - f"Invalid thinking_level type: {type(thinking_level).__name__}. Expected str, " - f"falling back to THINKING_LEVEL_UNSPECIFIED." - ) - thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - else: - # Convert to uppercase for case-insensitive matching - thinking_level_upper = thinking_level.upper() - - # Check if the uppercase value is a valid ThinkingLevel enum member - valid_levels = [level.value for level in types.ThinkingLevel] - if thinking_level_upper not in valid_levels: - logger.warning( - f"Invalid thinking_level value: '{thinking_level}'. " - f"Must be one of: {valid_levels} (case-insensitive). " - "Falling back to THINKING_LEVEL_UNSPECIFIED." - ) - thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - else: - # Parse valid string to ThinkingLevel enum - thinking_level = types.ThinkingLevel(thinking_level_upper) - - # Create thinking config with level - thinking_config = types.ThinkingConfig(thinking_level=thinking_level, include_thoughts=True) - generation_kwargs["thinking_config"] = thinking_config - - return generation_kwargs - @component.output_types(replies=list[ChatMessage]) def run( self, @@ -971,7 +271,7 @@ def run( tools = tools or self._tools # Process thinking configuration - generation_kwargs = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs) + generation_kwargs = _process_thinking_config(generation_kwargs) # Select appropriate streaming callback streaming_callback = select_streaming_callback( @@ -1019,7 +319,13 @@ def run( contents=contents, config=config, ) - return self._handle_streaming_response(response_stream, streaming_callback) + component_info = ComponentInfo.from_component(self) + return _handle_streaming_response( + component_info=component_info, + response_stream=response_stream, + streaming_callback=streaming_callback, + model=self._model, + ) else: # Use non-streaming response = self._client.models.generate_content( @@ -1081,7 +387,7 @@ async def run_async( tools = tools or self._tools # Process thinking configuration - generation_kwargs = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs) + generation_kwargs = _process_thinking_config(generation_kwargs) # Select appropriate streaming callback streaming_callback = select_streaming_callback( @@ -1129,7 +435,13 @@ async def run_async( contents=contents, config=config, ) - return await self._handle_streaming_response_async(response_stream, streaming_callback) + component_info = ComponentInfo.from_component(self) + return await _handle_streaming_response_async( + component_info=component_info, + response_stream=response_stream, + streaming_callback=streaming_callback, + model=self._model, + ) else: # Use non-streaming response = await self._client.aio.models.generate_content( diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py index b256d7dd15..366d74f86b 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py @@ -2,8 +2,116 @@ # # SPDX-License-Identifier: Apache-2.0 +import base64 +import json +from collections.abc import AsyncIterator, Iterator +from datetime import datetime, timezone from typing import Any +from google.genai import types +from google.genai.types import GenerateContentResponseUsageMetadata, UsageMetadata +from haystack import logging +from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message +from haystack.dataclasses import ( + AsyncStreamingCallbackT, + ComponentInfo, + FileContent, + FinishReason, + ImageContent, + StreamingCallbackT, + StreamingChunk, + TextContent, + ToolCall, + ToolCallDelta, + ToolCallResult, +) +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ReasoningContent +from haystack.tools import ( + ToolsType, + flatten_tools_or_toolsets, +) +from jsonref import replace_refs + +logger = logging.getLogger(__name__) + +# Mapping from Google GenAI finish reasons to Haystack FinishReason values +FINISH_REASON_MAPPING: dict[str, FinishReason] = { + "STOP": "stop", + "MAX_TOKENS": "length", + "SAFETY": "content_filter", + "BLOCKLIST": "content_filter", + "PROHIBITED_CONTENT": "content_filter", + "SPII": "content_filter", + "IMAGE_SAFETY": "content_filter", +} + +# Google GenAI supported image MIME types based on documentation +# https://ai.google.dev/gemini-api/docs/image-understanding?lang=python#supported-formats +GOOGLE_GENAI_SUPPORTED_MIME_TYPES = { + "image/png": "png", + "image/jpeg": "jpeg", + "image/jpg": "jpeg", # Common alias + "image/webp": "webp", + "image/heic": "heic", + "image/heif": "heif", +} + + +def _process_thinking_config(generation_kwargs: dict[str, Any]) -> dict[str, Any]: + """ + Process thinking configuration from generation_kwargs. + + :param generation_kwargs: The generation configuration dictionary. + :returns: Updated generation_kwargs with thinking_config if applicable. + """ + if "thinking_budget" in generation_kwargs: + thinking_budget = generation_kwargs.pop("thinking_budget") + + # Basic type validation + if not isinstance(thinking_budget, int): + logger.warning( + f"Invalid thinking_budget type: {type(thinking_budget)}. Expected int, using dynamic allocation." + ) + # fall back to default: dynamic thinking budget allocation + thinking_budget = -1 + + # Create thinking config + thinking_config = types.ThinkingConfig(thinking_budget=thinking_budget, include_thoughts=True) + generation_kwargs["thinking_config"] = thinking_config + + if "thinking_level" in generation_kwargs: + thinking_level = generation_kwargs.pop("thinking_level") + + # Basic type validation + if not isinstance(thinking_level, str): + logger.warning( + f"Invalid thinking_level type: {type(thinking_level).__name__}. Expected str, " + f"falling back to THINKING_LEVEL_UNSPECIFIED." + ) + thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + else: + # Convert to uppercase for case-insensitive matching + thinking_level_upper = thinking_level.upper() + + # Check if the uppercase value is a valid ThinkingLevel enum member + valid_levels = [level.value for level in types.ThinkingLevel] + if thinking_level_upper not in valid_levels: + logger.warning( + f"Invalid thinking_level value: '{thinking_level}'. " + f"Must be one of: {valid_levels} (case-insensitive). " + "Falling back to THINKING_LEVEL_UNSPECIFIED." + ) + thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + else: + # Parse valid string to ThinkingLevel enum + thinking_level = types.ThinkingLevel(thinking_level_upper) + + # Create thinking config with level + thinking_config = types.ThinkingConfig(thinking_level=thinking_level, include_thoughts=True) + generation_kwargs["thinking_config"] = thinking_config + + return generation_kwargs + def remove_key_from_schema( schema: dict[str, Any] | list[Any] | Any, target_key: str @@ -29,3 +137,629 @@ def remove_key_from_schema( return [remove_key_from_schema(item, target_key) for item in schema] return schema + + +def _sanitize_tool_schema(tool_schema: dict[str, Any]) -> dict[str, Any]: + """ + Sanitizes a tool schema to remove any keys that are not supported by Google Gen AI. + + Google Gen AI does not support additionalProperties, $schema, $defs, or $ref in the tool schema. + + :param tool_schema: The tool schema to sanitize. + :returns: The sanitized tool schema. + """ + # google Gemini does not support additionalProperties and $schema in the tool schema + sanitized_schema = remove_key_from_schema(tool_schema, "additionalProperties") + sanitized_schema = remove_key_from_schema(sanitized_schema, "$schema") + # expand $refs in the tool schema + expanded_schema = replace_refs(sanitized_schema) + # and remove the $defs key leaving the rest of the schema + final_schema = remove_key_from_schema(expanded_schema, "$defs") + + if not isinstance(final_schema, dict): + msg = "Tool schema must be a dictionary after sanitization" + raise ValueError(msg) + + return final_schema + + +def _convert_message_to_google_genai_format(message: ChatMessage) -> types.Content: + """ + Converts a Haystack ChatMessage to Google Gen AI Content format. + + :param message: The Haystack ChatMessage to convert. + :returns: Google Gen AI Content object. + """ + # Check if message has content + if not message._content: + msg = "A `ChatMessage` must contain at least one content part." + raise ValueError(msg) + + parts = [] + + # Check if this message has thought signatures from a previous response + # These need to be reconstructed in their original part structure + thought_signatures = message.meta.get("thought_signatures", []) if message.meta else [] + + # If we have thought signatures, we need to reconstruct the exact part structure + # from the previous assistant response to maintain multi-turn thinking context + if thought_signatures and message.is_from(ChatRole.ASSISTANT): + # Track which tool calls we've used (to handle multiple tool calls) + tool_call_index = 0 + + # Reconstruct parts with their original thought signatures + for sig_info in thought_signatures: + part_dict: dict[str, Any] = {} + + # Check what type of content this part had + if sig_info.get("has_text"): + # Find the corresponding text content + if sig_info.get("is_thought"): + # This was a thought part - find it in reasoning content + if message.reasoning: + part_dict["text"] = message.reasoning.reasoning_text + part_dict["thought"] = True + else: + # Regular text part + part_dict["text"] = message.text or "" + + if sig_info.get("has_function_call"): + # Find the corresponding tool call by index + if message.tool_calls and tool_call_index < len(message.tool_calls): + tool_call = message.tool_calls[tool_call_index] + part_dict["function_call"] = types.FunctionCall( + id=tool_call.id, name=tool_call.tool_name, args=tool_call.arguments + ) + tool_call_index += 1 # Move to next tool call for next part + + # Add the thought signature to preserve context + part_dict["thought_signature"] = sig_info["signature"] + + parts.append(types.Part(**part_dict)) + + # If we reconstructed from signatures, we're done + if parts: + role = "model" # Assistant messages with signatures are always from the model + return types.Content(role=role, parts=parts) + + # Standard processing for messages without thought signatures + for content_part in message._content: + if isinstance(content_part, TextContent): + # Only add text parts that are not empty to avoid unnecessary empty text parts + if content_part.text.strip(): + parts.append(types.Part(text=content_part.text)) + + elif isinstance(content_part, (ImageContent, FileContent)): + cls_name = content_part.__class__.__name__ + + if not message.is_from(ChatRole.USER): + msg = f"{cls_name} is only supported for user messages" + raise ValueError(msg) + + # Validate image MIME type and format + if not content_part.mime_type: + msg = f"Mime type is required to use {cls_name} with GoogleGenAIChatGenerator" + raise ValueError(msg) + + if ( + isinstance(content_part, ImageContent) + and content_part.mime_type not in GOOGLE_GENAI_SUPPORTED_MIME_TYPES + ): + supported_types = list(GOOGLE_GENAI_SUPPORTED_MIME_TYPES.keys()) + msg = ( + f"Unsupported image MIME type: {content_part.mime_type}. " + f"Google AI supports the following MIME types: {supported_types}" + ) + raise ValueError(msg) + + # Use inline data approach + try: + base64_data = ( + content_part.base64_data if isinstance(content_part, FileContent) else content_part.base64_image + ) + bytes_data = base64.b64decode(base64_data) + + # Create Part using from_bytes method + file_part = types.Part.from_bytes(data=bytes_data, mime_type=content_part.mime_type) + parts.append(file_part) + + except Exception as e: + msg = f"Failed to process {cls_name} data: {e}" + raise RuntimeError(msg) from e + + elif isinstance(content_part, ToolCall): + parts.append( + types.Part( + function_call=types.FunctionCall( + id=content_part.id, name=content_part.tool_name, args=content_part.arguments + ) + ) + ) + + elif isinstance(content_part, ToolCallResult): + if isinstance(content_part.result, str): + parts.append( + types.Part( + function_response=types.FunctionResponse( + id=content_part.origin.id, + name=content_part.origin.tool_name, + response={"result": content_part.result}, + ) + ) + ) + elif isinstance(content_part.result, list): + tool_call_result_parts: list[types.FunctionResponsePart] = [] + for item in content_part.result: + if isinstance(item, TextContent): + tool_call_result_parts.append( + types.FunctionResponsePart( + inline_data=types.FunctionResponseBlob( + data=item.text.encode("utf-8"), mime_type="text/plain" + ), + ) + ) + elif isinstance(item, ImageContent): + tool_call_result_parts.append( + types.FunctionResponsePart( + inline_data=types.FunctionResponseBlob( + data=base64.b64decode(item.base64_image), mime_type=item.mime_type + ), + ) + ) + else: + msg = ( + "Unsupported content type in tool call result list. " + "Only TextContent and ImageContent are supported." + ) + raise ValueError(msg) + parts.append( + types.Part( + function_response=types.FunctionResponse( + id=content_part.origin.id, + name=content_part.origin.tool_name, + parts=tool_call_result_parts, + # the response field is mandatory, but in this case the LLM just needs multimodal parts + response={"result": ""}, + ) + ) + ) + else: + msg = "Unsupported content type in tool call result" + raise ValueError(msg) + elif isinstance(content_part, ReasoningContent): + # Reasoning content is for human transparency only, not for maintaining LLM context + # Thought signatures (stored in message.meta) handle context preservation + # Leave this here so we don't implement reasoning content handling in the future accidentally + pass + + # Determine role + if message.is_from(ChatRole.USER) or message.tool_call_results: + role = "user" + elif message.is_from(ChatRole.ASSISTANT): + role = "model" + elif message.is_from(ChatRole.SYSTEM): + # System messages will be handled separately as system instruction + # When we convert a list of ChatMessage to be sent to google genai, + # we need to handle system messages separately as system instruction and we only take the first message + # as the system instruction - if it is present. + # + # If we find any additional system messages, we will treat them as user messages + role = "user" + else: + msg = f"Unsupported message role: {message._role}" + raise ValueError(msg) + + return types.Content(role=role, parts=parts) + + +def _convert_tools_to_google_genai_format(tools: ToolsType) -> list[types.Tool]: + """ + Converts a list of Haystack Tools, Toolsets, or a mix to Google Gen AI Tool format. + + :param tools: List of Haystack Tool and/or Toolset objects, or a single Toolset. + :returns: List of Google Gen AI Tool objects. + """ + # Flatten Tools and Toolsets into a single list of Tools + flattened_tools = flatten_tools_or_toolsets(tools) + + function_declarations: list[types.FunctionDeclaration] = [] + for tool in flattened_tools: + parameters = _sanitize_tool_schema(tool.parameters) + function_declarations.append( + types.FunctionDeclaration( + name=tool.name, description=tool.description, parameters=types.Schema(**parameters) + ) + ) + + # Return a single Tool object with all function declarations as in the Google GenAI docs + # we could also return multiple Tool objects, doesn't seem to make a difference + # revisit this decision + return [types.Tool(function_declarations=function_declarations)] + + +def _convert_usage_metadata_to_serializable( + usage_metadata: UsageMetadata | GenerateContentResponseUsageMetadata | None, +) -> dict[str, Any]: + """Build a JSON-serializable usage dict from a UsageMetadata object. + + Iterates over known UsageMetadata attribute names and adds each non-None value + in serialized form. Full list of fields: https://ai.google.dev/api/generate-content#UsageMetadata + """ + + def serialize(val: Any) -> Any: + if val is None: + return None + if isinstance(val, (str, int, float, bool)): + return val + if isinstance(val, list): + return [serialize(item) for item in val] + token_count = getattr(val, "token_count", None) or getattr(val, "tokenCount", None) + if hasattr(val, "modality") and token_count is not None: + mod = getattr(val, "modality", None) + mod_str = getattr(mod, "value", getattr(mod, "name", str(mod))) if mod is not None else None + return {"modality": mod_str, "token_count": token_count} + if hasattr(val, "name"): + return getattr(val, "value", getattr(val, "name", val)) + return val + + if not usage_metadata: + return {} + + _usage_attr_names = ( + "prompt_token_count", + "candidates_token_count", + "total_token_count", + "cache_tokens_details", + "candidates_tokens_details", + "prompt_tokens_details", + "tool_use_prompt_token_count", + "tool_use_prompt_tokens_details", + ) + result: dict[str, Any] = {} + for attr in _usage_attr_names: + val = getattr(usage_metadata, attr, None) + if val is not None: + result[attr] = serialize(val) + return result + + +def _convert_google_genai_response_to_chatmessage(response: types.GenerateContentResponse, model: str) -> ChatMessage: + """ + Converts a Google Gen AI response to a Haystack ChatMessage. + + :param response: The response from Google Gen AI. + :param model: The model name. + :returns: A Haystack ChatMessage. + """ + text_parts = [] + tool_calls = [] + reasoning_parts = [] + thought_signatures = [] # Store thought signatures for multi-turn context + + # Extract text, function calls, thoughts, and thought signatures from response + finish_reason = None + if response.candidates: + candidate = response.candidates[0] + finish_reason = getattr(candidate, "finish_reason", None) + if candidate.content is not None and candidate.content.parts is not None: + for i, part in enumerate(candidate.content.parts): + # Check for thought signature on this part + if hasattr(part, "thought_signature") and part.thought_signature: + # Store the thought signature with its part index for reconstruction + thought_signatures.append( + { + "part_index": i, + "signature": part.thought_signature, + "has_text": part.text is not None, + "has_function_call": part.function_call is not None, + "is_thought": hasattr(part, "thought") and part.thought, + } + ) + + if part.text is not None and not (hasattr(part, "thought") and part.thought): + text_parts.append(part.text) + if part.function_call is not None: + tool_call = ToolCall( + tool_name=part.function_call.name or "", + arguments=dict(part.function_call.args) if part.function_call.args else {}, + id=part.function_call.id, + ) + tool_calls.append(tool_call) + # Handle thought parts for Gemini 2.5 series + if hasattr(part, "thought") and part.thought: + # Extract thought content + if part.text: + reasoning_parts.append(part.text) + + # Combine text parts + text = " ".join(text_parts) if text_parts else "" + + usage_metadata = response.usage_metadata + + # Create usage metadata including thinking tokens if available + usage = { + "prompt_tokens": getattr(usage_metadata, "prompt_token_count", 0), + "completion_tokens": getattr(usage_metadata, "candidates_token_count", 0), + "total_tokens": getattr(usage_metadata, "total_token_count", 0), + } + + # Add thinking token count if available + if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: + usage["thoughts_token_count"] = usage_metadata.thoughts_token_count + + # Add cached content token count if available (implicit or explicit context caching) + if ( + usage_metadata + and hasattr(usage_metadata, "cached_content_token_count") + and usage_metadata.cached_content_token_count + ): + usage["cached_content_token_count"] = usage_metadata.cached_content_token_count + + usage.update(_convert_usage_metadata_to_serializable(usage_metadata)) + + # Create meta with reasoning content and thought signatures if available + meta: dict[str, Any] = { + "model": model, + "finish_reason": FINISH_REASON_MAPPING.get(finish_reason or ""), + "usage": usage, + } + + # Add thought signatures to meta if present (for multi-turn context preservation) + if thought_signatures: + meta["thought_signatures"] = thought_signatures + + # Create ReasoningContent object if there are reasoning parts + reasoning_content = None + if reasoning_parts: + reasoning_text = " ".join(reasoning_parts) + reasoning_content = ReasoningContent(reasoning_text=reasoning_text) + + # Create ChatMessage + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta, reasoning=reasoning_content) + + return message + + +def _convert_google_chunk_to_streaming_chunk( + chunk: types.GenerateContentResponse, + index: int, + component_info: ComponentInfo, + model: str, +) -> StreamingChunk: + """ + Convert a chunk from Google Gen AI to a Haystack StreamingChunk. + + :param chunk: The chunk from Google Gen AI. + :param index: The index of the chunk. + :param component_info: The component info. + :param model: The model name. + :returns: A StreamingChunk object. + """ + content = "" + tool_calls: list[ToolCallDelta] = [] + finish_reason = None + reasoning_deltas: list[dict[str, str]] = [] + thought_signature_deltas: list[dict[str, Any]] = [] # Track thought signatures in streaming + + if chunk.candidates: + candidate = chunk.candidates[0] + finish_reason = getattr(candidate, "finish_reason", None) + + usage_metadata = chunk.usage_metadata + + usage = { + "prompt_tokens": getattr(usage_metadata, "prompt_token_count", 0) if usage_metadata else 0, + "completion_tokens": getattr(usage_metadata, "candidates_token_count", 0) if usage_metadata else 0, + "total_tokens": getattr(usage_metadata, "total_token_count", 0) if usage_metadata else 0, + } + + # Add thinking token count if available + if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: + usage["thoughts_token_count"] = usage_metadata.thoughts_token_count + + if candidate.content and candidate.content.parts: + tc_index = -1 + for part_index, part in enumerate(candidate.content.parts): + # Check for thought signature on this part (for multi-turn context) + if hasattr(part, "thought_signature") and part.thought_signature: + thought_signature_deltas.append( + { + "part_index": part_index, + "signature": part.thought_signature, + "has_text": part.text is not None, + "has_function_call": part.function_call is not None, + "is_thought": hasattr(part, "thought") and part.thought, + } + ) + + if part.text is not None and not (hasattr(part, "thought") and part.thought): + content += part.text + + elif part.function_call: + tc_index += 1 + tool_calls.append( + ToolCallDelta( + # Google GenAI does not provide index, but it is required for tool calls + index=tc_index, + id=part.function_call.id, + tool_name=part.function_call.name or "", + arguments=json.dumps(part.function_call.args) if part.function_call.args else None, + ) + ) + + # Handle thought parts for Gemini 2.5 series + elif hasattr(part, "thought") and part.thought: + thought_delta = { + "type": "reasoning", + "content": part.text if part.text else "", + } + reasoning_deltas.append(thought_delta) + + # start is only used by print_streaming_chunk. We try to make a reasonable assumption here but it should not be + # a problem if we change it in the future. + start = index == 0 or len(tool_calls) > 0 + + # Create meta with reasoning deltas and thought signatures if available + meta: dict[str, Any] = { + "received_at": datetime.now(timezone.utc).isoformat(), + "model": model, + "usage": usage, + } + + # Add reasoning deltas to meta if available + if reasoning_deltas: + meta["reasoning_deltas"] = reasoning_deltas + + # Add thought signature deltas to meta if available (for multi-turn context) + if thought_signature_deltas: + meta["thought_signature_deltas"] = thought_signature_deltas + + return StreamingChunk( + content="" if tool_calls else content, # prioritize tool calls over content when both are present + tool_calls=tool_calls, + component_info=component_info, + index=index, + start=start, + finish_reason=FINISH_REASON_MAPPING.get(finish_reason or ""), + meta=meta, + ) + + +def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) -> ChatMessage: + """ + Aggregate streaming chunks into a final ChatMessage with reasoning content and thought signatures. + + This method extends the standard streaming chunk aggregation to handle Google GenAI's + specific reasoning content, thinking token usage, and thought signatures for multi-turn context. + + :param chunks: List of streaming chunks to aggregate. + :returns: Final ChatMessage with aggregated content, reasoning, and thought signatures. + """ + + # Use the generic aggregator for standard content (text, tool calls, basic meta) + message = _convert_streaming_chunks_to_chat_message(chunks) + + # Now enhance with Google-specific features: reasoning content, thinking token usage, and thought signatures + reasoning_text_parts: list[str] = [] + thought_signatures: list[dict[str, Any]] = [] + thoughts_token_count = None + + for chunk in chunks: + # Extract reasoning deltas + if chunk.meta and "reasoning_deltas" in chunk.meta: + reasoning_deltas = chunk.meta["reasoning_deltas"] + if isinstance(reasoning_deltas, list): + for delta in reasoning_deltas: + if delta.get("type") == "reasoning": + reasoning_text_parts.append(delta.get("content", "")) + + # Extract thought signature deltas (for multi-turn context preservation) + if chunk.meta and "thought_signature_deltas" in chunk.meta: + signature_deltas = chunk.meta["thought_signature_deltas"] + if isinstance(signature_deltas, list): + # Aggregate thought signatures - they should come from the final chunks + # We'll keep the last set of signatures as they represent the complete state + thought_signatures = signature_deltas + + # Extract thinking token usage (from the last chunk that has it) + if chunk.meta and "usage" in chunk.meta: + chunk_usage = chunk.meta["usage"] + if "thoughts_token_count" in chunk_usage: + thoughts_token_count = chunk_usage["thoughts_token_count"] + + # Add thinking token count to usage if present + if thoughts_token_count is not None and "usage" in message.meta: + if message.meta["usage"] is None: + message.meta["usage"] = {} + message.meta["usage"]["thoughts_token_count"] = thoughts_token_count + + # Add thought signatures to meta if present (for multi-turn context preservation) + if thought_signatures: + message.meta["thought_signatures"] = thought_signatures + + # If we have reasoning content, reconstruct the message to include it + # Note: ChatMessage doesn't support adding reasoning after creation, reconstruction is necessary + if reasoning_text_parts: + reasoning_content = ReasoningContent(reasoning_text="".join(reasoning_text_parts)) + return ChatMessage.from_assistant( + text=message.text, tool_calls=message.tool_calls, meta=message.meta, reasoning=reasoning_content + ) + + return message + + +def _handle_streaming_response( + component_info: ComponentInfo, + response_stream: Iterator[types.GenerateContentResponse], + streaming_callback: StreamingCallbackT, + model: str, +) -> dict[str, list[ChatMessage]]: + """ + Handle streaming response from Google Gen AI generate_content_stream. + :param component_info: The component info. + :param response_stream: The streaming response from generate_content_stream. + :param streaming_callback: The callback function for streaming chunks. + :param model: The model name. + :returns: A dictionary with the replies. + """ + + try: + chunks = [] + + chunk = None + for i, chunk in enumerate(response_stream): + streaming_chunk = _convert_google_chunk_to_streaming_chunk( + chunk=chunk, index=i, component_info=component_info, model=model + ) + chunks.append(streaming_chunk) + + # Stream the chunk + streaming_callback(streaming_chunk) + + # Use custom aggregation that supports reasoning content + message = _aggregate_streaming_chunks_with_reasoning(chunks) + return {"replies": [message]} + + except Exception as e: + msg = f"Error in streaming response: {e}" + raise RuntimeError(msg) from e + + +async def _handle_streaming_response_async( + component_info: ComponentInfo, + response_stream: AsyncIterator[types.GenerateContentResponse], + streaming_callback: AsyncStreamingCallbackT, + model: str, +) -> dict[str, list[ChatMessage]]: + """ + Handle async streaming response from Google Gen AI generate_content_stream. + :param component_info: The component info. + :param response_stream: The async streaming response from generate_content_stream. + :param streaming_callback: The async callback function for streaming chunks. + :param model: The model name. + :returns: A dictionary with the replies. + """ + + try: + chunks = [] + + i = 0 + chunk = None + async for chunk in response_stream: + i += 1 + + streaming_chunk = _convert_google_chunk_to_streaming_chunk( + chunk=chunk, index=i, component_info=component_info, model=model + ) + chunks.append(streaming_chunk) + + # Stream the chunk + await streaming_callback(streaming_chunk) + + # Use custom aggregation that supports reasoning content + message = _aggregate_streaming_chunks_with_reasoning(chunks) + return {"replies": [message]} + + except Exception as e: + msg = f"Error in async streaming response: {e}" + raise RuntimeError(msg) from e diff --git a/integrations/google_genai/tests/test_chat_generator.py b/integrations/google_genai/tests/test_chat_generator.py index ad6d7e95a5..58c99936da 100644 --- a/integrations/google_genai/tests/test_chat_generator.py +++ b/integrations/google_genai/tests/test_chat_generator.py @@ -3,18 +3,16 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import base64 import os -from unittest.mock import Mock import pytest -from google.genai import types from haystack.components.agents import Agent from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ( ChatMessage, ChatRole, ComponentInfo, + FileContent, ImageContent, ReasoningContent, StreamingChunk, @@ -26,9 +24,6 @@ from haystack_integrations.components.generators.google_genai.chat.chat_generator import ( GoogleGenAIChatGenerator, - _convert_google_genai_response_to_chatmessage, - _convert_message_to_google_genai_format, - _convert_usage_metadata_to_serializable, ) @@ -61,402 +56,7 @@ def tools(): ] -class TestStreamingChunkConversion: - def test_convert_google_chunk_to_streaming_chunk_text_only(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - mock_chunk = Mock() - mock_candidate = Mock() - mock_candidate.finish_reason = "STOP" - mock_chunk.candidates = [mock_candidate] - - mock_content = Mock() - mock_content.parts = [] - mock_part = Mock() - mock_part.text = "Hello, world!" - mock_part.function_call = None - # Explicitly set thought=False to simulate a regular (non-thought) part - mock_part.thought = False - mock_content.parts.append(mock_part) - mock_candidate.content = mock_content - - chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=mock_chunk, - index=0, - component_info=component_info, - ) - - assert chunk.content == "Hello, world!" - assert chunk.tool_calls == [] - assert chunk.finish_reason == "stop" - assert chunk.index == 0 - assert "received_at" in chunk.meta - assert chunk.component_info == component_info - - def test_convert_google_chunk_to_streaming_chunk_tool_call(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - mock_chunk = Mock() - mock_candidate = Mock() - mock_candidate.finish_reason = "STOP" - mock_chunk.candidates = [mock_candidate] - - mock_content = Mock() - mock_content.parts = [] - mock_part = Mock() - mock_part.text = None - mock_function_call = Mock() - mock_function_call.name = "weather" - mock_function_call.args = {"city": "Paris"} - mock_function_call.id = "call_123" - mock_part.function_call = mock_function_call - mock_content.parts.append(mock_part) - mock_candidate.content = mock_content - - chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=mock_chunk, index=0, component_info=component_info - ) - - assert chunk.content == "" - assert chunk.tool_calls is not None - assert len(chunk.tool_calls) == 1 - assert chunk.tool_calls[0].tool_name == "weather" - assert chunk.tool_calls[0].arguments == '{"city": "Paris"}' - assert chunk.tool_calls[0].id == "call_123" - assert chunk.finish_reason == "stop" - assert chunk.index == 0 - assert "received_at" in chunk.meta - assert chunk.component_info == component_info - - def test_convert_google_chunk_to_streaming_chunk_mixed_content(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - mock_chunk = Mock() - mock_candidate = Mock() - mock_candidate.finish_reason = "STOP" - mock_chunk.candidates = [mock_candidate] - - mock_content = Mock() - mock_content.parts = [] - - mock_text_part = Mock() - mock_text_part.text = "I'll check the weather for you." - mock_text_part.function_call = None - # Explicitly set thought=False to simulate a regular (non-thought) part - mock_text_part.thought = False - mock_content.parts.append(mock_text_part) - - mock_tool_part = Mock() - mock_tool_part.text = None - mock_function_call = Mock() - mock_function_call.name = "weather" - mock_function_call.args = {"city": "London"} - mock_function_call.id = "call_456" - mock_tool_part.function_call = mock_function_call - mock_content.parts.append(mock_tool_part) - - mock_candidate.content = mock_content - - chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=mock_chunk, index=0, component_info=component_info - ) - - # When both text and tool calls are present, tool calls are prioritized - assert chunk.content == "" - assert chunk.tool_calls is not None - assert len(chunk.tool_calls) == 1 - assert chunk.tool_calls[0].tool_name == "weather" - assert chunk.tool_calls[0].arguments == '{"city": "London"}' - assert chunk.finish_reason == "stop" - assert chunk.component_info == component_info - - def test_convert_google_chunk_to_streaming_chunk_empty_parts(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - mock_chunk = Mock() - mock_candidate = Mock() - mock_content = Mock() - mock_content.parts = [] - mock_candidate.content = mock_content - mock_chunk.candidates = [mock_candidate] - - chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=mock_chunk, index=0, component_info=component_info - ) - - assert chunk.content == "" - assert chunk.tool_calls == [] - assert chunk.component_info == component_info - - def test_convert_google_chunk_to_streaming_chunk_real_example(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - # Chunk 1: Text only - chunk1_parts = [ - types.Part( - text="I'll get the weather information for Paris and Berlin", function_call=None, function_response=None - ) - ] - chunk1_content = types.Content(role="model", parts=chunk1_parts) - chunk1_candidate = types.Candidate( - content=chunk1_content, - finish_reason=None, - index=None, - safety_ratings=None, - citation_metadata=None, - grounding_metadata=None, - finish_message=None, - token_count=None, - logprobs_result=None, - avg_logprobs=None, - url_context_metadata=None, - ) - chunk1_usage = types.GenerateContentResponseUsageMetadata( - prompt_token_count=217, candidates_token_count=None, total_token_count=217 - ) - chunk1 = types.GenerateContentResponse( - candidates=[chunk1_candidate], - usage_metadata=chunk1_usage, - model_version="gemini-2.5-flash", - response_id=None, - create_time=None, - prompt_feedback=None, - automatic_function_calling_history=None, - parsed=None, - ) - - streaming_chunk1 = component._convert_google_chunk_to_streaming_chunk( - chunk=chunk1, index=0, component_info=component_info - ) - assert streaming_chunk1.content == "I'll get the weather information for Paris and Berlin" - assert streaming_chunk1.tool_calls == [] - assert streaming_chunk1.finish_reason is None - assert streaming_chunk1.index == 0 - assert "received_at" in streaming_chunk1.meta - assert streaming_chunk1.meta["model"] == "gemini-2.5-flash" - assert "usage" in streaming_chunk1.meta - assert streaming_chunk1.meta["usage"]["prompt_tokens"] == 217 - assert streaming_chunk1.meta["usage"]["completion_tokens"] is None - assert streaming_chunk1.meta["usage"]["total_tokens"] == 217 - assert streaming_chunk1.component_info == component_info - - # Chunk 2: Text only - chunk2_parts = [ - types.Part(text=" and present it in a structured format.", function_call=None, function_response=None) - ] - chunk2_content = types.Content(role="model", parts=chunk2_parts) - chunk2_candidate = types.Candidate( - content=chunk2_content, - finish_reason=None, - index=None, - safety_ratings=None, - citation_metadata=None, - grounding_metadata=None, - finish_message=None, - token_count=None, - logprobs_result=None, - avg_logprobs=None, - url_context_metadata=None, - ) - chunk2_usage = types.GenerateContentResponseUsageMetadata( - prompt_token_count=217, candidates_token_count=None, total_token_count=217 - ) - chunk2 = types.GenerateContentResponse( - candidates=[chunk2_candidate], - usage_metadata=chunk2_usage, - model_version="gemini-2.5-flash", - response_id=None, - create_time=None, - prompt_feedback=None, - automatic_function_calling_history=None, - parsed=None, - ) - - streaming_chunk2 = component._convert_google_chunk_to_streaming_chunk( - chunk=chunk2, index=1, component_info=component_info - ) - assert streaming_chunk2.content == " and present it in a structured format." - assert streaming_chunk2.tool_calls == [] - assert streaming_chunk2.finish_reason is None - assert streaming_chunk2.index == 1 - assert "received_at" in streaming_chunk2.meta - assert streaming_chunk2.meta["model"] == "gemini-2.5-flash" - assert "usage" in streaming_chunk2.meta - assert streaming_chunk2.meta["usage"]["prompt_tokens"] == 217 - assert streaming_chunk2.meta["usage"]["completion_tokens"] is None - assert streaming_chunk2.meta["usage"]["total_tokens"] == 217 - assert streaming_chunk2.component_info == component_info - - # Chunk 3: Multiple tool calls (6 function calls) for 2 cities with 3 tools each - fc1 = types.FunctionCall(id=None, name="get_weather", args={"city": "Paris"}) - fc2 = types.FunctionCall(id=None, name="get_population", args={"city": "Paris"}) - fc3 = types.FunctionCall(id=None, name="get_time", args={"city": "Paris"}) - fc4 = types.FunctionCall(id=None, name="get_weather", args={"city": "Berlin"}) - fc5 = types.FunctionCall(id=None, name="get_population", args={"city": "Berlin"}) - fc6 = types.FunctionCall(id=None, name="get_time", args={"city": "Berlin"}) - - parts = [ - types.Part(text=None, function_call=fc1, function_response=None), - types.Part(text=None, function_call=fc2, function_response=None), - types.Part(text=None, function_call=fc3, function_response=None), - types.Part(text=None, function_call=fc4, function_response=None), - types.Part(text=None, function_call=fc5, function_response=None), - types.Part(text=None, function_call=fc6, function_response=None), - ] - - content = types.Content(role="model", parts=parts) - candidate = types.Candidate( - content=content, - finish_reason=types.FinishReason.STOP, - index=None, - safety_ratings=None, - citation_metadata=None, - grounding_metadata=None, - finish_message=None, - token_count=None, - logprobs_result=None, - avg_logprobs=None, - url_context_metadata=None, - ) - - usage_metadata = types.GenerateContentResponseUsageMetadata( - prompt_token_count=144, candidates_token_count=121, total_token_count=265 - ) - chunk = types.GenerateContentResponse( - candidates=[candidate], - usage_metadata=usage_metadata, - model_version="gemini-2.5-flash", - response_id=None, - create_time=None, - prompt_feedback=None, - automatic_function_calling_history=None, - parsed=None, - ) - - streaming_chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=chunk, index=2, component_info=component_info - ) - assert streaming_chunk.content == "" - assert streaming_chunk.tool_calls is not None - assert len(streaming_chunk.tool_calls) == 6 - assert streaming_chunk.finish_reason == "stop" - assert streaming_chunk.index == 2 - assert "received_at" in streaming_chunk.meta - assert streaming_chunk.meta["model"] == "gemini-2.5-flash" - assert streaming_chunk.component_info == component_info - assert "usage" in streaming_chunk.meta - assert streaming_chunk.meta["usage"]["prompt_tokens"] == 144 - assert streaming_chunk.meta["usage"]["completion_tokens"] == 121 - assert streaming_chunk.meta["usage"]["total_tokens"] == 265 - - assert streaming_chunk.tool_calls[0].tool_name == "get_weather" - assert streaming_chunk.tool_calls[0].arguments == '{"city": "Paris"}' - assert streaming_chunk.tool_calls[0].id is None - assert streaming_chunk.tool_calls[0].index == 0 - - assert streaming_chunk.tool_calls[1].tool_name == "get_population" - assert streaming_chunk.tool_calls[1].arguments == '{"city": "Paris"}' - assert streaming_chunk.tool_calls[1].id is None - assert streaming_chunk.tool_calls[1].index == 1 - - assert streaming_chunk.tool_calls[2].tool_name == "get_time" - assert streaming_chunk.tool_calls[2].arguments == '{"city": "Paris"}' - assert streaming_chunk.tool_calls[2].id is None - assert streaming_chunk.tool_calls[2].index == 2 - - assert streaming_chunk.tool_calls[3].tool_name == "get_weather" - assert streaming_chunk.tool_calls[3].arguments == '{"city": "Berlin"}' - assert streaming_chunk.tool_calls[3].id is None - assert streaming_chunk.tool_calls[3].index == 3 - - assert streaming_chunk.tool_calls[4].tool_name == "get_population" - assert streaming_chunk.tool_calls[4].arguments == '{"city": "Berlin"}' - assert streaming_chunk.tool_calls[4].id is None - assert streaming_chunk.tool_calls[4].index == 4 - - assert streaming_chunk.tool_calls[5].tool_name == "get_time" - assert streaming_chunk.tool_calls[5].arguments == '{"city": "Berlin"}' - assert streaming_chunk.tool_calls[5].id is None - assert streaming_chunk.tool_calls[5].index == 5 - - -def test_convert_google_genai_response_to_chatmessage_parses_cached_tokens(monkeypatch): - """When the API response includes cached_content_token_count in usage_metadata, it is parsed into meta['usage'].""" - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - - # Minimal candidate with one text part - mock_part = Mock() - mock_part.text = "Four." - mock_part.function_call = None - mock_part.thought = False - mock_part.thought_signature = None - mock_content = Mock() - mock_content.parts = [mock_part] - mock_candidate = Mock() - mock_candidate.content = mock_content - mock_candidate.finish_reason = "STOP" - - # Usage metadata including cached tokens (as returned by API when cache is used) - mock_usage = Mock() - mock_usage.prompt_token_count = 1000 - mock_usage.candidates_token_count = 5 - mock_usage.total_token_count = 1005 - mock_usage.cached_content_token_count = 800 - mock_usage.thoughts_token_count = None - - mock_response = Mock() - mock_response.candidates = [mock_candidate] - mock_response.usage_metadata = mock_usage - - message = _convert_google_genai_response_to_chatmessage(mock_response, "gemini-2.5-flash") - - assert message.meta is not None - assert "usage" in message.meta - usage = message.meta["usage"] - assert usage["prompt_tokens"] == 1000 - assert usage["completion_tokens"] == 5 - assert usage["total_tokens"] == 1005 - assert usage["cached_content_token_count"] == 800 - - -def test_convert_usage_metadata_to_serializable(): - """_convert_usage_metadata_to_serializable builds a serializable dict from a UsageMetadata object.""" - assert _convert_usage_metadata_to_serializable(None) == {} - assert _convert_usage_metadata_to_serializable(False) == {} - - usage_metadata = types.GenerateContentResponseUsageMetadata( - prompt_token_count=100, - candidates_token_count=5, - total_token_count=105, - ) - result = _convert_usage_metadata_to_serializable(usage_metadata) - assert result["prompt_token_count"] == 100 - assert result["candidates_token_count"] == 5 - assert result["total_token_count"] == 105 - assert len(result) == 3 - - # Serialization of zero and composite types (ModalityTokenCount, lists) - modality_token_count = types.ModalityTokenCount(modality=types.Modality.TEXT, tokenCount=100) - usage_with_details = types.GenerateContentResponseUsageMetadata( - prompt_token_count=0, - candidates_tokens_details=[modality_token_count], - ) - result2 = _convert_usage_metadata_to_serializable(usage_with_details) - assert result2["prompt_token_count"] == 0 - assert result2["candidates_tokens_details"] == [{"modality": "TEXT", "token_count": 100}] - - -class TestGoogleGenAIChatGenerator: +class TestGoogleGenAIChatGeneratorInitSerDe: def test_init_default(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") component = GoogleGenAIChatGenerator() @@ -504,28 +104,6 @@ def test_init_with_toolset(self, tools, monkeypatch): generator = GoogleGenAIChatGenerator(tools=toolset) assert generator._tools == toolset - def test_to_dict_with_toolset(self, tools, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - toolset = Toolset(tools) - generator = GoogleGenAIChatGenerator(tools=toolset) - data = generator.to_dict() - - assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset" - assert "tools" in data["init_parameters"]["tools"]["data"] - assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools) - - def test_from_dict_with_toolset(self, tools, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - toolset = Toolset(tools) - component = GoogleGenAIChatGenerator(tools=toolset) - data = component.to_dict() - - deserialized_component = GoogleGenAIChatGenerator.from_dict(data) - - assert isinstance(deserialized_component._tools, Toolset) - assert len(deserialized_component._tools) == len(tools) - assert all(isinstance(tool, Tool) for tool in deserialized_component._tools) - def test_init_with_mixed_tools_and_toolsets(self, monkeypatch): """Test initialization with a mixed list of Tools and Toolsets.""" monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") @@ -561,6 +139,28 @@ def test_init_with_mixed_tools_and_toolsets(self, monkeypatch): assert isinstance(generator._tools[1], Toolset) assert isinstance(generator._tools[2], Tool) + def test_to_dict_with_toolset(self, tools, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + toolset = Toolset(tools) + generator = GoogleGenAIChatGenerator(tools=toolset) + data = generator.to_dict() + + assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset" + assert "tools" in data["init_parameters"]["tools"]["data"] + assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools) + + def test_from_dict_with_toolset(self, tools, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + toolset = Toolset(tools) + component = GoogleGenAIChatGenerator(tools=toolset) + data = component.to_dict() + + deserialized_component = GoogleGenAIChatGenerator.from_dict(data) + + assert isinstance(deserialized_component._tools, Toolset) + assert len(deserialized_component._tools) == len(tools) + assert all(isinstance(tool, Tool) for tool in deserialized_component._tools) + def test_serde_with_mixed_tools_and_toolsets(self, monkeypatch): """Test serialization/deserialization with mixed Tools and Toolsets.""" monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") @@ -599,267 +199,12 @@ def test_serde_with_mixed_tools_and_toolsets(self, monkeypatch): assert len(restored._tools[1]) == 1 -class TestMessagesConversion: - def test_convert_message_to_google_genai_format_complex(self): - """ - Test that the GoogleGenAIChatGenerator can convert a complex sequence of ChatMessages to Google GenAI format. - In particular, we check that different tool results are handled properly in sequence. - """ - messages = [ - ChatMessage.from_system("You are good assistant"), - ChatMessage.from_user("What's the weather like in Paris? And how much is 2+2?"), - ChatMessage.from_assistant( - text="", - tool_calls=[ - ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}), - ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}), - ], - ), - ChatMessage.from_tool( - tool_result="22° C", origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) - ), - ChatMessage.from_tool( - tool_result="4", origin=ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}) - ), - ] - - # Test system message handling (should be handled separately in Google GenAI) - system_message = messages[0] - assert system_message.is_from(ChatRole.SYSTEM) - - # Test user message conversion - user_message = messages[1] - google_content = _convert_message_to_google_genai_format(user_message) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert google_content.parts[0].text == "What's the weather like in Paris? And how much is 2+2?" - - # Test assistant message with tool calls - assistant_message = messages[2] - google_content = _convert_message_to_google_genai_format(assistant_message) - assert google_content.role == "model" - assert len(google_content.parts) == 2 - assert google_content.parts[0].function_call.name == "weather" - assert google_content.parts[0].function_call.args == {"city": "Paris"} - assert google_content.parts[1].function_call.name == "math" - assert google_content.parts[1].function_call.args == {"expression": "2+2"} - - # Test tool result messages - tool_result_1 = messages[3] - google_content = _convert_message_to_google_genai_format(tool_result_1) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert google_content.parts[0].function_response.name == "weather" - assert google_content.parts[0].function_response.response == {"result": "22° C"} - - tool_result_2 = messages[4] - google_content = _convert_message_to_google_genai_format(tool_result_2) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert google_content.parts[0].function_response.name == "math" - assert google_content.parts[0].function_response.response == {"result": "4"} - - def test_convert_message_to_google_genai_format_with_single_image(self, test_files_path): - """Test converting a message with a single image to Google GenAI format.""" - apple_path = test_files_path / "apple.jpg" - apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) - - message = ChatMessage.from_user(content_parts=["What's in this image?", apple_content]) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "user" - assert len(google_content.parts) == 2 - - # First part should be text - assert google_content.parts[0].text == "What's in this image?" - - # Second part should be image data - assert hasattr(google_content.parts[1], "inline_data") - assert google_content.parts[1].inline_data is not None - assert google_content.parts[1].inline_data.mime_type == "image/jpeg" - assert google_content.parts[1].inline_data.data is not None - assert len(google_content.parts[1].inline_data.data) > 0 - - def test_convert_message_to_google_genai_format_with_multiple_images(self, test_files_path): - """Test converting a message with multiple images in mixed content to Google GenAI format.""" - apple_path = test_files_path / "apple.jpg" - banana_path = test_files_path / "banana.png" - - apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) - banana_content = ImageContent.from_file_path(banana_path, size=(100, 100)) - - message = ChatMessage.from_user( - content_parts=[ - "Compare these fruits. First:", - apple_content, - "Second:", - banana_content, - "Which is healthier?", - ] - ) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "user" - assert len(google_content.parts) == 5 - - # Verify the exact order is preserved - assert google_content.parts[0].text == "Compare these fruits. First:" - - # First image (apple) - assert hasattr(google_content.parts[1], "inline_data") - assert google_content.parts[1].inline_data.mime_type == "image/jpeg" - assert google_content.parts[1].inline_data.data is not None - - assert google_content.parts[2].text == "Second:" - - # Second image (banana) - assert hasattr(google_content.parts[3], "inline_data") - assert google_content.parts[3].inline_data.mime_type == "image/png" - assert google_content.parts[3].inline_data.data is not None - - assert google_content.parts[4].text == "Which is healthier?" - - def test_convert_message_to_google_genai_format_image_with_minimal_text(self, test_files_path): - """Test converting a message with minimal text and image to Google GenAI format.""" - apple_path = test_files_path / "apple.jpg" - apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) - - # Haystack requires at least one textual part for user messages, so we use minimal text - message = ChatMessage.from_user(content_parts=["", apple_content]) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "user" - # Empty text should be filtered out by our implementation, leaving only the image - assert len(google_content.parts) == 1 - - # Should only have the image part (empty text filtered out) - assert hasattr(google_content.parts[0], "inline_data") - assert google_content.parts[0].inline_data.mime_type == "image/jpeg" - assert google_content.parts[0].inline_data.data is not None - - def test_convert_message_to_google_genai_format_with_thought_signatures(self): - """Test converting an assistant message with thought signatures for multi-turn context preservation.""" - - # Create an assistant message with tool calls and thought signatures in meta - tool_call = ToolCall(id="call_123", tool_name="weather", arguments={"city": "Paris"}) - - # Thought signatures are stored in meta when thinking is enabled with tools - # They must be base64 encoded as per the API requirements - thought_signatures = [ - { - "part_index": 0, - "signature": base64.b64encode(b"encrypted_mock_thought_signature_1").decode("utf-8"), - "has_text": True, - "has_function_call": False, - "is_thought": False, - }, - { - "part_index": 1, - "signature": base64.b64encode(b"encrypted_mock_thought_signature_2").decode("utf-8"), - "has_text": False, - "has_function_call": True, - "is_thought": False, - }, - ] - - message = ChatMessage.from_assistant( - text="I'll check the weather for you", - tool_calls=[tool_call], - meta={"thought_signatures": thought_signatures}, - ) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "model" - assert len(google_content.parts) == 2 - - # First part should have text and its thought signature - assert google_content.parts[0].text == "I'll check the weather for you" - # thought_signature is returned as bytes from the API - assert google_content.parts[0].thought_signature == b"encrypted_mock_thought_signature_1" - - # Second part should have function call and its thought signature - assert google_content.parts[1].function_call.name == "weather" - assert google_content.parts[1].function_call.args == {"city": "Paris"} - assert google_content.parts[1].thought_signature == b"encrypted_mock_thought_signature_2" - - def test_convert_message_to_google_genai_format_with_reasoning_content(self): - """Test that ReasoningContent is properly skipped during conversion.""" - # ReasoningContent is for human transparency only, not sent to the API - reasoning = ReasoningContent(reasoning_text="Of Life, the Universe and Everything...") - - # Create a message with both text and reasoning content - message = ChatMessage.from_assistant(text="Forty-two", reasoning=reasoning) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "model" - assert len(google_content.parts) == 1 - - # Only the text should be included, reasoning content should be skipped - assert google_content.parts[0].text == "Forty-two" - # Verify no thought part was created (reasoning is not sent to API) - assert not hasattr(google_content.parts[0], "thought") or not google_content.parts[0].thought - - def test_convert_message_to_google_genai_format_tool_message(self): - tool_call = ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) - message = ChatMessage.from_tool(tool_result="22° C", origin=tool_call) - google_content = _convert_message_to_google_genai_format(message) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert isinstance(google_content.parts[0].function_response, types.FunctionResponse) - assert google_content.parts[0].function_response.id == "123" - assert google_content.parts[0].function_response.name == "weather" - assert google_content.parts[0].function_response.response == {"result": "22° C"} - assert google_content.parts[0].function_response.parts is None - - def test_convert_message_to_google_genai_format_image_in_tool_result(self): - tool_call = ToolCall(id="123", tool_name="image_retriever", arguments={}) - - base64_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=" - tool_result = [TextContent("Here is the image"), ImageContent(base64_image=base64_str, mime_type="image/jpeg")] - - message = ChatMessage.from_tool(tool_result=tool_result, origin=tool_call) - google_content = _convert_message_to_google_genai_format(message) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert isinstance(google_content.parts[0].function_response, types.FunctionResponse) - assert google_content.parts[0].function_response.id == "123" - assert google_content.parts[0].function_response.name == "image_retriever" - assert google_content.parts[0].function_response.response == {"result": ""} - assert len(google_content.parts[0].function_response.parts) == 2 - assert isinstance(google_content.parts[0].function_response.parts[0], types.FunctionResponsePart) - assert google_content.parts[0].function_response.parts[0].inline_data.mime_type == "text/plain" - assert google_content.parts[0].function_response.parts[0].inline_data.data == b"Here is the image" - assert isinstance(google_content.parts[0].function_response.parts[1], types.FunctionResponsePart) - assert google_content.parts[0].function_response.parts[1].inline_data.mime_type == "image/jpeg" - assert google_content.parts[0].function_response.parts[1].inline_data.data == base64.b64decode(base64_str) - assert len(google_content.parts[0].function_response.parts[1].inline_data.data) > 0 - - def test_convert_message_to_google_genai_format_invalid_tool_result_type(self): - tool_call = ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) - message = ChatMessage.from_tool(tool_result=256, origin=tool_call) - with pytest.raises(ValueError, match="Unsupported content type in tool call result"): - _convert_message_to_google_genai_format(message) - - message = ChatMessage.from_tool(tool_result=[TextContent("This is supported"), 256], origin=tool_call) - with pytest.raises( - ValueError, - match=( - r"Unsupported content type in tool call result list. " - "Only TextContent and ImageContent are supported." - ), - ): - _convert_message_to_google_genai_format(message) - - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration +@pytest.mark.skipif( + not os.environ.get("GOOGLE_API_KEY", None), + reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", +) +@pytest.mark.integration +class TestGoogleGenAIChatGeneratorInference: def test_live_run(self) -> None: chat_messages = [ChatMessage.from_user("What's the capital of France")] component = GoogleGenAIChatGenerator() @@ -870,15 +215,6 @@ def test_live_run(self) -> None: assert "gemini-2.5-flash" in message.meta["model"] assert message.meta["finish_reason"] == "stop" - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_run_with_multiple_images_mixed_content(self, test_files_path): """Test that multiple images with interleaved text maintain proper ordering.""" client = GoogleGenAIChatGenerator() @@ -920,11 +256,28 @@ def test_run_with_multiple_images_mixed_content(self, test_files_path): f"Apple should be mentioned before banana in the response. Got: {first_reply.text}" ) - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration + def test_live_run_with_file_content(self, test_files_path): + pdf_path = test_files_path / "sample_pdf_3.pdf" + + file_content = FileContent.from_file_path(file_path=pdf_path) + + chat_messages = [ + ChatMessage.from_user( + content_parts=[file_content, "Is this document a paper about LLMs? Respond with 'yes' or 'no' only."] + ) + ] + + generator = GoogleGenAIChatGenerator() + results = generator.run(chat_messages) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + + assert message.is_from(ChatRole.ASSISTANT) + + assert message.text + assert "no" in message.text.lower() + def test_live_run_streaming(self): component = GoogleGenAIChatGenerator() component_info = ComponentInfo.from_component(component) @@ -951,11 +304,6 @@ def __call__(self, chunk: StreamingChunk) -> None: assert message.text and "paris" in message.text.lower(), "Response does not contain Paris" assert message.meta["finish_reason"] == "stop" - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_tools_streaming(self, tools): """ Integration test that the GoogleGenAIChatGenerator component can run with tools and streaming. @@ -989,11 +337,6 @@ def test_live_run_with_tools_streaming(self, tools): assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_toolset(self, tools): """Test that GoogleGenAIChatGenerator can run with a Toolset.""" toolset = Toolset(tools) @@ -1029,11 +372,6 @@ def test_live_run_with_toolset(self, tools): "Response does not contain Paris or weather" ) - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_parallel_tools(self, tools): """ Integration test that the GoogleGenAIChatGenerator component can run with parallel tools. @@ -1081,11 +419,6 @@ def test_live_run_with_parallel_tools(self, tools): # Check that the response mentions both temperature readings assert "22" in final_message.text and "15" in final_message.text - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_mixed_tools(self): """ Integration test that verifies GoogleGenAIChatGenerator works with mixed Tool and Toolset. @@ -1177,18 +510,15 @@ def test_live_run_with_mixed_tools(self): assert "paris" in final_message.text.lower() assert "berlin" in final_message.text.lower() - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_thinking(self): """ Integration test for the thinking feature with a model that supports it. """ # We use a model that supports the thinking feature chat_messages = [ChatMessage.from_user("Why is the sky blue? Explain in one sentence.")] - component = GoogleGenAIChatGenerator(model="gemini-2.5-pro", generation_kwargs={"thinking_budget": -1}) + component = GoogleGenAIChatGenerator( + model="gemini-3-flash-preview", generation_kwargs={"thinking_level": "low"} + ) results = component.run(chat_messages) assert len(results["replies"]) == 1 @@ -1206,11 +536,6 @@ def test_live_run_with_thinking(self): assert message.meta["usage"]["thoughts_token_count"] is not None assert message.meta["usage"]["thoughts_token_count"] > 0 - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_thinking_and_tools_multi_turn(self, tools): """ Integration test for thought signatures preservation in multi-turn conversations with tools. @@ -1218,9 +543,9 @@ def test_live_run_with_thinking_and_tools_multi_turn(self, tools): """ # Use a model that supports thinking with tools component = GoogleGenAIChatGenerator( - model="gemini-2.5-pro", + model="gemini-3-flash-preview", tools=tools, - generation_kwargs={"thinking_budget": -1}, # Dynamic allocation + generation_kwargs={"thinking_level": "low"}, # Dynamic allocation ) # First turn: Ask about the weather @@ -1266,11 +591,6 @@ def test_live_run_with_thinking_and_tools_multi_turn(self, tools): # The model should maintain context from previous turns assert "22" in second_response.text or "sunny" in second_response.text.lower() - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_thinking_unsupported_model_fails_fast(self): """ Integration test to verify that thinking configuration fails fast with unsupported models. @@ -1290,11 +610,6 @@ def test_live_run_with_thinking_unsupported_model_fails_fast(self): assert "thinking_budget" in error_message or "thinking features" in error_message assert "Try removing" in error_message or "use a different model" in error_message - @pytest.mark.integration - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) def test_live_run_agent_with_images_in_tool_result(self, test_files_path): def retrieve_image(): return [ @@ -1325,7 +640,7 @@ def retrieve_image(): ) @pytest.mark.integration @pytest.mark.asyncio -class TestAsyncGoogleGenAIChatGenerator: +class TestAsyncGoogleGenAIChatGeneratorInference: """Test class for async functionality of GoogleGenAIChatGenerator.""" async def test_live_run_async(self) -> None: @@ -1386,7 +701,9 @@ async def test_live_run_async_with_thinking(self): Async integration test for the thinking feature. """ chat_messages = [ChatMessage.from_user("Why is the sky blue? Explain in one sentence.")] - component = GoogleGenAIChatGenerator(model="gemini-2.5-pro", generation_kwargs={"thinking_budget": -1}) + component = GoogleGenAIChatGenerator( + model="gemini-3-flash-preview", generation_kwargs={"thinking_level": "low"} + ) results = await component.run_async(chat_messages) assert len(results["replies"]) == 1 @@ -1445,132 +762,3 @@ async def test_concurrent_async_calls(self): assert result["replies"][0].text assert result["replies"][0].meta["model"] assert result["replies"][0].meta["finish_reason"] == "stop" - - -def test_aggregate_streaming_chunks_with_reasoning(monkeypatch): - """Test the _aggregate_streaming_chunks_with_reasoning method for reasoning content aggregation.""" - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - # Create mock streaming chunks with reasoning content - chunk1 = Mock() - chunk1.content = "Hello" - chunk1.tool_calls = [] - chunk1.meta = {"usage": {"prompt_tokens": 10, "completion_tokens": 5}} - chunk1.component_info = component_info - chunk1.reasoning = None - - chunk2 = Mock() - chunk2.content = " world" - chunk2.tool_calls = [] - chunk2.meta = {"usage": {"prompt_tokens": 10, "completion_tokens": 8}} - chunk2.component_info = component_info - chunk2.reasoning = None - - # Mock the final chunk with reasoning - final_chunk = Mock() - final_chunk.content = "" - final_chunk.tool_calls = [] - final_chunk.meta = { - "usage": {"prompt_tokens": 10, "completion_tokens": 13, "thoughts_token_count": 5}, - "model": "gemini-2.5-pro", - } - final_chunk.component_info = component_info - final_chunk.reasoning = ReasoningContent(reasoning_text="I should greet the user politely") - - # Add reasoning deltas to the final chunk meta (this is how the real method works) - final_chunk.meta["reasoning_deltas"] = [{"type": "reasoning", "content": "I should greet the user politely"}] - - # Test aggregation - result = GoogleGenAIChatGenerator._aggregate_streaming_chunks_with_reasoning([chunk1, chunk2, final_chunk]) - - # Verify the aggregated message - assert result.text == "Hello world" - assert result.tool_calls == [] - assert result.reasoning is not None - assert result.reasoning.reasoning_text == "I should greet the user politely" - assert result.meta["usage"]["prompt_tokens"] == 10 - assert result.meta["usage"]["completion_tokens"] == 13 - assert result.meta["usage"]["thoughts_token_count"] == 5 - assert result.meta["model"] == "gemini-2.5-pro" - - -def test_process_thinking_budget(monkeypatch): - """Test the _process_thinking_config method with different thinking_budget values.""" - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - - # Test valid thinking_budget values - generation_kwargs = {"thinking_budget": 1024, "temperature": 0.7} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - - # thinking_budget should be moved to thinking_config - assert "thinking_budget" not in result - assert "thinking_config" in result - assert result["thinking_config"].thinking_budget == 1024 - # Other kwargs should be preserved - assert result["temperature"] == 0.7 - - # Test dynamic allocation (-1) - generation_kwargs = {"thinking_budget": -1} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_budget == -1 - - # Test zero (disable thinking) - generation_kwargs = {"thinking_budget": 0} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_budget == 0 - - # Test large value - generation_kwargs = {"thinking_budget": 24576} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_budget == 24576 - - # Test when thinking_budget is not present - generation_kwargs = {"temperature": 0.5} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result == generation_kwargs # No changes - - # Test invalid type (should fall back to dynamic) - generation_kwargs = {"thinking_budget": "invalid", "temperature": 0.5} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_budget == -1 # Dynamic allocation - assert result["temperature"] == 0.5 - - -def test_process_thinking_level(monkeypatch): - """Test the _process_thinking_config method with different thinking_level values.""" - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - - # Test valid thinking_level values - generation_kwargs = {"thinking_level": "high", "temperature": 0.7} - result = component._process_thinking_config(generation_kwargs.copy()) - - # thinking_level should be moved to thinking_config - assert "thinking_level" not in result - assert "thinking_config" in result - assert result["thinking_config"].thinking_level == types.ThinkingLevel.HIGH - # Other kwargs should be preserved - assert result["temperature"] == 0.7 - - # Test THINKING_LEVEL_LOW in upper case - generation_kwargs = {"thinking_level": "LOW"} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_level == types.ThinkingLevel.LOW - - # Test THINKING_LEVEL_UNSPECIFIED - generation_kwargs = {"thinking_level": "test"} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_level == types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - - # Test when thinking_level is not present - generation_kwargs = {"temperature": 0.5} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result == generation_kwargs # No changes - - # Test invalid type (should fall back to THINKING_LEVEL_UNSPECIFIED) - generation_kwargs = {"thinking_level": 123, "temperature": 0.5} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_level == types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - assert result["temperature"] == 0.5 diff --git a/integrations/google_genai/tests/test_chat_generator_utils.py b/integrations/google_genai/tests/test_chat_generator_utils.py new file mode 100644 index 0000000000..04b2abbd1c --- /dev/null +++ b/integrations/google_genai/tests/test_chat_generator_utils.py @@ -0,0 +1,806 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import base64 +from unittest.mock import Mock + +import pytest +from google.genai import types +from haystack.dataclasses import ( + ChatMessage, + ChatRole, + ComponentInfo, + FileContent, + ImageContent, + ReasoningContent, + TextContent, + ToolCall, +) + +from haystack_integrations.components.generators.google_genai.chat.chat_generator import ( + GoogleGenAIChatGenerator, +) + +from haystack_integrations.components.generators.google_genai.chat.utils import ( + _aggregate_streaming_chunks_with_reasoning, + _convert_google_chunk_to_streaming_chunk, + _convert_google_genai_response_to_chatmessage, + _convert_message_to_google_genai_format, + _convert_usage_metadata_to_serializable, + _process_thinking_config, +) + + +def test_process_thinking_budget(): + """Test the _process_thinking_config method with different thinking_budget values.""" + + # Test valid thinking_budget values + generation_kwargs = {"thinking_budget": 1024, "temperature": 0.7} + result = _process_thinking_config(generation_kwargs.copy()) + + # thinking_budget should be moved to thinking_config + assert "thinking_budget" not in result + assert "thinking_config" in result + assert result["thinking_config"].thinking_budget == 1024 + # Other kwargs should be preserved + assert result["temperature"] == 0.7 + + # Test dynamic allocation (-1) + generation_kwargs = {"thinking_budget": -1} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_budget == -1 + + # Test zero (disable thinking) + generation_kwargs = {"thinking_budget": 0} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_budget == 0 + + # Test large value + generation_kwargs = {"thinking_budget": 24576} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_budget == 24576 + + # Test when thinking_budget is not present + generation_kwargs = {"temperature": 0.5} + result = _process_thinking_config(generation_kwargs.copy()) + assert result == generation_kwargs # No changes + + # Test invalid type (should fall back to dynamic) + generation_kwargs = {"thinking_budget": "invalid", "temperature": 0.5} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_budget == -1 # Dynamic allocation + assert result["temperature"] == 0.5 + + +def test_process_thinking_level(): + """Test the _process_thinking_config method with different thinking_level values.""" + + # Test valid thinking_level values + generation_kwargs = {"thinking_level": "high", "temperature": 0.7} + result = _process_thinking_config(generation_kwargs.copy()) + + # thinking_level should be moved to thinking_config + assert "thinking_level" not in result + assert "thinking_config" in result + assert result["thinking_config"].thinking_level == types.ThinkingLevel.HIGH + # Other kwargs should be preserved + assert result["temperature"] == 0.7 + + # Test THINKING_LEVEL_LOW in upper case + generation_kwargs = {"thinking_level": "LOW"} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_level == types.ThinkingLevel.LOW + + # Test THINKING_LEVEL_UNSPECIFIED + generation_kwargs = {"thinking_level": "test"} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_level == types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + + # Test when thinking_level is not present + generation_kwargs = {"temperature": 0.5} + result = _process_thinking_config(generation_kwargs.copy()) + assert result == generation_kwargs # No changes + + # Test invalid type (should fall back to THINKING_LEVEL_UNSPECIFIED) + generation_kwargs = {"thinking_level": 123, "temperature": 0.5} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_level == types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + assert result["temperature"] == 0.5 + + +class TestStreamingChunkConversion: + def test_convert_google_chunk_to_streaming_chunk_text_only(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + mock_chunk = Mock() + mock_candidate = Mock() + mock_candidate.finish_reason = "STOP" + mock_chunk.candidates = [mock_candidate] + + mock_content = Mock() + mock_content.parts = [] + mock_part = Mock() + mock_part.text = "Hello, world!" + mock_part.function_call = None + # Explicitly set thought=False to simulate a regular (non-thought) part + mock_part.thought = False + mock_content.parts.append(mock_part) + mock_candidate.content = mock_content + + chunk = _convert_google_chunk_to_streaming_chunk( + chunk=mock_chunk, + index=0, + component_info=component_info, + model="gemini-2.5-flash", + ) + + assert chunk.content == "Hello, world!" + assert chunk.tool_calls == [] + assert chunk.finish_reason == "stop" + assert chunk.index == 0 + assert "received_at" in chunk.meta + assert chunk.component_info == component_info + + def test_convert_google_chunk_to_streaming_chunk_tool_call(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + mock_chunk = Mock() + mock_candidate = Mock() + mock_candidate.finish_reason = "STOP" + mock_chunk.candidates = [mock_candidate] + + mock_content = Mock() + mock_content.parts = [] + mock_part = Mock() + mock_part.text = None + mock_function_call = Mock() + mock_function_call.name = "weather" + mock_function_call.args = {"city": "Paris"} + mock_function_call.id = "call_123" + mock_part.function_call = mock_function_call + mock_content.parts.append(mock_part) + mock_candidate.content = mock_content + + chunk = _convert_google_chunk_to_streaming_chunk( + chunk=mock_chunk, index=0, component_info=component_info, model="gemini-2.5-flash" + ) + + assert chunk.content == "" + assert chunk.tool_calls is not None + assert len(chunk.tool_calls) == 1 + assert chunk.tool_calls[0].tool_name == "weather" + assert chunk.tool_calls[0].arguments == '{"city": "Paris"}' + assert chunk.tool_calls[0].id == "call_123" + assert chunk.finish_reason == "stop" + assert chunk.index == 0 + assert "received_at" in chunk.meta + assert chunk.component_info == component_info + + def test_convert_google_chunk_to_streaming_chunk_mixed_content(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + mock_chunk = Mock() + mock_candidate = Mock() + mock_candidate.finish_reason = "STOP" + mock_chunk.candidates = [mock_candidate] + + mock_content = Mock() + mock_content.parts = [] + + mock_text_part = Mock() + mock_text_part.text = "I'll check the weather for you." + mock_text_part.function_call = None + # Explicitly set thought=False to simulate a regular (non-thought) part + mock_text_part.thought = False + mock_content.parts.append(mock_text_part) + + mock_tool_part = Mock() + mock_tool_part.text = None + mock_function_call = Mock() + mock_function_call.name = "weather" + mock_function_call.args = {"city": "London"} + mock_function_call.id = "call_456" + mock_tool_part.function_call = mock_function_call + mock_content.parts.append(mock_tool_part) + + mock_candidate.content = mock_content + + chunk = _convert_google_chunk_to_streaming_chunk( + chunk=mock_chunk, index=0, component_info=component_info, model="gemini-2.5-flash" + ) + + # When both text and tool calls are present, tool calls are prioritized + assert chunk.content == "" + assert chunk.tool_calls is not None + assert len(chunk.tool_calls) == 1 + assert chunk.tool_calls[0].tool_name == "weather" + assert chunk.tool_calls[0].arguments == '{"city": "London"}' + assert chunk.finish_reason == "stop" + assert chunk.component_info == component_info + + def test_convert_google_chunk_to_streaming_chunk_empty_parts(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + mock_chunk = Mock() + mock_candidate = Mock() + mock_content = Mock() + mock_content.parts = [] + mock_candidate.content = mock_content + mock_chunk.candidates = [mock_candidate] + + chunk = _convert_google_chunk_to_streaming_chunk( + chunk=mock_chunk, index=0, component_info=component_info, model="gemini-2.5-flash" + ) + + assert chunk.content == "" + assert chunk.tool_calls == [] + assert chunk.component_info == component_info + + def test_convert_google_chunk_to_streaming_chunk_real_example(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + # Chunk 1: Text only + chunk1_parts = [ + types.Part( + text="I'll get the weather information for Paris and Berlin", function_call=None, function_response=None + ) + ] + chunk1_content = types.Content(role="model", parts=chunk1_parts) + chunk1_candidate = types.Candidate( + content=chunk1_content, + finish_reason=None, + index=None, + safety_ratings=None, + citation_metadata=None, + grounding_metadata=None, + finish_message=None, + token_count=None, + logprobs_result=None, + avg_logprobs=None, + url_context_metadata=None, + ) + chunk1_usage = types.GenerateContentResponseUsageMetadata( + prompt_token_count=217, candidates_token_count=None, total_token_count=217 + ) + chunk1 = types.GenerateContentResponse( + candidates=[chunk1_candidate], + usage_metadata=chunk1_usage, + model_version="gemini-2.5-flash", + response_id=None, + create_time=None, + prompt_feedback=None, + automatic_function_calling_history=None, + parsed=None, + ) + + streaming_chunk1 = _convert_google_chunk_to_streaming_chunk( + chunk=chunk1, index=0, component_info=component_info, model="gemini-2.5-flash" + ) + assert streaming_chunk1.content == "I'll get the weather information for Paris and Berlin" + assert streaming_chunk1.tool_calls == [] + assert streaming_chunk1.finish_reason is None + assert streaming_chunk1.index == 0 + assert "received_at" in streaming_chunk1.meta + assert streaming_chunk1.meta["model"] == "gemini-2.5-flash" + assert "usage" in streaming_chunk1.meta + assert streaming_chunk1.meta["usage"]["prompt_tokens"] == 217 + assert streaming_chunk1.meta["usage"]["completion_tokens"] is None + assert streaming_chunk1.meta["usage"]["total_tokens"] == 217 + assert streaming_chunk1.component_info == component_info + + # Chunk 2: Text only + chunk2_parts = [ + types.Part(text=" and present it in a structured format.", function_call=None, function_response=None) + ] + chunk2_content = types.Content(role="model", parts=chunk2_parts) + chunk2_candidate = types.Candidate( + content=chunk2_content, + finish_reason=None, + index=None, + safety_ratings=None, + citation_metadata=None, + grounding_metadata=None, + finish_message=None, + token_count=None, + logprobs_result=None, + avg_logprobs=None, + url_context_metadata=None, + ) + chunk2_usage = types.GenerateContentResponseUsageMetadata( + prompt_token_count=217, candidates_token_count=None, total_token_count=217 + ) + chunk2 = types.GenerateContentResponse( + candidates=[chunk2_candidate], + usage_metadata=chunk2_usage, + model_version="gemini-2.5-flash", + response_id=None, + create_time=None, + prompt_feedback=None, + automatic_function_calling_history=None, + parsed=None, + ) + + streaming_chunk2 = _convert_google_chunk_to_streaming_chunk( + chunk=chunk2, index=1, component_info=component_info, model="gemini-2.5-flash" + ) + assert streaming_chunk2.content == " and present it in a structured format." + assert streaming_chunk2.tool_calls == [] + assert streaming_chunk2.finish_reason is None + assert streaming_chunk2.index == 1 + assert "received_at" in streaming_chunk2.meta + assert streaming_chunk2.meta["model"] == "gemini-2.5-flash" + assert "usage" in streaming_chunk2.meta + assert streaming_chunk2.meta["usage"]["prompt_tokens"] == 217 + assert streaming_chunk2.meta["usage"]["completion_tokens"] is None + assert streaming_chunk2.meta["usage"]["total_tokens"] == 217 + assert streaming_chunk2.component_info == component_info + + # Chunk 3: Multiple tool calls (6 function calls) for 2 cities with 3 tools each + fc1 = types.FunctionCall(id=None, name="get_weather", args={"city": "Paris"}) + fc2 = types.FunctionCall(id=None, name="get_population", args={"city": "Paris"}) + fc3 = types.FunctionCall(id=None, name="get_time", args={"city": "Paris"}) + fc4 = types.FunctionCall(id=None, name="get_weather", args={"city": "Berlin"}) + fc5 = types.FunctionCall(id=None, name="get_population", args={"city": "Berlin"}) + fc6 = types.FunctionCall(id=None, name="get_time", args={"city": "Berlin"}) + + parts = [ + types.Part(text=None, function_call=fc1, function_response=None), + types.Part(text=None, function_call=fc2, function_response=None), + types.Part(text=None, function_call=fc3, function_response=None), + types.Part(text=None, function_call=fc4, function_response=None), + types.Part(text=None, function_call=fc5, function_response=None), + types.Part(text=None, function_call=fc6, function_response=None), + ] + + content = types.Content(role="model", parts=parts) + candidate = types.Candidate( + content=content, + finish_reason=types.FinishReason.STOP, + index=None, + safety_ratings=None, + citation_metadata=None, + grounding_metadata=None, + finish_message=None, + token_count=None, + logprobs_result=None, + avg_logprobs=None, + url_context_metadata=None, + ) + + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=144, candidates_token_count=121, total_token_count=265 + ) + chunk = types.GenerateContentResponse( + candidates=[candidate], + usage_metadata=usage_metadata, + model_version="gemini-2.5-flash", + response_id=None, + create_time=None, + prompt_feedback=None, + automatic_function_calling_history=None, + parsed=None, + ) + + streaming_chunk = _convert_google_chunk_to_streaming_chunk( + chunk=chunk, index=2, component_info=component_info, model="gemini-2.5-flash" + ) + assert streaming_chunk.content == "" + assert streaming_chunk.tool_calls is not None + assert len(streaming_chunk.tool_calls) == 6 + assert streaming_chunk.finish_reason == "stop" + assert streaming_chunk.index == 2 + assert "received_at" in streaming_chunk.meta + assert streaming_chunk.meta["model"] == "gemini-2.5-flash" + assert streaming_chunk.component_info == component_info + assert "usage" in streaming_chunk.meta + assert streaming_chunk.meta["usage"]["prompt_tokens"] == 144 + assert streaming_chunk.meta["usage"]["completion_tokens"] == 121 + assert streaming_chunk.meta["usage"]["total_tokens"] == 265 + + assert streaming_chunk.tool_calls[0].tool_name == "get_weather" + assert streaming_chunk.tool_calls[0].arguments == '{"city": "Paris"}' + assert streaming_chunk.tool_calls[0].id is None + assert streaming_chunk.tool_calls[0].index == 0 + + assert streaming_chunk.tool_calls[1].tool_name == "get_population" + assert streaming_chunk.tool_calls[1].arguments == '{"city": "Paris"}' + assert streaming_chunk.tool_calls[1].id is None + assert streaming_chunk.tool_calls[1].index == 1 + + assert streaming_chunk.tool_calls[2].tool_name == "get_time" + assert streaming_chunk.tool_calls[2].arguments == '{"city": "Paris"}' + assert streaming_chunk.tool_calls[2].id is None + assert streaming_chunk.tool_calls[2].index == 2 + + assert streaming_chunk.tool_calls[3].tool_name == "get_weather" + assert streaming_chunk.tool_calls[3].arguments == '{"city": "Berlin"}' + assert streaming_chunk.tool_calls[3].id is None + assert streaming_chunk.tool_calls[3].index == 3 + + assert streaming_chunk.tool_calls[4].tool_name == "get_population" + assert streaming_chunk.tool_calls[4].arguments == '{"city": "Berlin"}' + assert streaming_chunk.tool_calls[4].id is None + assert streaming_chunk.tool_calls[4].index == 4 + + assert streaming_chunk.tool_calls[5].tool_name == "get_time" + assert streaming_chunk.tool_calls[5].arguments == '{"city": "Berlin"}' + assert streaming_chunk.tool_calls[5].id is None + assert streaming_chunk.tool_calls[5].index == 5 + + def test_aggregate_streaming_chunks_with_reasoning(self): + """Test the _aggregate_streaming_chunks_with_reasoning method for reasoning content aggregation.""" + + # Create mock streaming chunks with reasoning content + chunk1 = Mock() + chunk1.content = "Hello" + chunk1.tool_calls = [] + chunk1.meta = {"usage": {"prompt_tokens": 10, "completion_tokens": 5}} + chunk1.reasoning = None + + chunk2 = Mock() + chunk2.content = " world" + chunk2.tool_calls = [] + chunk2.meta = {"usage": {"prompt_tokens": 10, "completion_tokens": 8}} + chunk2.reasoning = None + + # Mock the final chunk with reasoning + final_chunk = Mock() + final_chunk.content = "" + final_chunk.tool_calls = [] + final_chunk.meta = { + "usage": {"prompt_tokens": 10, "completion_tokens": 13, "thoughts_token_count": 5}, + "model": "gemini-2.5-pro", + } + final_chunk.reasoning = ReasoningContent(reasoning_text="I should greet the user politely") + + # Add reasoning deltas to the final chunk meta (this is how the real method works) + final_chunk.meta["reasoning_deltas"] = [{"type": "reasoning", "content": "I should greet the user politely"}] + + # Test aggregation + result = _aggregate_streaming_chunks_with_reasoning([chunk1, chunk2, final_chunk]) + + # Verify the aggregated message + assert result.text == "Hello world" + assert result.tool_calls == [] + assert result.reasoning is not None + assert result.reasoning.reasoning_text == "I should greet the user politely" + assert result.meta["usage"]["prompt_tokens"] == 10 + assert result.meta["usage"]["completion_tokens"] == 13 + assert result.meta["usage"]["thoughts_token_count"] == 5 + assert result.meta["model"] == "gemini-2.5-pro" + + +class TestConvertMessageToGoogleGenAI: + def test_convert_message_to_google_genai_format_complex(self): + """ + Test that the GoogleGenAIChatGenerator can convert a complex sequence of ChatMessages to Google GenAI format. + In particular, we check that different tool results are handled properly in sequence. + """ + messages = [ + ChatMessage.from_system("You are good assistant"), + ChatMessage.from_user("What's the weather like in Paris? And how much is 2+2?"), + ChatMessage.from_assistant( + text="", + tool_calls=[ + ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}), + ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}), + ], + ), + ChatMessage.from_tool( + tool_result="22° C", origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ), + ChatMessage.from_tool( + tool_result="4", origin=ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}) + ), + ] + + # Test system message handling (should be handled separately in Google GenAI) + system_message = messages[0] + assert system_message.is_from(ChatRole.SYSTEM) + + # Test user message conversion + user_message = messages[1] + google_content = _convert_message_to_google_genai_format(user_message) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert google_content.parts[0].text == "What's the weather like in Paris? And how much is 2+2?" + + # Test assistant message with tool calls + assistant_message = messages[2] + google_content = _convert_message_to_google_genai_format(assistant_message) + assert google_content.role == "model" + assert len(google_content.parts) == 2 + assert google_content.parts[0].function_call.name == "weather" + assert google_content.parts[0].function_call.args == {"city": "Paris"} + assert google_content.parts[1].function_call.name == "math" + assert google_content.parts[1].function_call.args == {"expression": "2+2"} + + # Test tool result messages + tool_result_1 = messages[3] + google_content = _convert_message_to_google_genai_format(tool_result_1) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert google_content.parts[0].function_response.name == "weather" + assert google_content.parts[0].function_response.response == {"result": "22° C"} + + tool_result_2 = messages[4] + google_content = _convert_message_to_google_genai_format(tool_result_2) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert google_content.parts[0].function_response.name == "math" + assert google_content.parts[0].function_response.response == {"result": "4"} + + def test_convert_message_to_google_genai_format_with_multiple_images(self, test_files_path): + """Test converting a message with multiple images in mixed content to Google GenAI format.""" + apple_path = test_files_path / "apple.jpg" + banana_path = test_files_path / "banana.png" + + apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) + banana_content = ImageContent.from_file_path(banana_path, size=(100, 100)) + + message = ChatMessage.from_user( + content_parts=[ + "Compare these fruits. First:", + apple_content, + "Second:", + banana_content, + "Which is healthier?", + ] + ) + + google_content = _convert_message_to_google_genai_format(message) + + assert google_content.role == "user" + assert len(google_content.parts) == 5 + + # Verify the exact order is preserved + assert google_content.parts[0].text == "Compare these fruits. First:" + + # First image (apple) + assert hasattr(google_content.parts[1], "inline_data") + assert google_content.parts[1].inline_data.mime_type == "image/jpeg" + assert google_content.parts[1].inline_data.data is not None + + assert google_content.parts[2].text == "Second:" + + # Second image (banana) + assert hasattr(google_content.parts[3], "inline_data") + assert google_content.parts[3].inline_data.mime_type == "image/png" + assert google_content.parts[3].inline_data.data is not None + + assert google_content.parts[4].text == "Which is healthier?" + + def test_convert_message_to_google_genai_format_image_with_minimal_text(self, test_files_path): + """Test converting a message with minimal text and image to Google GenAI format.""" + apple_path = test_files_path / "apple.jpg" + apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) + + # Haystack requires at least one textual part for user messages, so we use minimal text + message = ChatMessage.from_user(content_parts=["", apple_content]) + + google_content = _convert_message_to_google_genai_format(message) + + assert google_content.role == "user" + # Empty text should be filtered out by our implementation, leaving only the image + assert len(google_content.parts) == 1 + + # Should only have the image part (empty text filtered out) + assert hasattr(google_content.parts[0], "inline_data") + assert google_content.parts[0].inline_data.mime_type == "image/jpeg" + assert google_content.parts[0].inline_data.data is not None + + def test_convert_message_to_google_genai_file_content(self, test_files_path): + file_path = test_files_path / "sample_pdf_3.pdf" + file_content = FileContent.from_file_path(file_path) + message = ChatMessage.from_user(content_parts=["Describe this document:", file_content]) + google_content = _convert_message_to_google_genai_format(message) + assert google_content.role == "user" + assert len(google_content.parts) == 2 + + assert google_content.parts[0].text == "Describe this document:" + assert hasattr(google_content.parts[1], "inline_data") + assert google_content.parts[1].inline_data.mime_type == "application/pdf" + assert google_content.parts[1].inline_data.data is not None + + def test_convert_message_to_google_genai_file_content_in_assistant_message(self, test_files_path): + file_path = test_files_path / "sample_pdf_3.pdf" + file_content = FileContent.from_file_path(file_path) + message = ChatMessage.from_assistant("This is a document") + message._content.append(file_content) + + with pytest.raises(ValueError, match="FileContent is only supported for user messages"): + _convert_message_to_google_genai_format(message) + + def test_convert_message_to_google_genai_format_with_thought_signatures(self): + """Test converting an assistant message with thought signatures for multi-turn context preservation.""" + + # Create an assistant message with tool calls and thought signatures in meta + tool_call = ToolCall(id="call_123", tool_name="weather", arguments={"city": "Paris"}) + + # Thought signatures are stored in meta when thinking is enabled with tools + # They must be base64 encoded as per the API requirements + thought_signatures = [ + { + "part_index": 0, + "signature": base64.b64encode(b"encrypted_mock_thought_signature_1").decode("utf-8"), + "has_text": True, + "has_function_call": False, + "is_thought": False, + }, + { + "part_index": 1, + "signature": base64.b64encode(b"encrypted_mock_thought_signature_2").decode("utf-8"), + "has_text": False, + "has_function_call": True, + "is_thought": False, + }, + ] + + message = ChatMessage.from_assistant( + text="I'll check the weather for you", + tool_calls=[tool_call], + meta={"thought_signatures": thought_signatures}, + ) + + google_content = _convert_message_to_google_genai_format(message) + + assert google_content.role == "model" + assert len(google_content.parts) == 2 + + # First part should have text and its thought signature + assert google_content.parts[0].text == "I'll check the weather for you" + # thought_signature is returned as bytes from the API + assert google_content.parts[0].thought_signature == b"encrypted_mock_thought_signature_1" + + # Second part should have function call and its thought signature + assert google_content.parts[1].function_call.name == "weather" + assert google_content.parts[1].function_call.args == {"city": "Paris"} + assert google_content.parts[1].thought_signature == b"encrypted_mock_thought_signature_2" + + def test_convert_message_to_google_genai_format_with_reasoning_content(self): + """Test that ReasoningContent is properly skipped during conversion.""" + # ReasoningContent is for human transparency only, not sent to the API + reasoning = ReasoningContent(reasoning_text="Of Life, the Universe and Everything...") + + # Create a message with both text and reasoning content + message = ChatMessage.from_assistant(text="Forty-two", reasoning=reasoning) + + google_content = _convert_message_to_google_genai_format(message) + + assert google_content.role == "model" + assert len(google_content.parts) == 1 + + # Only the text should be included, reasoning content should be skipped + assert google_content.parts[0].text == "Forty-two" + # Verify no thought part was created (reasoning is not sent to API) + assert not hasattr(google_content.parts[0], "thought") or not google_content.parts[0].thought + + def test_convert_message_to_google_genai_format_tool_message(self): + tool_call = ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + message = ChatMessage.from_tool(tool_result="22° C", origin=tool_call) + google_content = _convert_message_to_google_genai_format(message) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert isinstance(google_content.parts[0].function_response, types.FunctionResponse) + assert google_content.parts[0].function_response.id == "123" + assert google_content.parts[0].function_response.name == "weather" + assert google_content.parts[0].function_response.response == {"result": "22° C"} + assert google_content.parts[0].function_response.parts is None + + def test_convert_message_to_google_genai_format_image_in_tool_result(self): + tool_call = ToolCall(id="123", tool_name="image_retriever", arguments={}) + + base64_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=" + tool_result = [TextContent("Here is the image"), ImageContent(base64_image=base64_str, mime_type="image/jpeg")] + + message = ChatMessage.from_tool(tool_result=tool_result, origin=tool_call) + google_content = _convert_message_to_google_genai_format(message) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert isinstance(google_content.parts[0].function_response, types.FunctionResponse) + assert google_content.parts[0].function_response.id == "123" + assert google_content.parts[0].function_response.name == "image_retriever" + assert google_content.parts[0].function_response.response == {"result": ""} + assert len(google_content.parts[0].function_response.parts) == 2 + assert isinstance(google_content.parts[0].function_response.parts[0], types.FunctionResponsePart) + assert google_content.parts[0].function_response.parts[0].inline_data.mime_type == "text/plain" + assert google_content.parts[0].function_response.parts[0].inline_data.data == b"Here is the image" + assert isinstance(google_content.parts[0].function_response.parts[1], types.FunctionResponsePart) + assert google_content.parts[0].function_response.parts[1].inline_data.mime_type == "image/jpeg" + assert google_content.parts[0].function_response.parts[1].inline_data.data == base64.b64decode(base64_str) + assert len(google_content.parts[0].function_response.parts[1].inline_data.data) > 0 + + def test_convert_message_to_google_genai_format_invalid_tool_result_type(self): + tool_call = ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + message = ChatMessage.from_tool(tool_result=256, origin=tool_call) + with pytest.raises(ValueError, match="Unsupported content type in tool call result"): + _convert_message_to_google_genai_format(message) + + message = ChatMessage.from_tool(tool_result=[TextContent("This is supported"), 256], origin=tool_call) + with pytest.raises( + ValueError, + match=( + r"Unsupported content type in tool call result list. " + "Only TextContent and ImageContent are supported." + ), + ): + _convert_message_to_google_genai_format(message) + + +class TestConvertGoogleGenAIToMessage: + def test_convert_google_genai_response_to_chatmessage_parses_cached_tokens(self): + """ + When the API response includes cached_content_token_count in usage_metadata, + it is parsed into meta['usage']. + """ + + # Minimal candidate with one text part + mock_part = Mock() + mock_part.text = "Four." + mock_part.function_call = None + mock_part.thought = False + mock_part.thought_signature = None + mock_content = Mock() + mock_content.parts = [mock_part] + mock_candidate = Mock() + mock_candidate.content = mock_content + mock_candidate.finish_reason = "STOP" + + # Usage metadata including cached tokens (as returned by API when cache is used) + mock_usage = Mock() + mock_usage.prompt_token_count = 1000 + mock_usage.candidates_token_count = 5 + mock_usage.total_token_count = 1005 + mock_usage.cached_content_token_count = 800 + mock_usage.thoughts_token_count = None + + mock_response = Mock() + mock_response.candidates = [mock_candidate] + mock_response.usage_metadata = mock_usage + + message = _convert_google_genai_response_to_chatmessage(mock_response, "gemini-2.5-flash") + + assert message.meta is not None + assert "usage" in message.meta + usage = message.meta["usage"] + assert usage["prompt_tokens"] == 1000 + assert usage["completion_tokens"] == 5 + assert usage["total_tokens"] == 1005 + assert usage["cached_content_token_count"] == 800 + + def test_convert_usage_metadata_to_serializable(self): + """_convert_usage_metadata_to_serializable builds a serializable dict from a UsageMetadata object.""" + assert _convert_usage_metadata_to_serializable(None) == {} + assert _convert_usage_metadata_to_serializable(False) == {} + + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=100, + candidates_token_count=5, + total_token_count=105, + ) + result = _convert_usage_metadata_to_serializable(usage_metadata) + assert result["prompt_token_count"] == 100 + assert result["candidates_token_count"] == 5 + assert result["total_token_count"] == 105 + assert len(result) == 3 + + # Serialization of zero and composite types (ModalityTokenCount, lists) + modality_token_count = types.ModalityTokenCount(modality=types.Modality.TEXT, tokenCount=100) + usage_with_details = types.GenerateContentResponseUsageMetadata( + prompt_token_count=0, + candidates_tokens_details=[modality_token_count], + ) + result2 = _convert_usage_metadata_to_serializable(usage_with_details) + assert result2["prompt_token_count"] == 0 + assert result2["candidates_tokens_details"] == [{"modality": "TEXT", "token_count": 100}] diff --git a/integrations/google_genai/tests/test_files/sample_pdf_3.pdf b/integrations/google_genai/tests/test_files/sample_pdf_3.pdf new file mode 100644 index 0000000000000000000000000000000000000000..c0d07eaa688f0d605a30580f8ca1b3d06d4b66a6 GIT binary patch literal 19039 zcmaHS1yCi;(k0H|I=Eci-QC^Y-Q5|S!QI{68C)*zFu1$BJA)0feDA&eU+l)l-l*uR z&OUjnGrLblcST<^MNx4&CVEy_vibAVU05anBf!Df8kU!rLDkCD)(k*t=3!*#XlrKR zYGiC{Mg_~D=QT0T|>Rob8Nk{|jRKFGxh_AJD&{vLXORSO!_?e{D4YO!O@O z{&B!E$jJhjzK*E~0bHHk%>Li3?EhsIx3YEpI&}tdTO(I9Q8N<<(=Q)cGkXhHO8^HG z7e7D1#nsu&$PU&c`%=@+2A2u`!>^}54_b+CG|?6-$gKK(gRYLSt~sP0ap~PZ{mQIM zx7=G^x&t9q(79}cVl<9WN19R8-(g0;x$x$tZa$2d5WI1!Dcvhu0k5U7*J{?LVW$NALMrV7}l0fzfB{?w^ZL*gtj zSnt|22;Ou!!ezx5rq;yN8_nH-2b`Ra`eHg-P#QcQ6p~3h7qkNO7<>mslqrQz!xLod znqWE=&Ts-|-fKxy2C<$h`CIYtQD8co9Vmz0c+s%twjWGu|dvB3tfk4zVHLO{J9leGno-gu#&V0s@v*f zv7TNp_Yby~B>qmBQcnW3j&_uDG+Q1-P4rHvVW_S}g*EE5#Y>ZscGQfOX)fzTzMYTo z%cxc*gO;#+)&xZus_zyo$l#^Al7G2BPt~6NG*5mJ9->2`q-n$DQ{2tml2?#ruH-M+!Sr`l*@_2y?Vik^EYWkCl zGuAnOr@gXeQ0?{bDA~x)8llVsS%p)(wA8Z$!j0ZaqumuY$*J2CrSnEgGvG4)+E^&t zIBToli-Xg=FQaH;l2k~;D6_CXjZ%o#M*gbsNWBW(exwpI`5D@EqK@`VM(yz4KD#lvtD`VdUswZulF_Rjr#xOXQc|Mt_f*%O@G zN!|88D*Vr#{U1Ft|5p^Mo{nY!21O%_FD)pWxj48vo0z!(zAC_%)J;^(Tmd@&QaOXF znTIQYLCWq6DEyyZzA(nag|^J{G)hfSOy6;z(2B9{!%%ExP!gx zKS#w`0j&SF{bQ9gGqo}jcJKh`Fn$@?xYz-ltju48B421Tdsi30KLY+@Qu&g&iL0WK z^S_K=s{i*8!0{jY|3voB?*A@f{}Y+ZKXh(ivwsy887os4fX+V&P}ciz?cw~t-9(IB zjcgq({+Y(&D}euFDaXzKXB&tB2!v!A2kFiD}6J?d8w0H=p1Y?gv@53#E$ z2c~zh2M*a995tTDT+eOrhu`2c{!A{r4@1HRp!+B5SUt{1Y~w+3F9L74Kkq<(^4^XO zHDY%A>TQaEr1lhjR<6#c(Lr%w)xW{Y)=ug>Zl;VJo?(JQwSXnZ{X5wQ;vK;h?Pp7JFsTLf8B(iKm1V#*T-k$+GM9fsTJ!@2 zL=Qf1!j;u((sr3z^#f&Q2|KM{`^7axw&Uu1B$%qLZwt-F(Lx*v=c@hthGcLt!vwIJ zp#e4+KTnsj$bW)To(Q7C6k{`#_5V-=?SX(CV51Bi)m*zzR~PO9d6Pb9WM*(Q@-uTB zvBWCvJN;EwyX+y^YD0~rRKT${_fRr#A?SGICW9&H_&7KGf}B0iX+i~oDQn>O9b!5a zCh@p{C}aC7-Ms2~Ylr1#XGCC{!F0NH7;(+h%jT~w0=)SH4v&qgf=gtnp>)V;b?pua zg)zDV*!wNHVk}1hHCBWscp7nR08DsNK86#l?$kJ1u$w7~b3cHOG{w-V`#|LxTvzH6 zDk9S~p+P3t#f#s1;)Dr#fWJ0=>k>;sI=QcyYdAeu@JhTmZS~~mZyng|@z@1RPDb=Y z?DZ<6@MyARLtKiza@KxbO?Zs~Saqn-C%F)Xkq0jkt?=wF=1+%fRKIId z3dr>yUjhGJoeQ$ZL@0xk)E-4YR;Y!oQdIXWaEy%Y4qFY&v~!I$iydEHmGQS<_hQFU zxjJ3b#U8)<6_e=$WK^b$S)I6?9lODtBLF|9=ZfKCc=NV_==1JFs%^CRLt4T>nN)J3 z<1i7>gFU&KV?k_c39XENQrddZ$50Bu+5VRIu}40s6+ z)XfqsZh}18Ayc0QRTW0)rL^>c5%InTC1TM9I!2^SW!Aoi8mU?IYR)?;Z9y#>$ywtx zR#XorwZzl}Fs#5tW=u14&}eC?MP%lO1xscXlai3&#Ki^z?Dt+ z<4hPt_>2zJ=_uQ!F-?$PLLqV2KA?LHeyfR$e}>&$zF1}AxMhp7PqBwuN^(W;+#*y5vh{bQ+aPE@iU8s_g1}AR{JuNAoy_>V@7rrrpRS*M|jrC6Jm^ z$eX}9?`Ycwg)?{NHO*3{ZhbRA;n1RjA5hMf97OYGN+CQ`QYQ@DERK%O=bBp>EFbh+ zxa%^6Onn8g4#;=v?5c5!PRvCNrVhf~2r6W|RmBrt$Fmm{`MODvnQ%H+KwIZU)f@DL zZubeH6898Dkv6K*4i%rHT^G2YMYW$Yd%(K*NqtB;^n0h=gmUsSp}Y0sJY(;(zhUr` z=_`v4$B#Ba>xmM6Pc0liAGJ85&DLoh0M~q_AeDAkX5jECmJ+exsCS;95AGmNw%NHTdPX`->Plx-fmYUSo3jT7! zQNY2)L0dd*j>*G72=^4VJz+~k=LGtpDgPU}u0KnNuu-rg11s!!$a(PJVBk&5P5n)V zqx3V(b#QqkD3eBXxRO&ar(CDDkLsKBn{+017nv_n)t&GQdz;#a`e_-w`Y`o=vX@y< zgNo<$%+I+O?u(yk?PuuDqMCH7a6CKq+QeQbJP-KE$5xCoSYk@br}L)xT?tJyyjCH+ z;fTaVyTMpO`aMD%vnsCfjFAiRFpP1Ja;fAzgBG_+MBRuZgBo=X*M_4srYvm{1Zk5q zvg!O#cRB@eSl_n-Mf!~j!{hq$#8QC2nO>6_F-msSoa$8JxcW(M3hRP5zVTSHWx`s( z|0u+ZBQKvfB&H%4_-6S&H%lq@9x|;IH(BN%Xu=t{-1lyp%87dwirRqM9Z2D%{t#c2 z2zMxZN1j7dQ-)m(S^f;KgDXjZLJGYheHIdSu>U}!dx;oM?GxBCMMEjW09!POHs`?% zZP6R!h)XL~9eQ4@eH&-W{>_v*c!;EYQyq557?EG2Ks_Up+m2%DYO@?V^ji_*U`(P5 z)z}c}0y5I>JIEOf3YjJS)a|%pR2Mkio#+~DBN>g*ub|AgU$M8AQoJxoMP| zs-`Dt{$zioxuFpYJe5Wt;gtUffA0dtIei{Bj!qnL*n|63;4c@=JD9ng;4P191dXswFERz*%SGd{9III0z8*{~nxtdanTK&Ju6PJM-- zrZ&P?K)ndXZ4($?Hq2fAoE z?8NT0bEL^=dVRqjrz13{MGJu%bIz0-D`Xt zzdO&DYY~Py;;JrjvIgpE5XXBvS$R8+-W{u7Y)C*X!-3JMCi&X3u4W;~WhjNaV1CiNX`HTJ%FR0W+=EQLDnrF1K9K2U$MLst{} zZ@Meco}3cQR~@s#D^ z!s9^qzGi?kV(Ys>xC~kBhEK%chOt#ILp3<23ebl%D>}YE{X|9PetfzjnpYum*8EF= zHF{92vcSkMH&Qg~VMK(6Vulq<)#AkXl=kP$Q;r591j9zi2k_UB_RigFvzOS^Umg*6 zhez9l5uER~)Tvp@AKCmllSm!?SJXTd#(Zk*X!f9S<{VSbDy}f)DBT9wbTlq;jzHGY zM0qX?qkQ&f)XA`^&t$OTB#^4P@9Yl()z*^i#gWV~(n++8(rvU+b6{L4E$rdK z)Zm1QO>C6W+o*22MKd*2vNVBk;K082z;Du}3NEfs^U|e3nA9kNDRUV%7JkOHO@Z>P z9-aG%b4Xi8`V=;_vTB1`_b3r18Fde*vj9{iG#i(|fjveT3v(?0UlA>!6N@idwmN}~PwlXaE7dG7l+A%fQ_b0FE;Rrfz4z8I?pVWmGxUJtY*0^gR z1BOMY;LHLh&`b%+iB?9Acq z;V+u)b;c6a=TFQzJjWPsE0O$k7(3>Q?(`8KgPQ$4K0Q;7Nx5yx9cOcKwubJak!P)} zh4_6OlTCcAvN#a3YGR?uMaD?kjQ?MNWc$OcdZF|I()v3&b2LiwRS0%GhSmW;d#V+h z@fTf_w5?zCI%%<5lj;*jDd|=Y&6<6N>ecxsX+H8JZ6*1D72B4tU{hv8mfDxh8@03$ zakEV1Q&%n7Rsi4~$%jt*`97CEHT5)&I(mB6^Q^ez1@&AzMj2{ZLP4?@$}`9Ot>}2) zJKlL~>c{_rY|7L5Zk2$_;b||4~ zo;AJ~f>+%D1UUbL@%!}DHPO(=1^{6PYn{XN*EJY-Wb`PwgX{QHB4gKQh%t_0 zdiZ3P!P3wKc<7}^pkesRO*Y4_?1s_t11}xBMnJbVsB9Fxfp^3x)lj%QOUET)eZ;|w_ zR1X5@r<)z|hbJ+ObNqp3kC^RV1AuDhv+^?+nLZ^dDylZc7m<#cL&sSy1t}?q)FO!E zbQAP(ypv110HWW0T4**?bdVBey+7}P!!-$dCw+_;)fT-!jZ;lP3B7%ou7R>Iw#|>5Jh7g-6ZU ziM2`?UbmLQtYXy=)d(*K+tal3ogfy948{C54m!aoSP#DX7KPFP;qStjk1&eKo4y>> zk}c=h=j00G3I>53H4`Jmfv^eXnl&@Ziw~z2Z2$esU;ZXNeaE`H3}XGtx;sdqWm-pZ zqH8{BtK!c_$9qHjfy)B$pz&{;B)K?~{uPcl^?}aZ(!5avS;8T-vT1PJ18I3_kAQi+^i0^3^M*hP7lZ_B1jJjSNj8cx5ERZq@FBW? zoAIA98O8Fl?NDRbVm&~8#kLh|h0YNa82GF5N>yDbfi;anGixW@3cU&HIx;)@@NFyX z_9NI2jQy1fy`W0;D;RQD?-_iFvY9v(wpYv+h5g?r(yqxLkV-}V!}jmkv*JiZZs7#| zrCsYAJ)Ity6oq0`#K`g}^ahc2JW)EQgjxO<0jS^m62v3Cw0?`-(xNzL4fulp3hZbz zw$D%T>*m8kuqBua>xOo6ok;dWvhhO-JQ*&u5p2cEiC%z~^aYVSVJ}>drjJD|fmag2 zEdD#w6>|lC5=0+rPsQ|BCpIPaO=hso+1I=P=Lld{6hP^WlS* zxB9#xv28hp*o%xd41P3wQ?c9lZJ+El>(cX4&^TQF32+S-E_Oma`=4vayeZ<6WdV7i z`Fp$$X)+naF4W=qUdm$wq~0??(-aHVFkZ-XM@*8)7S=v;M)RN>dUhP({KO~86u#j_ zp$t#FTq_3kFXWV)Af43E)|(rw3RbER`W z)kZ#Tz6d20TWSf*f`dyiK^iLLg9VrCkea6{w&Jz&%4TItEX|Ufx1LiK%B%s%Da{>F z$gLJriD&AQgU5fDBVC%C)FE(5D)lWsg`8`OW)+iENoPGwC;fIJ-klP<3&`7g+!Qam=-%DNM0Jwj8koL;mG zIgq91d^Ww^D8ONbVV!Qef}8FN+RE~Auzse0wSfX$e=B7$Xwyf5rnD z2ll|fK|Rt!3iu-C4NZQGrcW;6ono)ZRnsm{^~QY^1vtF@LLaStXMpC1Z;34Cy5SGy zZDqO1{e0}r68i-O?@2Z%eN)f#JL&2U>04)A9EDyC3KAQ%y>VPl0PZ^p*!??iArrRK z5c9*3ZNO~;RX~EKeu}?1pap(>1H1ge0Hs{Al-k789qQE(tx&1~8FzhPng+q_9*p9T z1DT64U&`NPZL{b(1uAtfYA$KPQDvXV0YeT;sH#3&MJMlw4i<2>pMA5LoTyO^Pj84B z{&oHO2B;fgpc+oS0=PSQV2Z~1jxd8P!8qD1MwAs1t&_o-*nQ6xFa_#`v012T!Xi=| z`rvG6VZ+d(HrgqJ&mNzaa=O^wo|o*xS2Y7OoO;XBhQ@X@Qg!)k{O=>brqjN*X;}9A1fA3|&`cDB}9FSJCTFgo~={$-Nz4fVg44 z#AcV*tci7-#kH^L^&8&l{qew{hp(xuss5gHBXOiidUgpH! z$DeDS*BDU}fpjGG<|mtYc`Wtq3fTiy;P#F*X_}l$nyjshW%RO>JRna%MD{@LJnp_* z!EojCgrb_`dVU3FcG&P4p=$3^OwaEM=vuEU3e0_oOlv zB^;>pf&yeczdsFCeR_PY)Gs=rm{4>*nM}M9vu2-X{lGI1Td&22L!eDKrhIvXMr@XH zhnx~)Q-B*4KO??5vsx?pT4F#KlOq5QJ$NdZwmL8bJEkJ|IPP$jY)gw=HLP6GQ}B4u*z*K}tUx+u@hu|70l4 z8iYLyh9OELfvbtaR!4LeE>M77Nn-vJe-;!wJR{tq7Y0!j)0DU!MIi`USJJyLs0pbv z+%HyZM`QxZe@}wd7aNKQjtLrJosbAKFq~+R(2*fC2rE=~@Y|k%lLL-0?N8o80%RfD z1FVp@L8=oN$nY-C$57!x;&9@}0R3=Zl%A1IBZom$N@3a%{)i;S;d;Ucu?r-!aPWNA zTLB=N2`BOxGPG0mPrqPKL@g!n4p_xI7UbL01rpz~O$0#|AG_#+%%foT!^8K$niJFS z0kO=mwN44+hxFI~7mm=0RlE+*b}2wV75pZ4Hi zBm80hpB`8Q!Rsi5fqxZ*c(o5Ch`jU8O<0 z@UVPrb|U+*$#A+Gk3EtVPzBNt2p+C(2pwddU5)UL9jU=hJcH0O6oUW@RKotM8``~E zuB(7!)^#P+V z^7c(%f)~OY<#Xbe;YQ~xR)j%?ogWcFj_La-9xP5E{2{*Kd3`$Jc`m%L4kUh{e^GjZ ze<1e+21WeXv5DR8B}D#&^_RT}>*?zV_doW4-x7I)_KEkvjU5OJr{MI1^A~`DpdmDf7Q)aDC-CK7FYD43Yml?xR!b3c11*>rt9{uD0rpK^OW z{sK$`$0=+IHAM(ZeQs!Y33i-+dX-QeDH$Pl6F9I zyhybQ^G44T(w%Z`Ir3tzN~kDEt61TQ!n7$&jah^Tp`hnZ$!LI#Dx>tE>?r?t=ytl& zqhRU3AVv?GQB7zAF#NA;@*Qu!H^@N^iWwOWx{wk4=E z<6~oql8TX1s(0z`Nw5u28?&=(v#(@I7piBvaD3m3)@bq@=G5dL?z(*?`4%Lkn3*Yx z#~YNFY7IV7y-Lk~cc5K=R$I1t4D6%?U`kwMh5o7rLH>Rq%++_ zYVuKsiF|#+D{E8QAU31o%o!3qQ{IX(e|`$L9ops}Y$3@ft!9(QHK!NtKwW;y4fK4onWuPG$nl)DEAvfFk9!F`OH zP*Q$hlNuH;uePi$=UvkGS8Zj{64JD}#BbgzzZ64~M*Pca3}OWGr`XTxtXCjA1MM^p zb)nWk%ZfCahUVq{H3glA1)oyu87(F6RYl`DA}m(meFuRcP27_e;szLX0$|=jNYu|mlmbMZd9XOLQfl*bCT%A_OWVXH1ExXdg*BH-yAzON~CD-~h zY$^h~J_VJ1-dF*to{w+PoqZ5-$>eW?RiBs;j%y{rld-_#5C=kGaY&!k73!Uv4KAllMs?%_S{ks6_;L(v$OJ8N# z*n)UGv-`dD;Se$^nsO~gZpF5EGB?>S(}y42GLcEb1rOJx^LKp^qDJ2eHhv$Z_IBNC z*!jK*U_1rc>~qxmOx#lHIqgqm zIdriB_kT2IS;%=8|9GtW(dy>&NtTpJ2jo7G?&uSp=$NT?h~i>lmJuUbFSl1ms#Hy( zWW$C>wXTkBV*X42Ye0S|+Wa7ksGO3yvI3(tCE=Og+_<02B_d%pA0t!tjx>XDl6ILx zD?B8Vl4fF112Sn$nh8?Hi(rwT9W;>2&OZ`!JR-SN0H4Ka(!sE#@H3B2_ps5))w$Wg z0a06WhIwhuP|Fqtda=z(lKZ<)Q+FYLy>OL6sH){gRAXC1VQ2TZW7b&NCQ*jIlm<~G zB74u6T))d^cdfO}57HcGFV4*^epsC~Xt#I!gKV(8L0;zTxPB#dpsv&4=^qjcO!n)M z^eIX$<4rYAPX-fu@dVz`bj*Z|*CK(lhz6TG|vZ^cc$0Q=SqLn0voCsM}i~f~Fm=mv%<0g_)rqOHNT)1iC zIqSRC?vKVe+K@djyzt^1*+}_rQ?jL`>$mji-UwY0kwv~Rh`(RDz#fF4(dCd8z@&5{JD?60*ScZ z(9YfoJlmH3a%;Xv1EoVW>fwCyND`MGQhPSD42Z*ep8l_{)n&8Bv=5q{A$+tUA&04c z8SlQ}a18X-tqYh_jT3aXCyiUqI&opyNZxt%rWmAh2%gIPNzO<`xevTLb@TVmefrY@ zfdb}!x`yS>iMHk~C$fW=?xBepM%XIcCRpDcAvb1oLrM;fJ*T6F68I=^&k8kgj>K!3 zYT0_MK4$MsRId%J$qLnzHO{D~oPcG2G#7#%Z!ooJe&JSURl`&@jkYg(ZBT2r?HSEk zu{Exmn=y}(=cTX_g=tx;+lOhHnTKhU$TI)(!Ee!i53EYz$87|@kUDKg9YulV$c3Hx zG_NqEd&w@Bx`?GIV+j;fOe#z)2}>U(ac3um9bwjHp-9}B3W}x^R zNrO>s)pRyV|F$Vf8>xO_YFUCL@g>$FQRW&mk6Mhh@wV9uVs&z+sN$YW6C0*o1@fo1 zrp!+s7H6a@abc0d}J9M>L0_Rc1hB|;@PQ2v~cG(FhOC>hN3O}SQ&LCUA~!R=r**uR$>)d zFOo=dx3btGJF!S&Z$>YQh{MzhMAAe}HPba2m(EBSzG|3zI}qkgZwOgAoS1Os)GG)F z(5Od}(l4eEF0_*o*Ce+V8uh5t8&tG zrg1a;S;9S|Yi{v6{=VfI3Yp5)QkAXYxei8{fz*cR$v^(waekdU-53jj3aX$;-SD03 zRWhX|-L=B5+O@=$^4$Bj$$W|VSD|d*aoYDn^VeVGKfFHoQ%p9>YIupFA3>BRS8~m^-gXEm=3W)k8ghzFO2cnxB&r=s0>Y z(DR+Vw~Y4bF`5YV9p{ey;d3M7yqzoc_>@3y*I)D9R z&l41fTXWZ+p^5rK*Wg)Mun!~b0bMj^av%;fjxN%CSV$PHP7|GeW8a?A@oZZ^@2T~@ zu}spT)IH+9yj$4n1?#r?M>g!zhE7iB+HRKR|4 z2q-!TO@tQSa%(jmv}lC8OBzDJ}a%~z#Xt#7c69BZKRaWn439rFxZS=MN8 zorjl!mBb`*vpUOzy#nh`HI3HuO0=U%btW%jo*$2DdM$c3No;};qw3N`!<;wBPg`5< z?kf>Z$0J)zR;;Hxe~t>{{%Cv1S~qx>F|4v!Om>^LZ*$w)G>thg0Vj!*`p?U6D?SOsy+fP&P{#!!!FBKM{57Gzs))&mm*OvEase zvkQ1Ceq!N&^hvl4*$19G z%euiWVXCL-Vt%enyn>hu9$&7^bkgJ zJmqL|ai<7gNwHoLQufL`r|zYwuHASnd?n+uQ*JfCT)T6O_RDovInQkK211n;)Kf`~ z5qya0T3xK_?r)pFw#QT(q-xKe0|bXZoi(K>iId`(BJ)k}Uj~->$auJFffJ4*!ne*D zx29>amxWd2DUX;LQS@jahcf~4W!UfVVApy?^jo% zK(Suw{rOnD-|8sY?Jm5$>^rXQKGIOwuG_0$C+NDLK+6iR^V>kS9@>n1zx3_#+@D$Z zUX4@@qETkI&Xi4#l`yFRd1g5!EzqyzkY{Fty)O|P7EdBd!X~3q$B&7zP)*G<>V79V zZQ4)PxUeBSIXx#dCOpL7Becjn!@kb5H#o_;G9aq`%`2wR#^#8cHFSt`ZtFjRteT|+Gv|v%;v2Vjg%6K_PPbH3P;~E;|NaHbuPFEZNCIG#%PLr){nq% zx>eUsHyYt|FcQPQhm+X{n*4#9Y30MCGMi3I$3_Q$Qt058HUwsqNtspNDlG%?EIp-Z zYwK9~mPWPm7C`#w_pgrka~Hn*@$dOyjt_VJg;gQMU!1~Boy8gJtstYdRkhZ$k{Cav zT!sU7=_&VdkUl<~l6{?A>0%%h*-$=DGM?Xtpfm9X5T#~i^PaGy`;=22&)+job5y~g zYSPAhYt=72mIW!@=kJ)sZ_{NoI?|su0h;2_TMlQLgHdERl0q@OZ z&}~03ITO?}YLxx*2jqZtn&hf*9dyJ$oXhL~tLP2OkLkUh#nJ48&%mpqD1u_k@Qae(hy0`F-RkgfiGiGK|1y6 z9?y7POU+uV@vJXBJbN*@t`v;g>+n8n<(=b?;cs~ih550)^1(cH<91Oiy9b2Lk1sPO zsJ!dqmdGe7ES=)=lXlQa3{EB&wYRyQ?26E_#50S9m6+Jb)lXiZhfb9OV`&{}cy`at z@&Wtu=ucSA=EjvoRnFXDLo!XSCp7uQ+85t9H`!P*=e6ZEQticJq-STx<|qk_#X=Kj zGZd8(M5!=LhGg&W^7zGp5>KY3YDPvI0F>IG9Ufx84X134pb&*MR^$!mSCYkvNnXsx z@k>DJDQ|ccgU0n^YI8aey;BTrRWU0B#Ik48K*-QMuYbckqs&{JsQ2u09rg6Se@IbM zPo&!MsI{@OLiu#x>CKR3ay~-NDrp5gSQJT#57Fw4G^{_9A(^(vs<)yTXDwar&)UH! z9x7Q~=H;vO$(QHB#BiIS9xm~?Gx?h>OWq4=5^kyJI4ZqNj7)6M_{-kAtVAQ_g zhp>fTNllCa6Ts5);vZSPXfF|4^Y!wO#vDV@{Vwv*s286wXRe*r_*Ev#BC1*czFGCB zk#|}V)3iA#lo$mw$a%~MbNTL=OQ9{X8p)txkBaoGdUGR9w!J5kGtX6O%Ah^67L>t! z1V+}+#FHx~eNeHFnNlG{NemSYE9eNxCF7V)_blXBg>^Q zLF^VJ6-%rw*cpXuP5i#@lI|fck#|9z_g!i{f84%_{{?ohy2sbm&15kRxLaLaUM}d% z9GQc6_BN#E>Gi_M!%U3WctzqaWA61Lse}9}y_$bwamRs9|A7-#OYbu>>z3nL{XFWq zyg>-G+Af;5N0P(ue8!|}Zlu@aQ;go=@mh;s+tm5day)N9pU!@>MR)l$fApPZnE#>i zF-}>1D2+2I+f;{E1oZ&^wP~r+m)4aZZE=Y)Sx9ypZWXXM6t zX{m<1YNfQg9V7|&i;ZvpIYaN~bUZIdE9u5YC8Lvf|KKHzN*SNgp7mhzxJq62M()3i zQ-9hV(-#)7hbyqQFEW6m#>CGoG43qQi%F{(S7`R_D_yJjKU&+wS_Ng8b9C*qbwjgt zy|Q(;bF63o6tj2KNfqRpD4*!b!ZMX?t{FeN$R*FxU#j}14Ut%xW*?Y~w|?tPN4`x1 zqhuPHTEmvt;;4jDQ-air+HoJ1u&v}HV+^5Wl9@Tu?1vk{iAtrvw{QyMBHJ*wE$*Bmvykkc$ifwW7K}#?i&}c&;s$jAqkg^VRFA>uqhJ!Y%v*neK;Mcjw>Y&rq=S89%$TiF~88e1~V#gysR*-*fKS-qoJ(;D+73A)Kj!m zIGN3yKqgD96P%NdLAQw7DaSU?N}EcDT%&H!nwA;1?PUy~<=gtdCf6OS68I90aS&)i z7Lw7{vlGlt`Rux#qFIH$Mz1>RVditcfSab!c@6UpYypc&7~d4M0lcyRu;{4V&LHa9 zh4i(2_6p&%aIb*tKyN%BiEv* zMKvgKU?GUhgxWX#Gz5_I8|LG{vEzgjaij||Th>f*mAlxVWyihtYz?_-CkhNc&{Jz1 zB~`9-=WP5Q^-;%8#O3-M-i5jL8-MXkfu3({i=BAfad|jx%Xwx?*QU#f!DL5SIZt0s z57z(i=d1JNe$cVHzTePJfe5tT?r#i|I67>un~FD-6a<;|?E!WT@@|Y4B8%5=K;ia8 zb3_TGk-(F@1yS*JWD-~JX2F8 zWxDAE64UaQiww5Vo2OW5n5zsuEj{&Bwx$;)CZ5nYw(qz2brvip6b>AyRtR3~2DRc_ zSY75tDdi4IB2@hkT`><_%MM&Q%*>_zQ5!AN6qpXeCM!=6#F@oZu4x%~%%=W496$%< zN>5`=vrel|gT~q3L-0oNhVdEqHuj0|)fPgS-t|b!DZ1m)wuu`DTOR-l1X+bwu@~IK z8k6(Htkk=by$chc(N$;AtMTB&x(ws;k22lTWj8q%0b zcAY2^CavR=qzWm1J9`DvriJwo-ruN@U__bpWa%SH?{(z(1m$XkLtfsb5z>-|RLiIK z#M6~j&-5Q~2s4v>jA?29-g>E5hVANvLiH1j%A$Z67F_mSRH zVWMu4(9?k~wE)uk4Rd2`h;SV`=zgQFr;nl<$<=9Ty|eorStY&LlRYsao3p9hJjmq) zE;3-dHH=pVjd6R~o}N66_gK2Z@W7oL-w>Ek94bActf79K#a*r|2Hi^B9U7c%60Wh- zzv-_Mjn1+Z+PKq+pdIMS%s>x=4M`O1e_KU-IMNj1|BawE5lT=Qhkk&o#zp-K$Twpe zrx@eUl?3BRfh(%gs12pJPQF|@pN4Y%dq5$OxSf@n(h%1V>`Ufv{84auFjhJm!lJxt z+wmTnLNnGc;J@F;Bm#y0VcQ>5D?ht`WagqV<-#51(|XnBf49M5&wGQ?_FKT;%4AWD z=m+q5Z?Q;1R|7dbw$rG^`)%R&UF4%0o~_pO_gIAryWS)3{4dXlYs-QM08r8pGhJm} zR3*)UgL{rpmJEhVu3UCPBte-ZDG^4J)$<#2qZk<&C!5Ja>0TyPx~(?-0jf-^nXDYW ziFu+T{NnH#lv({;v8}48;J!s8s>WJ!P-Jyob3%m?SsVM@fd$DMnd)4j0?Hc}n-w|5 zBz4OIF$vj z^zLXEPRa!!&CSoK57N_+@Kuypx5 zmBt`h9DfX9yLYUJ4~+j{qVM*1TE&jj&Dn%lQ?!|U~k8igmsw=C@$)TU7LF30ob}CZIYBqQNrLJTQM6}Sp&^}6$r7T17 z@Kgw^8++9qnwC0=fk2GUvhe|Q!SHOZJv3v*{M3Cb#i&?HkLoW;`cfZeZ?Q@1dad$m zEYl#yzd!91Vc|52DgwDw97OJ%cx7C(t!=3vO_~(j?%ru~v5GRZWnb0y zhd24k4---o;6{?^ik_MrdaAC&ozU2;iw-Ri>jm~e0bRi))@h~n4i?=qz5wG(-tH;m zJXfRR6(d*y-}pr6^tYUrnZuQfTHC6K2;fH&7tVa1oR-RVCxTpsgg1Bp40gfq3h?~v z6;P>H9%{Lx<0j*?tjfHpD_>6@yQ2lqh0%vr&5CbB*{|}CrmuK^Oz%_`rsc$Za`@NY zxji+@p_68*x^ARiZyu8_EezC^`)`G(&s>5&UrU;3iqW9K~$!ER6HenJ$&74?|mJxyCpeATO&Eu(4;UP7L>0znjW*rycWGE5!A1s z%!Qao z;2u(5z4R88BzMdC@bYa&fm)xaYgD^D-X z>?hq}E-a;W@j8^7wSsETBZ!7wn;e3KgYIJ0{JW1J5{1w7&h7Z5k!k)E-q~V&fDii&{{lR?|5x zPIroy$2hHWtz^lQN|FQ@$iN_#?fhSToO?LbSs2H4f3mtzJ*{;4wIPY|yUg!0tI|Xm zjS-R+qlRWQGKLx2B)Zs5p;WqwXp>4NxmL7mQ;L*I7tuvePi5s%YPXv-?E9N;xc ze$V@!bKY~__xX{n2M$nni=uNwPfiiF@YX4H&vN?$ypp!{XV=Ysv8ZLGK&Z zGs($Y`{m^3nfw3PFDQL-cv#xma_)t-=|jnvMqPq+lf##$83{Ms!MJJh0ess!>Rh4w zsPQ2S-{%{pN%ZdT)8rJrxU!m);pR5JF>mmBq_b(?8r$k_(I)a`lPH?^${~K#YF|g* zn-po0nWv)c8_a*3DLL^^GqG6TPQT8Q_|G@xHnJHZ5Kt${}toYM(md$cR1u2aYvA zNPFY+pnmw!c9ZTU9Yt?`4R*Wgo?_fqR;#G+@|C$5vge0^2HM3!a#(^McbjCd!1IW0h|`gcGOJ7uVl&i9fI5KQ!NK zJ2*9_pj@SFT`bIhdhL1Sq#NbeL=!s;%kLI1nWU*d{7ZS4qrxV?`EJPaVSWuM8%IzB zEDgp?YYZ+ZE2?o#dT>aiQs?FR&93@T%X>1eJ9b7<`s3yuCV^dl7r(ptk`O$0o4LEf zXpW0t+o<|3=ijUhO-f51Rk&rgJZm>@>XKAlt)4Jx(zu3~!>o}Bc3WDA;>W`HQtbql zK+O(vm1z>kzv4mQLwj&LJ&6IEStfXCQ&hh(ZEiiX%0icor&QN z5FAHHlHy<_P5{?|06QG#U=+=x#V~>R0`mZ@;oc!OilR6?9**(E2*n^2uZJg%AvCZe zzN|!32+f1VKdpr@Z~}%PI4HaVCPo+xVTc~)4D1A0m@m^kkJ-ZvMZTU69?ZcTAyGEO z&+pWOi}8OT2Li=;NI4PRTCCzmoh)jpsB@RgC8;(G^=4N1$j{=IJHt#87GuLUnEKy4 zo*0*Nn!lJ*ccRI6ApA5wkBuu+M6XcJ(Z%z z&kf{mDvX&L<5wElPOdS|%#?&W+FY~D|oku$^J8ypmA<;!AT9iF`Z$)AkVE&1pN zal;yuI@-9k8EI<^t2Vxiaflk@UST*BvCL|3p6{;>ePWy}*Ce@JIk4-WT;>mWvQ9!;v|4NlqrVlku=tF6L8VE&_1c}IyFB%3w6Luh>pEV4{08-Lh zLopO0>!o1~NdXq6mxiJk7n;L28jb^Rq_>8#<<&=HpwWC|M-bnh1*1sOtK1OQ@`$_RvIr9hX`H5In8v5ih0AyGx>?hj*V6xTO1 Jo9pJN|1a#IQK0|; literal 0 HcmV?d00001 From 095d47cf2ad9bfbcddb3e0738d42aac260a9e83d Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Feb 2026 16:29:16 +0100 Subject: [PATCH 2/5] revert some changes --- .../google_genai/chat/chat_generator.py | 114 +++++++++++++++--- .../generators/google_genai/chat/utils.py | 80 ------------ .../google_genai/tests/test_chat_generator.py | 8 -- .../tests/test_chat_generator_utils.py | 1 - 4 files changed, 97 insertions(+), 106 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py index 01ee67c314..fdbab2dade 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py @@ -2,12 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 +from collections.abc import AsyncIterator from typing import Any, Literal from google.genai import types +from haystack import logging from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses import ComponentInfo, StreamingCallbackT, select_streaming_callback +from haystack.dataclasses import AsyncStreamingCallbackT, ComponentInfo, StreamingCallbackT, select_streaming_callback from haystack.dataclasses.chat_message import ChatMessage, ChatRole from haystack.tools import ( ToolsType, @@ -23,11 +25,11 @@ _convert_google_genai_response_to_chatmessage, _convert_message_to_google_genai_format, _convert_tools_to_google_genai_format, - _handle_streaming_response, - _handle_streaming_response_async, _process_thinking_config, ) +logger = logging.getLogger(__name__) + @component class GoogleGenAIChatGenerator: @@ -237,6 +239,96 @@ def from_dict(cls, data: dict[str, Any]) -> "GoogleGenAIChatGenerator": init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"]) return default_from_dict(cls, data) + async def _handle_streaming_response_async( + self, response_stream: AsyncIterator[types.GenerateContentResponse], streaming_callback: AsyncStreamingCallbackT + ) -> dict[str, list[ChatMessage]]: + """ + Handle async streaming response from Google Gen AI generate_content_stream. + :param response_stream: The async streaming response from generate_content_stream. + :param streaming_callback: The async callback function for streaming chunks. + :returns: A dictionary with the replies. + """ + component_info = ComponentInfo.from_component(self) + + try: + chunks = [] + + i = 0 + chunk = None + async for chunk in response_stream: + i += 1 + + streaming_chunk = self._convert_google_chunk_to_streaming_chunk( + chunk=chunk, index=i, component_info=component_info + ) + chunks.append(streaming_chunk) + + # Stream the chunk + await streaming_callback(streaming_chunk) + + # Use custom aggregation that supports reasoning content + message = GoogleGenAIChatGenerator._aggregate_streaming_chunks_with_reasoning(chunks) + return {"replies": [message]} + + except Exception as e: + msg = f"Error in async streaming response: {e}" + raise RuntimeError(msg) from e + + @staticmethod + def _process_thinking_config(generation_kwargs: dict[str, Any]) -> dict[str, Any]: + """ + Process thinking configuration from generation_kwargs. + :param generation_kwargs: The generation configuration dictionary. + :returns: Updated generation_kwargs with thinking_config if applicable. + """ + if "thinking_budget" in generation_kwargs: + thinking_budget = generation_kwargs.pop("thinking_budget") + + # Basic type validation + if not isinstance(thinking_budget, int): + logger.warning( + f"Invalid thinking_budget type: {type(thinking_budget)}. Expected int, using dynamic allocation." + ) + # fall back to default: dynamic thinking budget allocation + thinking_budget = -1 + + # Create thinking config + thinking_config = types.ThinkingConfig(thinking_budget=thinking_budget, include_thoughts=True) + generation_kwargs["thinking_config"] = thinking_config + + if "thinking_level" in generation_kwargs: + thinking_level = generation_kwargs.pop("thinking_level") + + # Basic type validation + if not isinstance(thinking_level, str): + logger.warning( + f"Invalid thinking_level type: {type(thinking_level).__name__}. Expected str, " + f"falling back to THINKING_LEVEL_UNSPECIFIED." + ) + thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + else: + # Convert to uppercase for case-insensitive matching + thinking_level_upper = thinking_level.upper() + + # Check if the uppercase value is a valid ThinkingLevel enum member + valid_levels = [level.value for level in types.ThinkingLevel] + if thinking_level_upper not in valid_levels: + logger.warning( + f"Invalid thinking_level value: '{thinking_level}'. " + f"Must be one of: {valid_levels} (case-insensitive). " + "Falling back to THINKING_LEVEL_UNSPECIFIED." + ) + thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + else: + # Parse valid string to ThinkingLevel enum + thinking_level = types.ThinkingLevel(thinking_level_upper) + + # Create thinking config with level + thinking_config = types.ThinkingConfig(thinking_level=thinking_level, include_thoughts=True) + generation_kwargs["thinking_config"] = thinking_config + + return generation_kwargs + @component.output_types(replies=list[ChatMessage]) def run( self, @@ -319,13 +411,7 @@ def run( contents=contents, config=config, ) - component_info = ComponentInfo.from_component(self) - return _handle_streaming_response( - component_info=component_info, - response_stream=response_stream, - streaming_callback=streaming_callback, - model=self._model, - ) + return self._handle_streaming_response(response_stream, streaming_callback) else: # Use non-streaming response = self._client.models.generate_content( @@ -435,13 +521,7 @@ async def run_async( contents=contents, config=config, ) - component_info = ComponentInfo.from_component(self) - return await _handle_streaming_response_async( - component_info=component_info, - response_stream=response_stream, - streaming_callback=streaming_callback, - model=self._model, - ) + return await self._handle_streaming_response_async(response_stream, streaming_callback) else: # Use non-streaming response = await self._client.aio.models.generate_content( diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py index 366d74f86b..3b35438e92 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py @@ -4,7 +4,6 @@ import base64 import json -from collections.abc import AsyncIterator, Iterator from datetime import datetime, timezone from typing import Any @@ -13,12 +12,10 @@ from haystack import logging from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.dataclasses import ( - AsyncStreamingCallbackT, ComponentInfo, FileContent, FinishReason, ImageContent, - StreamingCallbackT, StreamingChunk, TextContent, ToolCall, @@ -686,80 +683,3 @@ def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) -> ) return message - - -def _handle_streaming_response( - component_info: ComponentInfo, - response_stream: Iterator[types.GenerateContentResponse], - streaming_callback: StreamingCallbackT, - model: str, -) -> dict[str, list[ChatMessage]]: - """ - Handle streaming response from Google Gen AI generate_content_stream. - :param component_info: The component info. - :param response_stream: The streaming response from generate_content_stream. - :param streaming_callback: The callback function for streaming chunks. - :param model: The model name. - :returns: A dictionary with the replies. - """ - - try: - chunks = [] - - chunk = None - for i, chunk in enumerate(response_stream): - streaming_chunk = _convert_google_chunk_to_streaming_chunk( - chunk=chunk, index=i, component_info=component_info, model=model - ) - chunks.append(streaming_chunk) - - # Stream the chunk - streaming_callback(streaming_chunk) - - # Use custom aggregation that supports reasoning content - message = _aggregate_streaming_chunks_with_reasoning(chunks) - return {"replies": [message]} - - except Exception as e: - msg = f"Error in streaming response: {e}" - raise RuntimeError(msg) from e - - -async def _handle_streaming_response_async( - component_info: ComponentInfo, - response_stream: AsyncIterator[types.GenerateContentResponse], - streaming_callback: AsyncStreamingCallbackT, - model: str, -) -> dict[str, list[ChatMessage]]: - """ - Handle async streaming response from Google Gen AI generate_content_stream. - :param component_info: The component info. - :param response_stream: The async streaming response from generate_content_stream. - :param streaming_callback: The async callback function for streaming chunks. - :param model: The model name. - :returns: A dictionary with the replies. - """ - - try: - chunks = [] - - i = 0 - chunk = None - async for chunk in response_stream: - i += 1 - - streaming_chunk = _convert_google_chunk_to_streaming_chunk( - chunk=chunk, index=i, component_info=component_info, model=model - ) - chunks.append(streaming_chunk) - - # Stream the chunk - await streaming_callback(streaming_chunk) - - # Use custom aggregation that supports reasoning content - message = _aggregate_streaming_chunks_with_reasoning(chunks) - return {"replies": [message]} - - except Exception as e: - msg = f"Error in async streaming response: {e}" - raise RuntimeError(msg) from e diff --git a/integrations/google_genai/tests/test_chat_generator.py b/integrations/google_genai/tests/test_chat_generator.py index 58c99936da..f6d0f52e7c 100644 --- a/integrations/google_genai/tests/test_chat_generator.py +++ b/integrations/google_genai/tests/test_chat_generator.py @@ -27,14 +27,6 @@ ) -@pytest.fixture -def chat_messages(): - return [ - ChatMessage.from_system("You are a helpful assistant"), - ChatMessage.from_user("What's the capital of France"), - ] - - def weather(city: str): """Get weather information for a city.""" return f"Weather in {city}: 22°C, sunny" diff --git a/integrations/google_genai/tests/test_chat_generator_utils.py b/integrations/google_genai/tests/test_chat_generator_utils.py index 04b2abbd1c..719d7a3360 100644 --- a/integrations/google_genai/tests/test_chat_generator_utils.py +++ b/integrations/google_genai/tests/test_chat_generator_utils.py @@ -21,7 +21,6 @@ from haystack_integrations.components.generators.google_genai.chat.chat_generator import ( GoogleGenAIChatGenerator, ) - from haystack_integrations.components.generators.google_genai.chat.utils import ( _aggregate_streaming_chunks_with_reasoning, _convert_google_chunk_to_streaming_chunk, From 813cee2b6770babc591d6b89e3f567a5cfd109e8 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Feb 2026 16:41:00 +0100 Subject: [PATCH 3/5] clean more --- .../google_genai/chat/chat_generator.py | 97 ++++++++----------- 1 file changed, 38 insertions(+), 59 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py index fdbab2dade..715cd2a564 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterator from typing import Any, Literal from google.genai import types @@ -22,6 +22,8 @@ from haystack_integrations.components.common.google_genai.utils import _get_client from haystack_integrations.components.generators.google_genai.chat.utils import ( + _aggregate_streaming_chunks_with_reasoning, + _convert_google_chunk_to_streaming_chunk, _convert_google_genai_response_to_chatmessage, _convert_message_to_google_genai_format, _convert_tools_to_google_genai_format, @@ -239,6 +241,38 @@ def from_dict(cls, data: dict[str, Any]) -> "GoogleGenAIChatGenerator": init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"]) return default_from_dict(cls, data) + def _handle_streaming_response( + self, response_stream: Iterator[types.GenerateContentResponse], streaming_callback: StreamingCallbackT + ) -> dict[str, list[ChatMessage]]: + """ + Handle streaming response from Google Gen AI generate_content_stream. + :param response_stream: The streaming response from generate_content_stream. + :param streaming_callback: The callback function for streaming chunks. + :returns: A dictionary with the replies. + """ + component_info = ComponentInfo.from_component(self) + + try: + chunks = [] + + chunk = None + for i, chunk in enumerate(response_stream): + streaming_chunk = _convert_google_chunk_to_streaming_chunk( + chunk=chunk, index=i, component_info=component_info, model=self._model + ) + chunks.append(streaming_chunk) + + # Stream the chunk + streaming_callback(streaming_chunk) + + # Use custom aggregation that supports reasoning content + message = _aggregate_streaming_chunks_with_reasoning(chunks) + return {"replies": [message]} + + except Exception as e: + msg = f"Error in streaming response: {e}" + raise RuntimeError(msg) from e + async def _handle_streaming_response_async( self, response_stream: AsyncIterator[types.GenerateContentResponse], streaming_callback: AsyncStreamingCallbackT ) -> dict[str, list[ChatMessage]]: @@ -258,8 +292,8 @@ async def _handle_streaming_response_async( async for chunk in response_stream: i += 1 - streaming_chunk = self._convert_google_chunk_to_streaming_chunk( - chunk=chunk, index=i, component_info=component_info + streaming_chunk = _convert_google_chunk_to_streaming_chunk( + chunk=chunk, index=i, component_info=component_info, model=self._model ) chunks.append(streaming_chunk) @@ -267,68 +301,13 @@ async def _handle_streaming_response_async( await streaming_callback(streaming_chunk) # Use custom aggregation that supports reasoning content - message = GoogleGenAIChatGenerator._aggregate_streaming_chunks_with_reasoning(chunks) + message = _aggregate_streaming_chunks_with_reasoning(chunks) return {"replies": [message]} except Exception as e: msg = f"Error in async streaming response: {e}" raise RuntimeError(msg) from e - @staticmethod - def _process_thinking_config(generation_kwargs: dict[str, Any]) -> dict[str, Any]: - """ - Process thinking configuration from generation_kwargs. - :param generation_kwargs: The generation configuration dictionary. - :returns: Updated generation_kwargs with thinking_config if applicable. - """ - if "thinking_budget" in generation_kwargs: - thinking_budget = generation_kwargs.pop("thinking_budget") - - # Basic type validation - if not isinstance(thinking_budget, int): - logger.warning( - f"Invalid thinking_budget type: {type(thinking_budget)}. Expected int, using dynamic allocation." - ) - # fall back to default: dynamic thinking budget allocation - thinking_budget = -1 - - # Create thinking config - thinking_config = types.ThinkingConfig(thinking_budget=thinking_budget, include_thoughts=True) - generation_kwargs["thinking_config"] = thinking_config - - if "thinking_level" in generation_kwargs: - thinking_level = generation_kwargs.pop("thinking_level") - - # Basic type validation - if not isinstance(thinking_level, str): - logger.warning( - f"Invalid thinking_level type: {type(thinking_level).__name__}. Expected str, " - f"falling back to THINKING_LEVEL_UNSPECIFIED." - ) - thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - else: - # Convert to uppercase for case-insensitive matching - thinking_level_upper = thinking_level.upper() - - # Check if the uppercase value is a valid ThinkingLevel enum member - valid_levels = [level.value for level in types.ThinkingLevel] - if thinking_level_upper not in valid_levels: - logger.warning( - f"Invalid thinking_level value: '{thinking_level}'. " - f"Must be one of: {valid_levels} (case-insensitive). " - "Falling back to THINKING_LEVEL_UNSPECIFIED." - ) - thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - else: - # Parse valid string to ThinkingLevel enum - thinking_level = types.ThinkingLevel(thinking_level_upper) - - # Create thinking config with level - thinking_config = types.ThinkingConfig(thinking_level=thinking_level, include_thoughts=True) - generation_kwargs["thinking_config"] = thinking_config - - return generation_kwargs - @component.output_types(replies=list[ChatMessage]) def run( self, From 10140a51be2a0adcf259562191597b438519d94a Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 24 Feb 2026 12:31:50 +0100 Subject: [PATCH 4/5] implement suggestions --- .../generators/google_genai/chat/chat_generator.py | 13 ++++++++++++- .../generators/google_genai/chat/utils.py | 5 ++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py index 715cd2a564..85c474198a 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py @@ -138,6 +138,18 @@ def weather_function(city: str): messages = [ChatMessage.from_user("What's the weather in Paris?")] response = chat_generator_with_tools.run(messages=messages) ``` + + ### Usage example with FileContent embedded in a ChatMessage + + ```python + from haystack.dataclasses import ChatMessage, FileContent + from haystack_integrations.components.generators.google_genai import GoogleGenAIChatGenerator + + file_content = FileContent.from_url("https://arxiv.org/pdf/2309.08632") + chat_message = ChatMessage.from_user(content_parts=[file_content, "Summarize this paper in 100 words."]) + chat_generator = GoogleGenAIChatGenerator() + response = chat_generator.run(messages=[chat_message]) + ``` """ def __init__( @@ -255,7 +267,6 @@ def _handle_streaming_response( try: chunks = [] - chunk = None for i, chunk in enumerate(response_stream): streaming_chunk = _convert_google_chunk_to_streaming_chunk( chunk=chunk, index=i, component_info=component_info, model=self._model diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py index 3b35438e92..5e66ad0724 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py @@ -233,9 +233,9 @@ def _convert_message_to_google_genai_format(message: ChatMessage) -> types.Conte msg = f"{cls_name} is only supported for user messages" raise ValueError(msg) - # Validate image MIME type and format + # MIME type validation: must be provided and, for images, one of the supported types if not content_part.mime_type: - msg = f"Mime type is required to use {cls_name} with GoogleGenAIChatGenerator" + msg = f"MIME type is required to use {cls_name} with GoogleGenAIChatGenerator" raise ValueError(msg) if ( @@ -256,7 +256,6 @@ def _convert_message_to_google_genai_format(message: ChatMessage) -> types.Conte ) bytes_data = base64.b64decode(base64_data) - # Create Part using from_bytes method file_part = types.Part.from_bytes(data=bytes_data, mime_type=content_part.mime_type) parts.append(file_part) From 15ab505a4246ed5a1e15d3d479027da331d01b2e Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 24 Feb 2026 12:39:52 +0100 Subject: [PATCH 5/5] improved test --- .../google_genai/tests/test_chat_generator.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/integrations/google_genai/tests/test_chat_generator.py b/integrations/google_genai/tests/test_chat_generator.py index f6d0f52e7c..2a15ba4c1d 100644 --- a/integrations/google_genai/tests/test_chat_generator.py +++ b/integrations/google_genai/tests/test_chat_generator.py @@ -268,7 +268,23 @@ def test_live_run_with_file_content(self, test_files_path): assert message.is_from(ChatRole.ASSISTANT) assert message.text - assert "no" in message.text.lower() + indicates_no = any( + phrase in message.text.lower() + for phrase in ( + "no", + "nope", + "not about", + "not a paper about", + "it is not", + "it's not", + "the answer is no", + "does not", + "doesn't", + "negative", + ) + ) + + assert indicates_no is True def test_live_run_streaming(self): component = GoogleGenAIChatGenerator()