diff --git a/integrations/google_genai/pyproject.toml b/integrations/google_genai/pyproject.toml index b1b16959fb..49f4ae793a 100644 --- a/integrations/google_genai/pyproject.toml +++ b/integrations/google_genai/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.23.0", "google-genai[aiohttp]>=1.51.0", "jsonref>=1.0.0"] +dependencies = ["haystack-ai>=2.24.1", "google-genai[aiohttp]>=1.51.0", "jsonref>=1.0.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_genai#readme" diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py index 11c26970bb..85c474198a 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py @@ -2,32 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 -import base64 -import json from collections.abc import AsyncIterator, Iterator -from datetime import datetime, timezone from typing import Any, Literal from google.genai import types -from google.genai.types import GenerateContentResponseUsageMetadata, UsageMetadata from haystack import logging -from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses import ( - AsyncStreamingCallbackT, - ComponentInfo, - FinishReason, - ImageContent, - StreamingCallbackT, - StreamingChunk, - TextContent, - ToolCall, - ToolCallDelta, - ToolCallResult, - select_streaming_callback, -) -from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ReasoningContent +from haystack.dataclasses import AsyncStreamingCallbackT, ComponentInfo, StreamingCallbackT, select_streaming_callback +from haystack.dataclasses.chat_message import ChatMessage, ChatRole from haystack.tools import ( ToolsType, _check_duplicate_tool_names, @@ -36,410 +19,19 @@ serialize_tools_or_toolset, ) from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable -from jsonref import replace_refs from haystack_integrations.components.common.google_genai.utils import _get_client -from haystack_integrations.components.generators.google_genai.chat.utils import remove_key_from_schema - -# Mapping from Google GenAI finish reasons to Haystack FinishReason values -FINISH_REASON_MAPPING: dict[str, FinishReason] = { - "STOP": "stop", - "MAX_TOKENS": "length", - "SAFETY": "content_filter", - "BLOCKLIST": "content_filter", - "PROHIBITED_CONTENT": "content_filter", - "SPII": "content_filter", - "IMAGE_SAFETY": "content_filter", -} - +from haystack_integrations.components.generators.google_genai.chat.utils import ( + _aggregate_streaming_chunks_with_reasoning, + _convert_google_chunk_to_streaming_chunk, + _convert_google_genai_response_to_chatmessage, + _convert_message_to_google_genai_format, + _convert_tools_to_google_genai_format, + _process_thinking_config, +) logger = logging.getLogger(__name__) -# Google AI supported image MIME types based on documentation -# https://ai.google.dev/gemini-api/docs/image-understanding?lang=python -GOOGLE_AI_SUPPORTED_MIME_TYPES = { - "image/png": "png", - "image/jpeg": "jpeg", - "image/jpg": "jpeg", # Common alias - "image/webp": "webp", - "image/heic": "heic", - "image/heif": "heif", -} - - -def _convert_message_to_google_genai_format(message: ChatMessage) -> types.Content: - """ - Converts a Haystack ChatMessage to Google Gen AI Content format. - - :param message: The Haystack ChatMessage to convert. - :returns: Google Gen AI Content object. - """ - # Check if message has content - if not message._content: - msg = "A `ChatMessage` must contain at least one content part." - raise ValueError(msg) - - parts = [] - - # Check if this message has thought signatures from a previous response - # These need to be reconstructed in their original part structure - thought_signatures = message.meta.get("thought_signatures", []) if message.meta else [] - - # If we have thought signatures, we need to reconstruct the exact part structure - # from the previous assistant response to maintain multi-turn thinking context - if thought_signatures and message.is_from(ChatRole.ASSISTANT): - # Track which tool calls we've used (to handle multiple tool calls) - tool_call_index = 0 - - # Reconstruct parts with their original thought signatures - for sig_info in thought_signatures: - part_dict: dict[str, Any] = {} - - # Check what type of content this part had - if sig_info.get("has_text"): - # Find the corresponding text content - if sig_info.get("is_thought"): - # This was a thought part - find it in reasoning content - if message.reasoning: - part_dict["text"] = message.reasoning.reasoning_text - part_dict["thought"] = True - else: - # Regular text part - part_dict["text"] = message.text or "" - - if sig_info.get("has_function_call"): - # Find the corresponding tool call by index - if message.tool_calls and tool_call_index < len(message.tool_calls): - tool_call = message.tool_calls[tool_call_index] - part_dict["function_call"] = types.FunctionCall( - id=tool_call.id, name=tool_call.tool_name, args=tool_call.arguments - ) - tool_call_index += 1 # Move to next tool call for next part - - # Add the thought signature to preserve context - part_dict["thought_signature"] = sig_info["signature"] - - parts.append(types.Part(**part_dict)) - - # If we reconstructed from signatures, we're done - if parts: - role = "model" # Assistant messages with signatures are always from the model - return types.Content(role=role, parts=parts) - - # Standard processing for messages without thought signatures - for content_part in message._content: - if isinstance(content_part, TextContent): - # Only add text parts that are not empty to avoid unnecessary empty text parts - if content_part.text.strip(): - parts.append(types.Part(text=content_part.text)) - - elif isinstance(content_part, ImageContent): - if not message.is_from(ChatRole.USER): - msg = "Image content is only supported for user messages" - raise ValueError(msg) - - # Validate image MIME type and format - if not content_part.mime_type: - msg = "Image MIME type could not be determined. Please provide a valid image with detectable format." - raise ValueError(msg) - - if content_part.mime_type not in GOOGLE_AI_SUPPORTED_MIME_TYPES: - supported_types = list(GOOGLE_AI_SUPPORTED_MIME_TYPES.keys()) - msg = ( - f"Unsupported image MIME type: {content_part.mime_type}. " - f"Google AI supports the following MIME types: {supported_types}" - ) - raise ValueError(msg) - - # Use inline image data approach - try: - # ImageContent already has base64 data, decode it for bytes - image_bytes = base64.b64decode(content_part.base64_image) - - # Create Part using from_bytes method - image_part = types.Part.from_bytes(data=image_bytes, mime_type=content_part.mime_type) - parts.append(image_part) - - except Exception as e: - msg = f"Failed to process image data: {e}" - raise RuntimeError(msg) from e - - elif isinstance(content_part, ToolCall): - parts.append( - types.Part( - function_call=types.FunctionCall( - id=content_part.id, name=content_part.tool_name, args=content_part.arguments - ) - ) - ) - - elif isinstance(content_part, ToolCallResult): - if isinstance(content_part.result, str): - parts.append( - types.Part( - function_response=types.FunctionResponse( - id=content_part.origin.id, - name=content_part.origin.tool_name, - response={"result": content_part.result}, - ) - ) - ) - elif isinstance(content_part.result, list): - tool_call_result_parts: list[types.FunctionResponsePart] = [] - for item in content_part.result: - if isinstance(item, TextContent): - tool_call_result_parts.append( - types.FunctionResponsePart( - inline_data=types.FunctionResponseBlob( - data=item.text.encode("utf-8"), mime_type="text/plain" - ), - ) - ) - elif isinstance(item, ImageContent): - tool_call_result_parts.append( - types.FunctionResponsePart( - inline_data=types.FunctionResponseBlob( - data=base64.b64decode(item.base64_image), mime_type=item.mime_type - ), - ) - ) - else: - msg = ( - "Unsupported content type in tool call result list. " - "Only TextContent and ImageContent are supported." - ) - raise ValueError(msg) - parts.append( - types.Part( - function_response=types.FunctionResponse( - id=content_part.origin.id, - name=content_part.origin.tool_name, - parts=tool_call_result_parts, - # the response field is mandatory, but in this case the LLM just needs multimodal parts - response={"result": ""}, - ) - ) - ) - else: - msg = "Unsupported content type in tool call result" - raise ValueError(msg) - elif isinstance(content_part, ReasoningContent): - # Reasoning content is for human transparency only, not for maintaining LLM context - # Thought signatures (stored in message.meta) handle context preservation - # Leave this here so we don't implement reasoning content handling in the future accidentally - pass - - # Determine role - if message.is_from(ChatRole.USER) or message.tool_call_results: - role = "user" - elif message.is_from(ChatRole.ASSISTANT): - role = "model" - elif message.is_from(ChatRole.SYSTEM): - # System messages will be handled separately as system instruction - # When we convert a list of ChatMessage to be sent to google genai, - # we need to handle system messages separately as system instruction and we only take the first message - # as the system instruction - if it is present. - # - # If we find any additional system messages, we will treat them as user messages - role = "user" - else: - msg = f"Unsupported message role: {message._role}" - raise ValueError(msg) - - return types.Content(role=role, parts=parts) - - -def _sanitize_tool_schema(tool_schema: dict[str, Any]) -> dict[str, Any]: - """ - Sanitizes a tool schema to remove any keys that are not supported by Google Gen AI. - - Google Gen AI does not support additionalProperties, $schema, $defs, or $ref in the tool schema. - - :param tool_schema: The tool schema to sanitize. - :returns: The sanitized tool schema. - """ - # google Gemini does not support additionalProperties and $schema in the tool schema - sanitized_schema = remove_key_from_schema(tool_schema, "additionalProperties") - sanitized_schema = remove_key_from_schema(sanitized_schema, "$schema") - # expand $refs in the tool schema - expanded_schema = replace_refs(sanitized_schema) - # and remove the $defs key leaving the rest of the schema - final_schema = remove_key_from_schema(expanded_schema, "$defs") - - if not isinstance(final_schema, dict): - msg = "Tool schema must be a dictionary after sanitization" - raise ValueError(msg) - - return final_schema - - -def _convert_tools_to_google_genai_format(tools: ToolsType) -> list[types.Tool]: - """ - Converts a list of Haystack Tools, Toolsets, or a mix to Google Gen AI Tool format. - - :param tools: List of Haystack Tool and/or Toolset objects, or a single Toolset. - :returns: List of Google Gen AI Tool objects. - """ - # Flatten Tools and Toolsets into a single list of Tools - flattened_tools = flatten_tools_or_toolsets(tools) - - function_declarations: list[types.FunctionDeclaration] = [] - for tool in flattened_tools: - parameters = _sanitize_tool_schema(tool.parameters) - function_declarations.append( - types.FunctionDeclaration( - name=tool.name, description=tool.description, parameters=types.Schema(**parameters) - ) - ) - - # Return a single Tool object with all function declarations as in the Google GenAI docs - # we could also return multiple Tool objects, doesn't seem to make a difference - # revisit this decision - return [types.Tool(function_declarations=function_declarations)] - - -def _convert_usage_metadata_to_serializable( - usage_metadata: UsageMetadata | GenerateContentResponseUsageMetadata | None, -) -> dict[str, Any]: - """Build a JSON-serializable usage dict from a UsageMetadata object. - - Iterates over known UsageMetadata attribute names and adds each non-None value - in serialized form. Full list of fields: https://ai.google.dev/api/generate-content#UsageMetadata - """ - - def serialize(val: Any) -> Any: - if val is None: - return None - if isinstance(val, (str, int, float, bool)): - return val - if isinstance(val, list): - return [serialize(item) for item in val] - token_count = getattr(val, "token_count", None) or getattr(val, "tokenCount", None) - if hasattr(val, "modality") and token_count is not None: - mod = getattr(val, "modality", None) - mod_str = getattr(mod, "value", getattr(mod, "name", str(mod))) if mod is not None else None - return {"modality": mod_str, "token_count": token_count} - if hasattr(val, "name"): - return getattr(val, "value", getattr(val, "name", val)) - return val - - if not usage_metadata: - return {} - - _usage_attr_names = ( - "prompt_token_count", - "candidates_token_count", - "total_token_count", - "cache_tokens_details", - "candidates_tokens_details", - "prompt_tokens_details", - "tool_use_prompt_token_count", - "tool_use_prompt_tokens_details", - ) - result: dict[str, Any] = {} - for attr in _usage_attr_names: - val = getattr(usage_metadata, attr, None) - if val is not None: - result[attr] = serialize(val) - return result - - -def _convert_google_genai_response_to_chatmessage(response: types.GenerateContentResponse, model: str) -> ChatMessage: - """ - Converts a Google Gen AI response to a Haystack ChatMessage. - - :param response: The response from Google Gen AI. - :param model: The model name. - :returns: A Haystack ChatMessage. - """ - text_parts = [] - tool_calls = [] - reasoning_parts = [] - thought_signatures = [] # Store thought signatures for multi-turn context - - # Extract text, function calls, thoughts, and thought signatures from response - finish_reason = None - if response.candidates: - candidate = response.candidates[0] - finish_reason = getattr(candidate, "finish_reason", None) - if candidate.content is not None and candidate.content.parts is not None: - for i, part in enumerate(candidate.content.parts): - # Check for thought signature on this part - if hasattr(part, "thought_signature") and part.thought_signature: - # Store the thought signature with its part index for reconstruction - thought_signatures.append( - { - "part_index": i, - "signature": part.thought_signature, - "has_text": part.text is not None, - "has_function_call": part.function_call is not None, - "is_thought": hasattr(part, "thought") and part.thought, - } - ) - - if part.text is not None and not (hasattr(part, "thought") and part.thought): - text_parts.append(part.text) - if part.function_call is not None: - tool_call = ToolCall( - tool_name=part.function_call.name or "", - arguments=dict(part.function_call.args) if part.function_call.args else {}, - id=part.function_call.id, - ) - tool_calls.append(tool_call) - # Handle thought parts for Gemini 2.5 series - if hasattr(part, "thought") and part.thought: - # Extract thought content - if part.text: - reasoning_parts.append(part.text) - - # Combine text parts - text = " ".join(text_parts) if text_parts else "" - - usage_metadata = response.usage_metadata - - # Create usage metadata including thinking tokens if available - usage = { - "prompt_tokens": getattr(usage_metadata, "prompt_token_count", 0), - "completion_tokens": getattr(usage_metadata, "candidates_token_count", 0), - "total_tokens": getattr(usage_metadata, "total_token_count", 0), - } - - # Add thinking token count if available - if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: - usage["thoughts_token_count"] = usage_metadata.thoughts_token_count - - # Add cached content token count if available (implicit or explicit context caching) - if ( - usage_metadata - and hasattr(usage_metadata, "cached_content_token_count") - and usage_metadata.cached_content_token_count - ): - usage["cached_content_token_count"] = usage_metadata.cached_content_token_count - - usage.update(_convert_usage_metadata_to_serializable(usage_metadata)) - - # Create meta with reasoning content and thought signatures if available - meta: dict[str, Any] = { - "model": model, - "finish_reason": FINISH_REASON_MAPPING.get(finish_reason or ""), - "usage": usage, - } - - # Add thought signatures to meta if present (for multi-turn context preservation) - if thought_signatures: - meta["thought_signatures"] = thought_signatures - - # Create ReasoningContent object if there are reasoning parts - reasoning_content = None - if reasoning_parts: - reasoning_text = " ".join(reasoning_parts) - reasoning_content = ReasoningContent(reasoning_text=reasoning_text) - - # Create ChatMessage - message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta, reasoning=reasoning_content) - - return message - @component class GoogleGenAIChatGenerator: @@ -546,6 +138,18 @@ def weather_function(city: str): messages = [ChatMessage.from_user("What's the weather in Paris?")] response = chat_generator_with_tools.run(messages=messages) ``` + + ### Usage example with FileContent embedded in a ChatMessage + + ```python + from haystack.dataclasses import ChatMessage, FileContent + from haystack_integrations.components.generators.google_genai import GoogleGenAIChatGenerator + + file_content = FileContent.from_url("https://arxiv.org/pdf/2309.08632") + chat_message = ChatMessage.from_user(content_parts=[file_content, "Summarize this paper in 100 words."]) + chat_generator = GoogleGenAIChatGenerator() + response = chat_generator.run(messages=[chat_message]) + ``` """ def __init__( @@ -649,171 +253,6 @@ def from_dict(cls, data: dict[str, Any]) -> "GoogleGenAIChatGenerator": init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"]) return default_from_dict(cls, data) - def _convert_google_chunk_to_streaming_chunk( - self, - chunk: types.GenerateContentResponse, - index: int, - component_info: ComponentInfo, - ) -> StreamingChunk: - """ - Convert a chunk from Google Gen AI to a Haystack StreamingChunk. - - :param chunk: The chunk from Google Gen AI. - :param index: The index of the chunk. - :returns: A StreamingChunk object. - """ - content = "" - tool_calls: list[ToolCallDelta] = [] - finish_reason = None - reasoning_deltas: list[dict[str, str]] = [] - thought_signature_deltas: list[dict[str, Any]] = [] # Track thought signatures in streaming - - if chunk.candidates: - candidate = chunk.candidates[0] - finish_reason = getattr(candidate, "finish_reason", None) - - usage_metadata = chunk.usage_metadata - - usage = { - "prompt_tokens": getattr(usage_metadata, "prompt_token_count", 0) if usage_metadata else 0, - "completion_tokens": getattr(usage_metadata, "candidates_token_count", 0) if usage_metadata else 0, - "total_tokens": getattr(usage_metadata, "total_token_count", 0) if usage_metadata else 0, - } - - # Add thinking token count if available - if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: - usage["thoughts_token_count"] = usage_metadata.thoughts_token_count - - if candidate.content and candidate.content.parts: - tc_index = -1 - for part_index, part in enumerate(candidate.content.parts): - # Check for thought signature on this part (for multi-turn context) - if hasattr(part, "thought_signature") and part.thought_signature: - thought_signature_deltas.append( - { - "part_index": part_index, - "signature": part.thought_signature, - "has_text": part.text is not None, - "has_function_call": part.function_call is not None, - "is_thought": hasattr(part, "thought") and part.thought, - } - ) - - if part.text is not None and not (hasattr(part, "thought") and part.thought): - content += part.text - - elif part.function_call: - tc_index += 1 - tool_calls.append( - ToolCallDelta( - # Google GenAI does not provide index, but it is required for tool calls - index=tc_index, - id=part.function_call.id, - tool_name=part.function_call.name or "", - arguments=json.dumps(part.function_call.args) if part.function_call.args else None, - ) - ) - - # Handle thought parts for Gemini 2.5 series - elif hasattr(part, "thought") and part.thought: - thought_delta = { - "type": "reasoning", - "content": part.text if part.text else "", - } - reasoning_deltas.append(thought_delta) - - # start is only used by print_streaming_chunk. We try to make a reasonable assumption here but it should not be - # a problem if we change it in the future. - start = index == 0 or len(tool_calls) > 0 - - # Create meta with reasoning deltas and thought signatures if available - meta: dict[str, Any] = { - "received_at": datetime.now(timezone.utc).isoformat(), - "model": self._model, - "usage": usage, - } - - # Add reasoning deltas to meta if available - if reasoning_deltas: - meta["reasoning_deltas"] = reasoning_deltas - - # Add thought signature deltas to meta if available (for multi-turn context) - if thought_signature_deltas: - meta["thought_signature_deltas"] = thought_signature_deltas - - return StreamingChunk( - content="" if tool_calls else content, # prioritize tool calls over content when both are present - tool_calls=tool_calls, - component_info=component_info, - index=index, - start=start, - finish_reason=FINISH_REASON_MAPPING.get(finish_reason or ""), - meta=meta, - ) - - @staticmethod - def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) -> ChatMessage: - """ - Aggregate streaming chunks into a final ChatMessage with reasoning content and thought signatures. - - This method extends the standard streaming chunk aggregation to handle Google GenAI's - specific reasoning content, thinking token usage, and thought signatures for multi-turn context. - - :param chunks: List of streaming chunks to aggregate. - :returns: Final ChatMessage with aggregated content, reasoning, and thought signatures. - """ - - # Use the generic aggregator for standard content (text, tool calls, basic meta) - message = _convert_streaming_chunks_to_chat_message(chunks) - - # Now enhance with Google-specific features: reasoning content, thinking token usage, and thought signatures - reasoning_text_parts: list[str] = [] - thought_signatures: list[dict[str, Any]] = [] - thoughts_token_count = None - - for chunk in chunks: - # Extract reasoning deltas - if chunk.meta and "reasoning_deltas" in chunk.meta: - reasoning_deltas = chunk.meta["reasoning_deltas"] - if isinstance(reasoning_deltas, list): - for delta in reasoning_deltas: - if delta.get("type") == "reasoning": - reasoning_text_parts.append(delta.get("content", "")) - - # Extract thought signature deltas (for multi-turn context preservation) - if chunk.meta and "thought_signature_deltas" in chunk.meta: - signature_deltas = chunk.meta["thought_signature_deltas"] - if isinstance(signature_deltas, list): - # Aggregate thought signatures - they should come from the final chunks - # We'll keep the last set of signatures as they represent the complete state - thought_signatures = signature_deltas - - # Extract thinking token usage (from the last chunk that has it) - if chunk.meta and "usage" in chunk.meta: - chunk_usage = chunk.meta["usage"] - if "thoughts_token_count" in chunk_usage: - thoughts_token_count = chunk_usage["thoughts_token_count"] - - # Add thinking token count to usage if present - if thoughts_token_count is not None and "usage" in message.meta: - if message.meta["usage"] is None: - message.meta["usage"] = {} - message.meta["usage"]["thoughts_token_count"] = thoughts_token_count - - # Add thought signatures to meta if present (for multi-turn context preservation) - if thought_signatures: - message.meta["thought_signatures"] = thought_signatures - - # If we have reasoning content, reconstruct the message to include it - # Note: ChatMessage doesn't support adding reasoning after creation, reconstruction is necessary - if reasoning_text_parts: - reasoning_content = ReasoningContent(reasoning_text="".join(reasoning_text_parts)) - return ChatMessage.from_assistant( - text=message.text, tool_calls=message.tool_calls, meta=message.meta, reasoning=reasoning_content - ) - - return message - def _handle_streaming_response( self, response_stream: Iterator[types.GenerateContentResponse], streaming_callback: StreamingCallbackT ) -> dict[str, list[ChatMessage]]: @@ -828,10 +267,9 @@ def _handle_streaming_response( try: chunks = [] - chunk = None for i, chunk in enumerate(response_stream): - streaming_chunk = self._convert_google_chunk_to_streaming_chunk( - chunk=chunk, index=i, component_info=component_info + streaming_chunk = _convert_google_chunk_to_streaming_chunk( + chunk=chunk, index=i, component_info=component_info, model=self._model ) chunks.append(streaming_chunk) @@ -839,7 +277,7 @@ def _handle_streaming_response( streaming_callback(streaming_chunk) # Use custom aggregation that supports reasoning content - message = GoogleGenAIChatGenerator._aggregate_streaming_chunks_with_reasoning(chunks) + message = _aggregate_streaming_chunks_with_reasoning(chunks) return {"replies": [message]} except Exception as e: @@ -865,8 +303,8 @@ async def _handle_streaming_response_async( async for chunk in response_stream: i += 1 - streaming_chunk = self._convert_google_chunk_to_streaming_chunk( - chunk=chunk, index=i, component_info=component_info + streaming_chunk = _convert_google_chunk_to_streaming_chunk( + chunk=chunk, index=i, component_info=component_info, model=self._model ) chunks.append(streaming_chunk) @@ -874,69 +312,13 @@ async def _handle_streaming_response_async( await streaming_callback(streaming_chunk) # Use custom aggregation that supports reasoning content - message = GoogleGenAIChatGenerator._aggregate_streaming_chunks_with_reasoning(chunks) + message = _aggregate_streaming_chunks_with_reasoning(chunks) return {"replies": [message]} except Exception as e: msg = f"Error in async streaming response: {e}" raise RuntimeError(msg) from e - @staticmethod - def _process_thinking_config(generation_kwargs: dict[str, Any]) -> dict[str, Any]: - """ - Process thinking configuration from generation_kwargs. - - :param generation_kwargs: The generation configuration dictionary. - :returns: Updated generation_kwargs with thinking_config if applicable. - """ - if "thinking_budget" in generation_kwargs: - thinking_budget = generation_kwargs.pop("thinking_budget") - - # Basic type validation - if not isinstance(thinking_budget, int): - logger.warning( - f"Invalid thinking_budget type: {type(thinking_budget)}. Expected int, using dynamic allocation." - ) - # fall back to default: dynamic thinking budget allocation - thinking_budget = -1 - - # Create thinking config - thinking_config = types.ThinkingConfig(thinking_budget=thinking_budget, include_thoughts=True) - generation_kwargs["thinking_config"] = thinking_config - - if "thinking_level" in generation_kwargs: - thinking_level = generation_kwargs.pop("thinking_level") - - # Basic type validation - if not isinstance(thinking_level, str): - logger.warning( - f"Invalid thinking_level type: {type(thinking_level).__name__}. Expected str, " - f"falling back to THINKING_LEVEL_UNSPECIFIED." - ) - thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - else: - # Convert to uppercase for case-insensitive matching - thinking_level_upper = thinking_level.upper() - - # Check if the uppercase value is a valid ThinkingLevel enum member - valid_levels = [level.value for level in types.ThinkingLevel] - if thinking_level_upper not in valid_levels: - logger.warning( - f"Invalid thinking_level value: '{thinking_level}'. " - f"Must be one of: {valid_levels} (case-insensitive). " - "Falling back to THINKING_LEVEL_UNSPECIFIED." - ) - thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - else: - # Parse valid string to ThinkingLevel enum - thinking_level = types.ThinkingLevel(thinking_level_upper) - - # Create thinking config with level - thinking_config = types.ThinkingConfig(thinking_level=thinking_level, include_thoughts=True) - generation_kwargs["thinking_config"] = thinking_config - - return generation_kwargs - @component.output_types(replies=list[ChatMessage]) def run( self, @@ -971,7 +353,7 @@ def run( tools = tools or self._tools # Process thinking configuration - generation_kwargs = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs) + generation_kwargs = _process_thinking_config(generation_kwargs) # Select appropriate streaming callback streaming_callback = select_streaming_callback( @@ -1081,7 +463,7 @@ async def run_async( tools = tools or self._tools # Process thinking configuration - generation_kwargs = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs) + generation_kwargs = _process_thinking_config(generation_kwargs) # Select appropriate streaming callback streaming_callback = select_streaming_callback( diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py index b256d7dd15..5e66ad0724 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py @@ -2,8 +2,113 @@ # # SPDX-License-Identifier: Apache-2.0 +import base64 +import json +from datetime import datetime, timezone from typing import Any +from google.genai import types +from google.genai.types import GenerateContentResponseUsageMetadata, UsageMetadata +from haystack import logging +from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message +from haystack.dataclasses import ( + ComponentInfo, + FileContent, + FinishReason, + ImageContent, + StreamingChunk, + TextContent, + ToolCall, + ToolCallDelta, + ToolCallResult, +) +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ReasoningContent +from haystack.tools import ( + ToolsType, + flatten_tools_or_toolsets, +) +from jsonref import replace_refs + +logger = logging.getLogger(__name__) + +# Mapping from Google GenAI finish reasons to Haystack FinishReason values +FINISH_REASON_MAPPING: dict[str, FinishReason] = { + "STOP": "stop", + "MAX_TOKENS": "length", + "SAFETY": "content_filter", + "BLOCKLIST": "content_filter", + "PROHIBITED_CONTENT": "content_filter", + "SPII": "content_filter", + "IMAGE_SAFETY": "content_filter", +} + +# Google GenAI supported image MIME types based on documentation +# https://ai.google.dev/gemini-api/docs/image-understanding?lang=python#supported-formats +GOOGLE_GENAI_SUPPORTED_MIME_TYPES = { + "image/png": "png", + "image/jpeg": "jpeg", + "image/jpg": "jpeg", # Common alias + "image/webp": "webp", + "image/heic": "heic", + "image/heif": "heif", +} + + +def _process_thinking_config(generation_kwargs: dict[str, Any]) -> dict[str, Any]: + """ + Process thinking configuration from generation_kwargs. + + :param generation_kwargs: The generation configuration dictionary. + :returns: Updated generation_kwargs with thinking_config if applicable. + """ + if "thinking_budget" in generation_kwargs: + thinking_budget = generation_kwargs.pop("thinking_budget") + + # Basic type validation + if not isinstance(thinking_budget, int): + logger.warning( + f"Invalid thinking_budget type: {type(thinking_budget)}. Expected int, using dynamic allocation." + ) + # fall back to default: dynamic thinking budget allocation + thinking_budget = -1 + + # Create thinking config + thinking_config = types.ThinkingConfig(thinking_budget=thinking_budget, include_thoughts=True) + generation_kwargs["thinking_config"] = thinking_config + + if "thinking_level" in generation_kwargs: + thinking_level = generation_kwargs.pop("thinking_level") + + # Basic type validation + if not isinstance(thinking_level, str): + logger.warning( + f"Invalid thinking_level type: {type(thinking_level).__name__}. Expected str, " + f"falling back to THINKING_LEVEL_UNSPECIFIED." + ) + thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + else: + # Convert to uppercase for case-insensitive matching + thinking_level_upper = thinking_level.upper() + + # Check if the uppercase value is a valid ThinkingLevel enum member + valid_levels = [level.value for level in types.ThinkingLevel] + if thinking_level_upper not in valid_levels: + logger.warning( + f"Invalid thinking_level value: '{thinking_level}'. " + f"Must be one of: {valid_levels} (case-insensitive). " + "Falling back to THINKING_LEVEL_UNSPECIFIED." + ) + thinking_level = types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + else: + # Parse valid string to ThinkingLevel enum + thinking_level = types.ThinkingLevel(thinking_level_upper) + + # Create thinking config with level + thinking_config = types.ThinkingConfig(thinking_level=thinking_level, include_thoughts=True) + generation_kwargs["thinking_config"] = thinking_config + + return generation_kwargs + def remove_key_from_schema( schema: dict[str, Any] | list[Any] | Any, target_key: str @@ -29,3 +134,551 @@ def remove_key_from_schema( return [remove_key_from_schema(item, target_key) for item in schema] return schema + + +def _sanitize_tool_schema(tool_schema: dict[str, Any]) -> dict[str, Any]: + """ + Sanitizes a tool schema to remove any keys that are not supported by Google Gen AI. + + Google Gen AI does not support additionalProperties, $schema, $defs, or $ref in the tool schema. + + :param tool_schema: The tool schema to sanitize. + :returns: The sanitized tool schema. + """ + # google Gemini does not support additionalProperties and $schema in the tool schema + sanitized_schema = remove_key_from_schema(tool_schema, "additionalProperties") + sanitized_schema = remove_key_from_schema(sanitized_schema, "$schema") + # expand $refs in the tool schema + expanded_schema = replace_refs(sanitized_schema) + # and remove the $defs key leaving the rest of the schema + final_schema = remove_key_from_schema(expanded_schema, "$defs") + + if not isinstance(final_schema, dict): + msg = "Tool schema must be a dictionary after sanitization" + raise ValueError(msg) + + return final_schema + + +def _convert_message_to_google_genai_format(message: ChatMessage) -> types.Content: + """ + Converts a Haystack ChatMessage to Google Gen AI Content format. + + :param message: The Haystack ChatMessage to convert. + :returns: Google Gen AI Content object. + """ + # Check if message has content + if not message._content: + msg = "A `ChatMessage` must contain at least one content part." + raise ValueError(msg) + + parts = [] + + # Check if this message has thought signatures from a previous response + # These need to be reconstructed in their original part structure + thought_signatures = message.meta.get("thought_signatures", []) if message.meta else [] + + # If we have thought signatures, we need to reconstruct the exact part structure + # from the previous assistant response to maintain multi-turn thinking context + if thought_signatures and message.is_from(ChatRole.ASSISTANT): + # Track which tool calls we've used (to handle multiple tool calls) + tool_call_index = 0 + + # Reconstruct parts with their original thought signatures + for sig_info in thought_signatures: + part_dict: dict[str, Any] = {} + + # Check what type of content this part had + if sig_info.get("has_text"): + # Find the corresponding text content + if sig_info.get("is_thought"): + # This was a thought part - find it in reasoning content + if message.reasoning: + part_dict["text"] = message.reasoning.reasoning_text + part_dict["thought"] = True + else: + # Regular text part + part_dict["text"] = message.text or "" + + if sig_info.get("has_function_call"): + # Find the corresponding tool call by index + if message.tool_calls and tool_call_index < len(message.tool_calls): + tool_call = message.tool_calls[tool_call_index] + part_dict["function_call"] = types.FunctionCall( + id=tool_call.id, name=tool_call.tool_name, args=tool_call.arguments + ) + tool_call_index += 1 # Move to next tool call for next part + + # Add the thought signature to preserve context + part_dict["thought_signature"] = sig_info["signature"] + + parts.append(types.Part(**part_dict)) + + # If we reconstructed from signatures, we're done + if parts: + role = "model" # Assistant messages with signatures are always from the model + return types.Content(role=role, parts=parts) + + # Standard processing for messages without thought signatures + for content_part in message._content: + if isinstance(content_part, TextContent): + # Only add text parts that are not empty to avoid unnecessary empty text parts + if content_part.text.strip(): + parts.append(types.Part(text=content_part.text)) + + elif isinstance(content_part, (ImageContent, FileContent)): + cls_name = content_part.__class__.__name__ + + if not message.is_from(ChatRole.USER): + msg = f"{cls_name} is only supported for user messages" + raise ValueError(msg) + + # MIME type validation: must be provided and, for images, one of the supported types + if not content_part.mime_type: + msg = f"MIME type is required to use {cls_name} with GoogleGenAIChatGenerator" + raise ValueError(msg) + + if ( + isinstance(content_part, ImageContent) + and content_part.mime_type not in GOOGLE_GENAI_SUPPORTED_MIME_TYPES + ): + supported_types = list(GOOGLE_GENAI_SUPPORTED_MIME_TYPES.keys()) + msg = ( + f"Unsupported image MIME type: {content_part.mime_type}. " + f"Google AI supports the following MIME types: {supported_types}" + ) + raise ValueError(msg) + + # Use inline data approach + try: + base64_data = ( + content_part.base64_data if isinstance(content_part, FileContent) else content_part.base64_image + ) + bytes_data = base64.b64decode(base64_data) + + file_part = types.Part.from_bytes(data=bytes_data, mime_type=content_part.mime_type) + parts.append(file_part) + + except Exception as e: + msg = f"Failed to process {cls_name} data: {e}" + raise RuntimeError(msg) from e + + elif isinstance(content_part, ToolCall): + parts.append( + types.Part( + function_call=types.FunctionCall( + id=content_part.id, name=content_part.tool_name, args=content_part.arguments + ) + ) + ) + + elif isinstance(content_part, ToolCallResult): + if isinstance(content_part.result, str): + parts.append( + types.Part( + function_response=types.FunctionResponse( + id=content_part.origin.id, + name=content_part.origin.tool_name, + response={"result": content_part.result}, + ) + ) + ) + elif isinstance(content_part.result, list): + tool_call_result_parts: list[types.FunctionResponsePart] = [] + for item in content_part.result: + if isinstance(item, TextContent): + tool_call_result_parts.append( + types.FunctionResponsePart( + inline_data=types.FunctionResponseBlob( + data=item.text.encode("utf-8"), mime_type="text/plain" + ), + ) + ) + elif isinstance(item, ImageContent): + tool_call_result_parts.append( + types.FunctionResponsePart( + inline_data=types.FunctionResponseBlob( + data=base64.b64decode(item.base64_image), mime_type=item.mime_type + ), + ) + ) + else: + msg = ( + "Unsupported content type in tool call result list. " + "Only TextContent and ImageContent are supported." + ) + raise ValueError(msg) + parts.append( + types.Part( + function_response=types.FunctionResponse( + id=content_part.origin.id, + name=content_part.origin.tool_name, + parts=tool_call_result_parts, + # the response field is mandatory, but in this case the LLM just needs multimodal parts + response={"result": ""}, + ) + ) + ) + else: + msg = "Unsupported content type in tool call result" + raise ValueError(msg) + elif isinstance(content_part, ReasoningContent): + # Reasoning content is for human transparency only, not for maintaining LLM context + # Thought signatures (stored in message.meta) handle context preservation + # Leave this here so we don't implement reasoning content handling in the future accidentally + pass + + # Determine role + if message.is_from(ChatRole.USER) or message.tool_call_results: + role = "user" + elif message.is_from(ChatRole.ASSISTANT): + role = "model" + elif message.is_from(ChatRole.SYSTEM): + # System messages will be handled separately as system instruction + # When we convert a list of ChatMessage to be sent to google genai, + # we need to handle system messages separately as system instruction and we only take the first message + # as the system instruction - if it is present. + # + # If we find any additional system messages, we will treat them as user messages + role = "user" + else: + msg = f"Unsupported message role: {message._role}" + raise ValueError(msg) + + return types.Content(role=role, parts=parts) + + +def _convert_tools_to_google_genai_format(tools: ToolsType) -> list[types.Tool]: + """ + Converts a list of Haystack Tools, Toolsets, or a mix to Google Gen AI Tool format. + + :param tools: List of Haystack Tool and/or Toolset objects, or a single Toolset. + :returns: List of Google Gen AI Tool objects. + """ + # Flatten Tools and Toolsets into a single list of Tools + flattened_tools = flatten_tools_or_toolsets(tools) + + function_declarations: list[types.FunctionDeclaration] = [] + for tool in flattened_tools: + parameters = _sanitize_tool_schema(tool.parameters) + function_declarations.append( + types.FunctionDeclaration( + name=tool.name, description=tool.description, parameters=types.Schema(**parameters) + ) + ) + + # Return a single Tool object with all function declarations as in the Google GenAI docs + # we could also return multiple Tool objects, doesn't seem to make a difference + # revisit this decision + return [types.Tool(function_declarations=function_declarations)] + + +def _convert_usage_metadata_to_serializable( + usage_metadata: UsageMetadata | GenerateContentResponseUsageMetadata | None, +) -> dict[str, Any]: + """Build a JSON-serializable usage dict from a UsageMetadata object. + + Iterates over known UsageMetadata attribute names and adds each non-None value + in serialized form. Full list of fields: https://ai.google.dev/api/generate-content#UsageMetadata + """ + + def serialize(val: Any) -> Any: + if val is None: + return None + if isinstance(val, (str, int, float, bool)): + return val + if isinstance(val, list): + return [serialize(item) for item in val] + token_count = getattr(val, "token_count", None) or getattr(val, "tokenCount", None) + if hasattr(val, "modality") and token_count is not None: + mod = getattr(val, "modality", None) + mod_str = getattr(mod, "value", getattr(mod, "name", str(mod))) if mod is not None else None + return {"modality": mod_str, "token_count": token_count} + if hasattr(val, "name"): + return getattr(val, "value", getattr(val, "name", val)) + return val + + if not usage_metadata: + return {} + + _usage_attr_names = ( + "prompt_token_count", + "candidates_token_count", + "total_token_count", + "cache_tokens_details", + "candidates_tokens_details", + "prompt_tokens_details", + "tool_use_prompt_token_count", + "tool_use_prompt_tokens_details", + ) + result: dict[str, Any] = {} + for attr in _usage_attr_names: + val = getattr(usage_metadata, attr, None) + if val is not None: + result[attr] = serialize(val) + return result + + +def _convert_google_genai_response_to_chatmessage(response: types.GenerateContentResponse, model: str) -> ChatMessage: + """ + Converts a Google Gen AI response to a Haystack ChatMessage. + + :param response: The response from Google Gen AI. + :param model: The model name. + :returns: A Haystack ChatMessage. + """ + text_parts = [] + tool_calls = [] + reasoning_parts = [] + thought_signatures = [] # Store thought signatures for multi-turn context + + # Extract text, function calls, thoughts, and thought signatures from response + finish_reason = None + if response.candidates: + candidate = response.candidates[0] + finish_reason = getattr(candidate, "finish_reason", None) + if candidate.content is not None and candidate.content.parts is not None: + for i, part in enumerate(candidate.content.parts): + # Check for thought signature on this part + if hasattr(part, "thought_signature") and part.thought_signature: + # Store the thought signature with its part index for reconstruction + thought_signatures.append( + { + "part_index": i, + "signature": part.thought_signature, + "has_text": part.text is not None, + "has_function_call": part.function_call is not None, + "is_thought": hasattr(part, "thought") and part.thought, + } + ) + + if part.text is not None and not (hasattr(part, "thought") and part.thought): + text_parts.append(part.text) + if part.function_call is not None: + tool_call = ToolCall( + tool_name=part.function_call.name or "", + arguments=dict(part.function_call.args) if part.function_call.args else {}, + id=part.function_call.id, + ) + tool_calls.append(tool_call) + # Handle thought parts for Gemini 2.5 series + if hasattr(part, "thought") and part.thought: + # Extract thought content + if part.text: + reasoning_parts.append(part.text) + + # Combine text parts + text = " ".join(text_parts) if text_parts else "" + + usage_metadata = response.usage_metadata + + # Create usage metadata including thinking tokens if available + usage = { + "prompt_tokens": getattr(usage_metadata, "prompt_token_count", 0), + "completion_tokens": getattr(usage_metadata, "candidates_token_count", 0), + "total_tokens": getattr(usage_metadata, "total_token_count", 0), + } + + # Add thinking token count if available + if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: + usage["thoughts_token_count"] = usage_metadata.thoughts_token_count + + # Add cached content token count if available (implicit or explicit context caching) + if ( + usage_metadata + and hasattr(usage_metadata, "cached_content_token_count") + and usage_metadata.cached_content_token_count + ): + usage["cached_content_token_count"] = usage_metadata.cached_content_token_count + + usage.update(_convert_usage_metadata_to_serializable(usage_metadata)) + + # Create meta with reasoning content and thought signatures if available + meta: dict[str, Any] = { + "model": model, + "finish_reason": FINISH_REASON_MAPPING.get(finish_reason or ""), + "usage": usage, + } + + # Add thought signatures to meta if present (for multi-turn context preservation) + if thought_signatures: + meta["thought_signatures"] = thought_signatures + + # Create ReasoningContent object if there are reasoning parts + reasoning_content = None + if reasoning_parts: + reasoning_text = " ".join(reasoning_parts) + reasoning_content = ReasoningContent(reasoning_text=reasoning_text) + + # Create ChatMessage + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta, reasoning=reasoning_content) + + return message + + +def _convert_google_chunk_to_streaming_chunk( + chunk: types.GenerateContentResponse, + index: int, + component_info: ComponentInfo, + model: str, +) -> StreamingChunk: + """ + Convert a chunk from Google Gen AI to a Haystack StreamingChunk. + + :param chunk: The chunk from Google Gen AI. + :param index: The index of the chunk. + :param component_info: The component info. + :param model: The model name. + :returns: A StreamingChunk object. + """ + content = "" + tool_calls: list[ToolCallDelta] = [] + finish_reason = None + reasoning_deltas: list[dict[str, str]] = [] + thought_signature_deltas: list[dict[str, Any]] = [] # Track thought signatures in streaming + + if chunk.candidates: + candidate = chunk.candidates[0] + finish_reason = getattr(candidate, "finish_reason", None) + + usage_metadata = chunk.usage_metadata + + usage = { + "prompt_tokens": getattr(usage_metadata, "prompt_token_count", 0) if usage_metadata else 0, + "completion_tokens": getattr(usage_metadata, "candidates_token_count", 0) if usage_metadata else 0, + "total_tokens": getattr(usage_metadata, "total_token_count", 0) if usage_metadata else 0, + } + + # Add thinking token count if available + if usage_metadata and hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count: + usage["thoughts_token_count"] = usage_metadata.thoughts_token_count + + if candidate.content and candidate.content.parts: + tc_index = -1 + for part_index, part in enumerate(candidate.content.parts): + # Check for thought signature on this part (for multi-turn context) + if hasattr(part, "thought_signature") and part.thought_signature: + thought_signature_deltas.append( + { + "part_index": part_index, + "signature": part.thought_signature, + "has_text": part.text is not None, + "has_function_call": part.function_call is not None, + "is_thought": hasattr(part, "thought") and part.thought, + } + ) + + if part.text is not None and not (hasattr(part, "thought") and part.thought): + content += part.text + + elif part.function_call: + tc_index += 1 + tool_calls.append( + ToolCallDelta( + # Google GenAI does not provide index, but it is required for tool calls + index=tc_index, + id=part.function_call.id, + tool_name=part.function_call.name or "", + arguments=json.dumps(part.function_call.args) if part.function_call.args else None, + ) + ) + + # Handle thought parts for Gemini 2.5 series + elif hasattr(part, "thought") and part.thought: + thought_delta = { + "type": "reasoning", + "content": part.text if part.text else "", + } + reasoning_deltas.append(thought_delta) + + # start is only used by print_streaming_chunk. We try to make a reasonable assumption here but it should not be + # a problem if we change it in the future. + start = index == 0 or len(tool_calls) > 0 + + # Create meta with reasoning deltas and thought signatures if available + meta: dict[str, Any] = { + "received_at": datetime.now(timezone.utc).isoformat(), + "model": model, + "usage": usage, + } + + # Add reasoning deltas to meta if available + if reasoning_deltas: + meta["reasoning_deltas"] = reasoning_deltas + + # Add thought signature deltas to meta if available (for multi-turn context) + if thought_signature_deltas: + meta["thought_signature_deltas"] = thought_signature_deltas + + return StreamingChunk( + content="" if tool_calls else content, # prioritize tool calls over content when both are present + tool_calls=tool_calls, + component_info=component_info, + index=index, + start=start, + finish_reason=FINISH_REASON_MAPPING.get(finish_reason or ""), + meta=meta, + ) + + +def _aggregate_streaming_chunks_with_reasoning(chunks: list[StreamingChunk]) -> ChatMessage: + """ + Aggregate streaming chunks into a final ChatMessage with reasoning content and thought signatures. + + This method extends the standard streaming chunk aggregation to handle Google GenAI's + specific reasoning content, thinking token usage, and thought signatures for multi-turn context. + + :param chunks: List of streaming chunks to aggregate. + :returns: Final ChatMessage with aggregated content, reasoning, and thought signatures. + """ + + # Use the generic aggregator for standard content (text, tool calls, basic meta) + message = _convert_streaming_chunks_to_chat_message(chunks) + + # Now enhance with Google-specific features: reasoning content, thinking token usage, and thought signatures + reasoning_text_parts: list[str] = [] + thought_signatures: list[dict[str, Any]] = [] + thoughts_token_count = None + + for chunk in chunks: + # Extract reasoning deltas + if chunk.meta and "reasoning_deltas" in chunk.meta: + reasoning_deltas = chunk.meta["reasoning_deltas"] + if isinstance(reasoning_deltas, list): + for delta in reasoning_deltas: + if delta.get("type") == "reasoning": + reasoning_text_parts.append(delta.get("content", "")) + + # Extract thought signature deltas (for multi-turn context preservation) + if chunk.meta and "thought_signature_deltas" in chunk.meta: + signature_deltas = chunk.meta["thought_signature_deltas"] + if isinstance(signature_deltas, list): + # Aggregate thought signatures - they should come from the final chunks + # We'll keep the last set of signatures as they represent the complete state + thought_signatures = signature_deltas + + # Extract thinking token usage (from the last chunk that has it) + if chunk.meta and "usage" in chunk.meta: + chunk_usage = chunk.meta["usage"] + if "thoughts_token_count" in chunk_usage: + thoughts_token_count = chunk_usage["thoughts_token_count"] + + # Add thinking token count to usage if present + if thoughts_token_count is not None and "usage" in message.meta: + if message.meta["usage"] is None: + message.meta["usage"] = {} + message.meta["usage"]["thoughts_token_count"] = thoughts_token_count + + # Add thought signatures to meta if present (for multi-turn context preservation) + if thought_signatures: + message.meta["thought_signatures"] = thought_signatures + + # If we have reasoning content, reconstruct the message to include it + # Note: ChatMessage doesn't support adding reasoning after creation, reconstruction is necessary + if reasoning_text_parts: + reasoning_content = ReasoningContent(reasoning_text="".join(reasoning_text_parts)) + return ChatMessage.from_assistant( + text=message.text, tool_calls=message.tool_calls, meta=message.meta, reasoning=reasoning_content + ) + + return message diff --git a/integrations/google_genai/tests/test_chat_generator.py b/integrations/google_genai/tests/test_chat_generator.py index ad6d7e95a5..2a15ba4c1d 100644 --- a/integrations/google_genai/tests/test_chat_generator.py +++ b/integrations/google_genai/tests/test_chat_generator.py @@ -3,18 +3,16 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import base64 import os -from unittest.mock import Mock import pytest -from google.genai import types from haystack.components.agents import Agent from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ( ChatMessage, ChatRole, ComponentInfo, + FileContent, ImageContent, ReasoningContent, StreamingChunk, @@ -26,20 +24,9 @@ from haystack_integrations.components.generators.google_genai.chat.chat_generator import ( GoogleGenAIChatGenerator, - _convert_google_genai_response_to_chatmessage, - _convert_message_to_google_genai_format, - _convert_usage_metadata_to_serializable, ) -@pytest.fixture -def chat_messages(): - return [ - ChatMessage.from_system("You are a helpful assistant"), - ChatMessage.from_user("What's the capital of France"), - ] - - def weather(city: str): """Get weather information for a city.""" return f"Weather in {city}: 22°C, sunny" @@ -61,402 +48,7 @@ def tools(): ] -class TestStreamingChunkConversion: - def test_convert_google_chunk_to_streaming_chunk_text_only(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - mock_chunk = Mock() - mock_candidate = Mock() - mock_candidate.finish_reason = "STOP" - mock_chunk.candidates = [mock_candidate] - - mock_content = Mock() - mock_content.parts = [] - mock_part = Mock() - mock_part.text = "Hello, world!" - mock_part.function_call = None - # Explicitly set thought=False to simulate a regular (non-thought) part - mock_part.thought = False - mock_content.parts.append(mock_part) - mock_candidate.content = mock_content - - chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=mock_chunk, - index=0, - component_info=component_info, - ) - - assert chunk.content == "Hello, world!" - assert chunk.tool_calls == [] - assert chunk.finish_reason == "stop" - assert chunk.index == 0 - assert "received_at" in chunk.meta - assert chunk.component_info == component_info - - def test_convert_google_chunk_to_streaming_chunk_tool_call(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - mock_chunk = Mock() - mock_candidate = Mock() - mock_candidate.finish_reason = "STOP" - mock_chunk.candidates = [mock_candidate] - - mock_content = Mock() - mock_content.parts = [] - mock_part = Mock() - mock_part.text = None - mock_function_call = Mock() - mock_function_call.name = "weather" - mock_function_call.args = {"city": "Paris"} - mock_function_call.id = "call_123" - mock_part.function_call = mock_function_call - mock_content.parts.append(mock_part) - mock_candidate.content = mock_content - - chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=mock_chunk, index=0, component_info=component_info - ) - - assert chunk.content == "" - assert chunk.tool_calls is not None - assert len(chunk.tool_calls) == 1 - assert chunk.tool_calls[0].tool_name == "weather" - assert chunk.tool_calls[0].arguments == '{"city": "Paris"}' - assert chunk.tool_calls[0].id == "call_123" - assert chunk.finish_reason == "stop" - assert chunk.index == 0 - assert "received_at" in chunk.meta - assert chunk.component_info == component_info - - def test_convert_google_chunk_to_streaming_chunk_mixed_content(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - mock_chunk = Mock() - mock_candidate = Mock() - mock_candidate.finish_reason = "STOP" - mock_chunk.candidates = [mock_candidate] - - mock_content = Mock() - mock_content.parts = [] - - mock_text_part = Mock() - mock_text_part.text = "I'll check the weather for you." - mock_text_part.function_call = None - # Explicitly set thought=False to simulate a regular (non-thought) part - mock_text_part.thought = False - mock_content.parts.append(mock_text_part) - - mock_tool_part = Mock() - mock_tool_part.text = None - mock_function_call = Mock() - mock_function_call.name = "weather" - mock_function_call.args = {"city": "London"} - mock_function_call.id = "call_456" - mock_tool_part.function_call = mock_function_call - mock_content.parts.append(mock_tool_part) - - mock_candidate.content = mock_content - - chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=mock_chunk, index=0, component_info=component_info - ) - - # When both text and tool calls are present, tool calls are prioritized - assert chunk.content == "" - assert chunk.tool_calls is not None - assert len(chunk.tool_calls) == 1 - assert chunk.tool_calls[0].tool_name == "weather" - assert chunk.tool_calls[0].arguments == '{"city": "London"}' - assert chunk.finish_reason == "stop" - assert chunk.component_info == component_info - - def test_convert_google_chunk_to_streaming_chunk_empty_parts(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - mock_chunk = Mock() - mock_candidate = Mock() - mock_content = Mock() - mock_content.parts = [] - mock_candidate.content = mock_content - mock_chunk.candidates = [mock_candidate] - - chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=mock_chunk, index=0, component_info=component_info - ) - - assert chunk.content == "" - assert chunk.tool_calls == [] - assert chunk.component_info == component_info - - def test_convert_google_chunk_to_streaming_chunk_real_example(self, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - # Chunk 1: Text only - chunk1_parts = [ - types.Part( - text="I'll get the weather information for Paris and Berlin", function_call=None, function_response=None - ) - ] - chunk1_content = types.Content(role="model", parts=chunk1_parts) - chunk1_candidate = types.Candidate( - content=chunk1_content, - finish_reason=None, - index=None, - safety_ratings=None, - citation_metadata=None, - grounding_metadata=None, - finish_message=None, - token_count=None, - logprobs_result=None, - avg_logprobs=None, - url_context_metadata=None, - ) - chunk1_usage = types.GenerateContentResponseUsageMetadata( - prompt_token_count=217, candidates_token_count=None, total_token_count=217 - ) - chunk1 = types.GenerateContentResponse( - candidates=[chunk1_candidate], - usage_metadata=chunk1_usage, - model_version="gemini-2.5-flash", - response_id=None, - create_time=None, - prompt_feedback=None, - automatic_function_calling_history=None, - parsed=None, - ) - - streaming_chunk1 = component._convert_google_chunk_to_streaming_chunk( - chunk=chunk1, index=0, component_info=component_info - ) - assert streaming_chunk1.content == "I'll get the weather information for Paris and Berlin" - assert streaming_chunk1.tool_calls == [] - assert streaming_chunk1.finish_reason is None - assert streaming_chunk1.index == 0 - assert "received_at" in streaming_chunk1.meta - assert streaming_chunk1.meta["model"] == "gemini-2.5-flash" - assert "usage" in streaming_chunk1.meta - assert streaming_chunk1.meta["usage"]["prompt_tokens"] == 217 - assert streaming_chunk1.meta["usage"]["completion_tokens"] is None - assert streaming_chunk1.meta["usage"]["total_tokens"] == 217 - assert streaming_chunk1.component_info == component_info - - # Chunk 2: Text only - chunk2_parts = [ - types.Part(text=" and present it in a structured format.", function_call=None, function_response=None) - ] - chunk2_content = types.Content(role="model", parts=chunk2_parts) - chunk2_candidate = types.Candidate( - content=chunk2_content, - finish_reason=None, - index=None, - safety_ratings=None, - citation_metadata=None, - grounding_metadata=None, - finish_message=None, - token_count=None, - logprobs_result=None, - avg_logprobs=None, - url_context_metadata=None, - ) - chunk2_usage = types.GenerateContentResponseUsageMetadata( - prompt_token_count=217, candidates_token_count=None, total_token_count=217 - ) - chunk2 = types.GenerateContentResponse( - candidates=[chunk2_candidate], - usage_metadata=chunk2_usage, - model_version="gemini-2.5-flash", - response_id=None, - create_time=None, - prompt_feedback=None, - automatic_function_calling_history=None, - parsed=None, - ) - - streaming_chunk2 = component._convert_google_chunk_to_streaming_chunk( - chunk=chunk2, index=1, component_info=component_info - ) - assert streaming_chunk2.content == " and present it in a structured format." - assert streaming_chunk2.tool_calls == [] - assert streaming_chunk2.finish_reason is None - assert streaming_chunk2.index == 1 - assert "received_at" in streaming_chunk2.meta - assert streaming_chunk2.meta["model"] == "gemini-2.5-flash" - assert "usage" in streaming_chunk2.meta - assert streaming_chunk2.meta["usage"]["prompt_tokens"] == 217 - assert streaming_chunk2.meta["usage"]["completion_tokens"] is None - assert streaming_chunk2.meta["usage"]["total_tokens"] == 217 - assert streaming_chunk2.component_info == component_info - - # Chunk 3: Multiple tool calls (6 function calls) for 2 cities with 3 tools each - fc1 = types.FunctionCall(id=None, name="get_weather", args={"city": "Paris"}) - fc2 = types.FunctionCall(id=None, name="get_population", args={"city": "Paris"}) - fc3 = types.FunctionCall(id=None, name="get_time", args={"city": "Paris"}) - fc4 = types.FunctionCall(id=None, name="get_weather", args={"city": "Berlin"}) - fc5 = types.FunctionCall(id=None, name="get_population", args={"city": "Berlin"}) - fc6 = types.FunctionCall(id=None, name="get_time", args={"city": "Berlin"}) - - parts = [ - types.Part(text=None, function_call=fc1, function_response=None), - types.Part(text=None, function_call=fc2, function_response=None), - types.Part(text=None, function_call=fc3, function_response=None), - types.Part(text=None, function_call=fc4, function_response=None), - types.Part(text=None, function_call=fc5, function_response=None), - types.Part(text=None, function_call=fc6, function_response=None), - ] - - content = types.Content(role="model", parts=parts) - candidate = types.Candidate( - content=content, - finish_reason=types.FinishReason.STOP, - index=None, - safety_ratings=None, - citation_metadata=None, - grounding_metadata=None, - finish_message=None, - token_count=None, - logprobs_result=None, - avg_logprobs=None, - url_context_metadata=None, - ) - - usage_metadata = types.GenerateContentResponseUsageMetadata( - prompt_token_count=144, candidates_token_count=121, total_token_count=265 - ) - chunk = types.GenerateContentResponse( - candidates=[candidate], - usage_metadata=usage_metadata, - model_version="gemini-2.5-flash", - response_id=None, - create_time=None, - prompt_feedback=None, - automatic_function_calling_history=None, - parsed=None, - ) - - streaming_chunk = component._convert_google_chunk_to_streaming_chunk( - chunk=chunk, index=2, component_info=component_info - ) - assert streaming_chunk.content == "" - assert streaming_chunk.tool_calls is not None - assert len(streaming_chunk.tool_calls) == 6 - assert streaming_chunk.finish_reason == "stop" - assert streaming_chunk.index == 2 - assert "received_at" in streaming_chunk.meta - assert streaming_chunk.meta["model"] == "gemini-2.5-flash" - assert streaming_chunk.component_info == component_info - assert "usage" in streaming_chunk.meta - assert streaming_chunk.meta["usage"]["prompt_tokens"] == 144 - assert streaming_chunk.meta["usage"]["completion_tokens"] == 121 - assert streaming_chunk.meta["usage"]["total_tokens"] == 265 - - assert streaming_chunk.tool_calls[0].tool_name == "get_weather" - assert streaming_chunk.tool_calls[0].arguments == '{"city": "Paris"}' - assert streaming_chunk.tool_calls[0].id is None - assert streaming_chunk.tool_calls[0].index == 0 - - assert streaming_chunk.tool_calls[1].tool_name == "get_population" - assert streaming_chunk.tool_calls[1].arguments == '{"city": "Paris"}' - assert streaming_chunk.tool_calls[1].id is None - assert streaming_chunk.tool_calls[1].index == 1 - - assert streaming_chunk.tool_calls[2].tool_name == "get_time" - assert streaming_chunk.tool_calls[2].arguments == '{"city": "Paris"}' - assert streaming_chunk.tool_calls[2].id is None - assert streaming_chunk.tool_calls[2].index == 2 - - assert streaming_chunk.tool_calls[3].tool_name == "get_weather" - assert streaming_chunk.tool_calls[3].arguments == '{"city": "Berlin"}' - assert streaming_chunk.tool_calls[3].id is None - assert streaming_chunk.tool_calls[3].index == 3 - - assert streaming_chunk.tool_calls[4].tool_name == "get_population" - assert streaming_chunk.tool_calls[4].arguments == '{"city": "Berlin"}' - assert streaming_chunk.tool_calls[4].id is None - assert streaming_chunk.tool_calls[4].index == 4 - - assert streaming_chunk.tool_calls[5].tool_name == "get_time" - assert streaming_chunk.tool_calls[5].arguments == '{"city": "Berlin"}' - assert streaming_chunk.tool_calls[5].id is None - assert streaming_chunk.tool_calls[5].index == 5 - - -def test_convert_google_genai_response_to_chatmessage_parses_cached_tokens(monkeypatch): - """When the API response includes cached_content_token_count in usage_metadata, it is parsed into meta['usage'].""" - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - - # Minimal candidate with one text part - mock_part = Mock() - mock_part.text = "Four." - mock_part.function_call = None - mock_part.thought = False - mock_part.thought_signature = None - mock_content = Mock() - mock_content.parts = [mock_part] - mock_candidate = Mock() - mock_candidate.content = mock_content - mock_candidate.finish_reason = "STOP" - - # Usage metadata including cached tokens (as returned by API when cache is used) - mock_usage = Mock() - mock_usage.prompt_token_count = 1000 - mock_usage.candidates_token_count = 5 - mock_usage.total_token_count = 1005 - mock_usage.cached_content_token_count = 800 - mock_usage.thoughts_token_count = None - - mock_response = Mock() - mock_response.candidates = [mock_candidate] - mock_response.usage_metadata = mock_usage - - message = _convert_google_genai_response_to_chatmessage(mock_response, "gemini-2.5-flash") - - assert message.meta is not None - assert "usage" in message.meta - usage = message.meta["usage"] - assert usage["prompt_tokens"] == 1000 - assert usage["completion_tokens"] == 5 - assert usage["total_tokens"] == 1005 - assert usage["cached_content_token_count"] == 800 - - -def test_convert_usage_metadata_to_serializable(): - """_convert_usage_metadata_to_serializable builds a serializable dict from a UsageMetadata object.""" - assert _convert_usage_metadata_to_serializable(None) == {} - assert _convert_usage_metadata_to_serializable(False) == {} - - usage_metadata = types.GenerateContentResponseUsageMetadata( - prompt_token_count=100, - candidates_token_count=5, - total_token_count=105, - ) - result = _convert_usage_metadata_to_serializable(usage_metadata) - assert result["prompt_token_count"] == 100 - assert result["candidates_token_count"] == 5 - assert result["total_token_count"] == 105 - assert len(result) == 3 - - # Serialization of zero and composite types (ModalityTokenCount, lists) - modality_token_count = types.ModalityTokenCount(modality=types.Modality.TEXT, tokenCount=100) - usage_with_details = types.GenerateContentResponseUsageMetadata( - prompt_token_count=0, - candidates_tokens_details=[modality_token_count], - ) - result2 = _convert_usage_metadata_to_serializable(usage_with_details) - assert result2["prompt_token_count"] == 0 - assert result2["candidates_tokens_details"] == [{"modality": "TEXT", "token_count": 100}] - - -class TestGoogleGenAIChatGenerator: +class TestGoogleGenAIChatGeneratorInitSerDe: def test_init_default(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") component = GoogleGenAIChatGenerator() @@ -504,28 +96,6 @@ def test_init_with_toolset(self, tools, monkeypatch): generator = GoogleGenAIChatGenerator(tools=toolset) assert generator._tools == toolset - def test_to_dict_with_toolset(self, tools, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - toolset = Toolset(tools) - generator = GoogleGenAIChatGenerator(tools=toolset) - data = generator.to_dict() - - assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset" - assert "tools" in data["init_parameters"]["tools"]["data"] - assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools) - - def test_from_dict_with_toolset(self, tools, monkeypatch): - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - toolset = Toolset(tools) - component = GoogleGenAIChatGenerator(tools=toolset) - data = component.to_dict() - - deserialized_component = GoogleGenAIChatGenerator.from_dict(data) - - assert isinstance(deserialized_component._tools, Toolset) - assert len(deserialized_component._tools) == len(tools) - assert all(isinstance(tool, Tool) for tool in deserialized_component._tools) - def test_init_with_mixed_tools_and_toolsets(self, monkeypatch): """Test initialization with a mixed list of Tools and Toolsets.""" monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") @@ -561,6 +131,28 @@ def test_init_with_mixed_tools_and_toolsets(self, monkeypatch): assert isinstance(generator._tools[1], Toolset) assert isinstance(generator._tools[2], Tool) + def test_to_dict_with_toolset(self, tools, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + toolset = Toolset(tools) + generator = GoogleGenAIChatGenerator(tools=toolset) + data = generator.to_dict() + + assert data["init_parameters"]["tools"]["type"] == "haystack.tools.toolset.Toolset" + assert "tools" in data["init_parameters"]["tools"]["data"] + assert len(data["init_parameters"]["tools"]["data"]["tools"]) == len(tools) + + def test_from_dict_with_toolset(self, tools, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + toolset = Toolset(tools) + component = GoogleGenAIChatGenerator(tools=toolset) + data = component.to_dict() + + deserialized_component = GoogleGenAIChatGenerator.from_dict(data) + + assert isinstance(deserialized_component._tools, Toolset) + assert len(deserialized_component._tools) == len(tools) + assert all(isinstance(tool, Tool) for tool in deserialized_component._tools) + def test_serde_with_mixed_tools_and_toolsets(self, monkeypatch): """Test serialization/deserialization with mixed Tools and Toolsets.""" monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") @@ -599,267 +191,12 @@ def test_serde_with_mixed_tools_and_toolsets(self, monkeypatch): assert len(restored._tools[1]) == 1 -class TestMessagesConversion: - def test_convert_message_to_google_genai_format_complex(self): - """ - Test that the GoogleGenAIChatGenerator can convert a complex sequence of ChatMessages to Google GenAI format. - In particular, we check that different tool results are handled properly in sequence. - """ - messages = [ - ChatMessage.from_system("You are good assistant"), - ChatMessage.from_user("What's the weather like in Paris? And how much is 2+2?"), - ChatMessage.from_assistant( - text="", - tool_calls=[ - ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}), - ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}), - ], - ), - ChatMessage.from_tool( - tool_result="22° C", origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) - ), - ChatMessage.from_tool( - tool_result="4", origin=ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}) - ), - ] - - # Test system message handling (should be handled separately in Google GenAI) - system_message = messages[0] - assert system_message.is_from(ChatRole.SYSTEM) - - # Test user message conversion - user_message = messages[1] - google_content = _convert_message_to_google_genai_format(user_message) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert google_content.parts[0].text == "What's the weather like in Paris? And how much is 2+2?" - - # Test assistant message with tool calls - assistant_message = messages[2] - google_content = _convert_message_to_google_genai_format(assistant_message) - assert google_content.role == "model" - assert len(google_content.parts) == 2 - assert google_content.parts[0].function_call.name == "weather" - assert google_content.parts[0].function_call.args == {"city": "Paris"} - assert google_content.parts[1].function_call.name == "math" - assert google_content.parts[1].function_call.args == {"expression": "2+2"} - - # Test tool result messages - tool_result_1 = messages[3] - google_content = _convert_message_to_google_genai_format(tool_result_1) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert google_content.parts[0].function_response.name == "weather" - assert google_content.parts[0].function_response.response == {"result": "22° C"} - - tool_result_2 = messages[4] - google_content = _convert_message_to_google_genai_format(tool_result_2) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert google_content.parts[0].function_response.name == "math" - assert google_content.parts[0].function_response.response == {"result": "4"} - - def test_convert_message_to_google_genai_format_with_single_image(self, test_files_path): - """Test converting a message with a single image to Google GenAI format.""" - apple_path = test_files_path / "apple.jpg" - apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) - - message = ChatMessage.from_user(content_parts=["What's in this image?", apple_content]) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "user" - assert len(google_content.parts) == 2 - - # First part should be text - assert google_content.parts[0].text == "What's in this image?" - - # Second part should be image data - assert hasattr(google_content.parts[1], "inline_data") - assert google_content.parts[1].inline_data is not None - assert google_content.parts[1].inline_data.mime_type == "image/jpeg" - assert google_content.parts[1].inline_data.data is not None - assert len(google_content.parts[1].inline_data.data) > 0 - - def test_convert_message_to_google_genai_format_with_multiple_images(self, test_files_path): - """Test converting a message with multiple images in mixed content to Google GenAI format.""" - apple_path = test_files_path / "apple.jpg" - banana_path = test_files_path / "banana.png" - - apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) - banana_content = ImageContent.from_file_path(banana_path, size=(100, 100)) - - message = ChatMessage.from_user( - content_parts=[ - "Compare these fruits. First:", - apple_content, - "Second:", - banana_content, - "Which is healthier?", - ] - ) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "user" - assert len(google_content.parts) == 5 - - # Verify the exact order is preserved - assert google_content.parts[0].text == "Compare these fruits. First:" - - # First image (apple) - assert hasattr(google_content.parts[1], "inline_data") - assert google_content.parts[1].inline_data.mime_type == "image/jpeg" - assert google_content.parts[1].inline_data.data is not None - - assert google_content.parts[2].text == "Second:" - - # Second image (banana) - assert hasattr(google_content.parts[3], "inline_data") - assert google_content.parts[3].inline_data.mime_type == "image/png" - assert google_content.parts[3].inline_data.data is not None - - assert google_content.parts[4].text == "Which is healthier?" - - def test_convert_message_to_google_genai_format_image_with_minimal_text(self, test_files_path): - """Test converting a message with minimal text and image to Google GenAI format.""" - apple_path = test_files_path / "apple.jpg" - apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) - - # Haystack requires at least one textual part for user messages, so we use minimal text - message = ChatMessage.from_user(content_parts=["", apple_content]) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "user" - # Empty text should be filtered out by our implementation, leaving only the image - assert len(google_content.parts) == 1 - - # Should only have the image part (empty text filtered out) - assert hasattr(google_content.parts[0], "inline_data") - assert google_content.parts[0].inline_data.mime_type == "image/jpeg" - assert google_content.parts[0].inline_data.data is not None - - def test_convert_message_to_google_genai_format_with_thought_signatures(self): - """Test converting an assistant message with thought signatures for multi-turn context preservation.""" - - # Create an assistant message with tool calls and thought signatures in meta - tool_call = ToolCall(id="call_123", tool_name="weather", arguments={"city": "Paris"}) - - # Thought signatures are stored in meta when thinking is enabled with tools - # They must be base64 encoded as per the API requirements - thought_signatures = [ - { - "part_index": 0, - "signature": base64.b64encode(b"encrypted_mock_thought_signature_1").decode("utf-8"), - "has_text": True, - "has_function_call": False, - "is_thought": False, - }, - { - "part_index": 1, - "signature": base64.b64encode(b"encrypted_mock_thought_signature_2").decode("utf-8"), - "has_text": False, - "has_function_call": True, - "is_thought": False, - }, - ] - - message = ChatMessage.from_assistant( - text="I'll check the weather for you", - tool_calls=[tool_call], - meta={"thought_signatures": thought_signatures}, - ) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "model" - assert len(google_content.parts) == 2 - - # First part should have text and its thought signature - assert google_content.parts[0].text == "I'll check the weather for you" - # thought_signature is returned as bytes from the API - assert google_content.parts[0].thought_signature == b"encrypted_mock_thought_signature_1" - - # Second part should have function call and its thought signature - assert google_content.parts[1].function_call.name == "weather" - assert google_content.parts[1].function_call.args == {"city": "Paris"} - assert google_content.parts[1].thought_signature == b"encrypted_mock_thought_signature_2" - - def test_convert_message_to_google_genai_format_with_reasoning_content(self): - """Test that ReasoningContent is properly skipped during conversion.""" - # ReasoningContent is for human transparency only, not sent to the API - reasoning = ReasoningContent(reasoning_text="Of Life, the Universe and Everything...") - - # Create a message with both text and reasoning content - message = ChatMessage.from_assistant(text="Forty-two", reasoning=reasoning) - - google_content = _convert_message_to_google_genai_format(message) - - assert google_content.role == "model" - assert len(google_content.parts) == 1 - - # Only the text should be included, reasoning content should be skipped - assert google_content.parts[0].text == "Forty-two" - # Verify no thought part was created (reasoning is not sent to API) - assert not hasattr(google_content.parts[0], "thought") or not google_content.parts[0].thought - - def test_convert_message_to_google_genai_format_tool_message(self): - tool_call = ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) - message = ChatMessage.from_tool(tool_result="22° C", origin=tool_call) - google_content = _convert_message_to_google_genai_format(message) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert isinstance(google_content.parts[0].function_response, types.FunctionResponse) - assert google_content.parts[0].function_response.id == "123" - assert google_content.parts[0].function_response.name == "weather" - assert google_content.parts[0].function_response.response == {"result": "22° C"} - assert google_content.parts[0].function_response.parts is None - - def test_convert_message_to_google_genai_format_image_in_tool_result(self): - tool_call = ToolCall(id="123", tool_name="image_retriever", arguments={}) - - base64_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=" - tool_result = [TextContent("Here is the image"), ImageContent(base64_image=base64_str, mime_type="image/jpeg")] - - message = ChatMessage.from_tool(tool_result=tool_result, origin=tool_call) - google_content = _convert_message_to_google_genai_format(message) - assert google_content.role == "user" - assert len(google_content.parts) == 1 - assert isinstance(google_content.parts[0].function_response, types.FunctionResponse) - assert google_content.parts[0].function_response.id == "123" - assert google_content.parts[0].function_response.name == "image_retriever" - assert google_content.parts[0].function_response.response == {"result": ""} - assert len(google_content.parts[0].function_response.parts) == 2 - assert isinstance(google_content.parts[0].function_response.parts[0], types.FunctionResponsePart) - assert google_content.parts[0].function_response.parts[0].inline_data.mime_type == "text/plain" - assert google_content.parts[0].function_response.parts[0].inline_data.data == b"Here is the image" - assert isinstance(google_content.parts[0].function_response.parts[1], types.FunctionResponsePart) - assert google_content.parts[0].function_response.parts[1].inline_data.mime_type == "image/jpeg" - assert google_content.parts[0].function_response.parts[1].inline_data.data == base64.b64decode(base64_str) - assert len(google_content.parts[0].function_response.parts[1].inline_data.data) > 0 - - def test_convert_message_to_google_genai_format_invalid_tool_result_type(self): - tool_call = ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) - message = ChatMessage.from_tool(tool_result=256, origin=tool_call) - with pytest.raises(ValueError, match="Unsupported content type in tool call result"): - _convert_message_to_google_genai_format(message) - - message = ChatMessage.from_tool(tool_result=[TextContent("This is supported"), 256], origin=tool_call) - with pytest.raises( - ValueError, - match=( - r"Unsupported content type in tool call result list. " - "Only TextContent and ImageContent are supported." - ), - ): - _convert_message_to_google_genai_format(message) - - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration +@pytest.mark.skipif( + not os.environ.get("GOOGLE_API_KEY", None), + reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", +) +@pytest.mark.integration +class TestGoogleGenAIChatGeneratorInference: def test_live_run(self) -> None: chat_messages = [ChatMessage.from_user("What's the capital of France")] component = GoogleGenAIChatGenerator() @@ -870,15 +207,6 @@ def test_live_run(self) -> None: assert "gemini-2.5-flash" in message.meta["model"] assert message.meta["finish_reason"] == "stop" - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_run_with_multiple_images_mixed_content(self, test_files_path): """Test that multiple images with interleaved text maintain proper ordering.""" client = GoogleGenAIChatGenerator() @@ -920,11 +248,44 @@ def test_run_with_multiple_images_mixed_content(self, test_files_path): f"Apple should be mentioned before banana in the response. Got: {first_reply.text}" ) - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration + def test_live_run_with_file_content(self, test_files_path): + pdf_path = test_files_path / "sample_pdf_3.pdf" + + file_content = FileContent.from_file_path(file_path=pdf_path) + + chat_messages = [ + ChatMessage.from_user( + content_parts=[file_content, "Is this document a paper about LLMs? Respond with 'yes' or 'no' only."] + ) + ] + + generator = GoogleGenAIChatGenerator() + results = generator.run(chat_messages) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + + assert message.is_from(ChatRole.ASSISTANT) + + assert message.text + indicates_no = any( + phrase in message.text.lower() + for phrase in ( + "no", + "nope", + "not about", + "not a paper about", + "it is not", + "it's not", + "the answer is no", + "does not", + "doesn't", + "negative", + ) + ) + + assert indicates_no is True + def test_live_run_streaming(self): component = GoogleGenAIChatGenerator() component_info = ComponentInfo.from_component(component) @@ -951,11 +312,6 @@ def __call__(self, chunk: StreamingChunk) -> None: assert message.text and "paris" in message.text.lower(), "Response does not contain Paris" assert message.meta["finish_reason"] == "stop" - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_tools_streaming(self, tools): """ Integration test that the GoogleGenAIChatGenerator component can run with tools and streaming. @@ -989,11 +345,6 @@ def test_live_run_with_tools_streaming(self, tools): assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_toolset(self, tools): """Test that GoogleGenAIChatGenerator can run with a Toolset.""" toolset = Toolset(tools) @@ -1029,11 +380,6 @@ def test_live_run_with_toolset(self, tools): "Response does not contain Paris or weather" ) - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_parallel_tools(self, tools): """ Integration test that the GoogleGenAIChatGenerator component can run with parallel tools. @@ -1081,11 +427,6 @@ def test_live_run_with_parallel_tools(self, tools): # Check that the response mentions both temperature readings assert "22" in final_message.text and "15" in final_message.text - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_mixed_tools(self): """ Integration test that verifies GoogleGenAIChatGenerator works with mixed Tool and Toolset. @@ -1177,18 +518,15 @@ def test_live_run_with_mixed_tools(self): assert "paris" in final_message.text.lower() assert "berlin" in final_message.text.lower() - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_thinking(self): """ Integration test for the thinking feature with a model that supports it. """ # We use a model that supports the thinking feature chat_messages = [ChatMessage.from_user("Why is the sky blue? Explain in one sentence.")] - component = GoogleGenAIChatGenerator(model="gemini-2.5-pro", generation_kwargs={"thinking_budget": -1}) + component = GoogleGenAIChatGenerator( + model="gemini-3-flash-preview", generation_kwargs={"thinking_level": "low"} + ) results = component.run(chat_messages) assert len(results["replies"]) == 1 @@ -1206,11 +544,6 @@ def test_live_run_with_thinking(self): assert message.meta["usage"]["thoughts_token_count"] is not None assert message.meta["usage"]["thoughts_token_count"] > 0 - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_thinking_and_tools_multi_turn(self, tools): """ Integration test for thought signatures preservation in multi-turn conversations with tools. @@ -1218,9 +551,9 @@ def test_live_run_with_thinking_and_tools_multi_turn(self, tools): """ # Use a model that supports thinking with tools component = GoogleGenAIChatGenerator( - model="gemini-2.5-pro", + model="gemini-3-flash-preview", tools=tools, - generation_kwargs={"thinking_budget": -1}, # Dynamic allocation + generation_kwargs={"thinking_level": "low"}, # Dynamic allocation ) # First turn: Ask about the weather @@ -1266,11 +599,6 @@ def test_live_run_with_thinking_and_tools_multi_turn(self, tools): # The model should maintain context from previous turns assert "22" in second_response.text or "sunny" in second_response.text.lower() - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) - @pytest.mark.integration def test_live_run_with_thinking_unsupported_model_fails_fast(self): """ Integration test to verify that thinking configuration fails fast with unsupported models. @@ -1290,11 +618,6 @@ def test_live_run_with_thinking_unsupported_model_fails_fast(self): assert "thinking_budget" in error_message or "thinking features" in error_message assert "Try removing" in error_message or "use a different model" in error_message - @pytest.mark.integration - @pytest.mark.skipif( - not os.environ.get("GOOGLE_API_KEY", None), - reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", - ) def test_live_run_agent_with_images_in_tool_result(self, test_files_path): def retrieve_image(): return [ @@ -1325,7 +648,7 @@ def retrieve_image(): ) @pytest.mark.integration @pytest.mark.asyncio -class TestAsyncGoogleGenAIChatGenerator: +class TestAsyncGoogleGenAIChatGeneratorInference: """Test class for async functionality of GoogleGenAIChatGenerator.""" async def test_live_run_async(self) -> None: @@ -1386,7 +709,9 @@ async def test_live_run_async_with_thinking(self): Async integration test for the thinking feature. """ chat_messages = [ChatMessage.from_user("Why is the sky blue? Explain in one sentence.")] - component = GoogleGenAIChatGenerator(model="gemini-2.5-pro", generation_kwargs={"thinking_budget": -1}) + component = GoogleGenAIChatGenerator( + model="gemini-3-flash-preview", generation_kwargs={"thinking_level": "low"} + ) results = await component.run_async(chat_messages) assert len(results["replies"]) == 1 @@ -1445,132 +770,3 @@ async def test_concurrent_async_calls(self): assert result["replies"][0].text assert result["replies"][0].meta["model"] assert result["replies"][0].meta["finish_reason"] == "stop" - - -def test_aggregate_streaming_chunks_with_reasoning(monkeypatch): - """Test the _aggregate_streaming_chunks_with_reasoning method for reasoning content aggregation.""" - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - component_info = ComponentInfo.from_component(component) - - # Create mock streaming chunks with reasoning content - chunk1 = Mock() - chunk1.content = "Hello" - chunk1.tool_calls = [] - chunk1.meta = {"usage": {"prompt_tokens": 10, "completion_tokens": 5}} - chunk1.component_info = component_info - chunk1.reasoning = None - - chunk2 = Mock() - chunk2.content = " world" - chunk2.tool_calls = [] - chunk2.meta = {"usage": {"prompt_tokens": 10, "completion_tokens": 8}} - chunk2.component_info = component_info - chunk2.reasoning = None - - # Mock the final chunk with reasoning - final_chunk = Mock() - final_chunk.content = "" - final_chunk.tool_calls = [] - final_chunk.meta = { - "usage": {"prompt_tokens": 10, "completion_tokens": 13, "thoughts_token_count": 5}, - "model": "gemini-2.5-pro", - } - final_chunk.component_info = component_info - final_chunk.reasoning = ReasoningContent(reasoning_text="I should greet the user politely") - - # Add reasoning deltas to the final chunk meta (this is how the real method works) - final_chunk.meta["reasoning_deltas"] = [{"type": "reasoning", "content": "I should greet the user politely"}] - - # Test aggregation - result = GoogleGenAIChatGenerator._aggregate_streaming_chunks_with_reasoning([chunk1, chunk2, final_chunk]) - - # Verify the aggregated message - assert result.text == "Hello world" - assert result.tool_calls == [] - assert result.reasoning is not None - assert result.reasoning.reasoning_text == "I should greet the user politely" - assert result.meta["usage"]["prompt_tokens"] == 10 - assert result.meta["usage"]["completion_tokens"] == 13 - assert result.meta["usage"]["thoughts_token_count"] == 5 - assert result.meta["model"] == "gemini-2.5-pro" - - -def test_process_thinking_budget(monkeypatch): - """Test the _process_thinking_config method with different thinking_budget values.""" - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - - # Test valid thinking_budget values - generation_kwargs = {"thinking_budget": 1024, "temperature": 0.7} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - - # thinking_budget should be moved to thinking_config - assert "thinking_budget" not in result - assert "thinking_config" in result - assert result["thinking_config"].thinking_budget == 1024 - # Other kwargs should be preserved - assert result["temperature"] == 0.7 - - # Test dynamic allocation (-1) - generation_kwargs = {"thinking_budget": -1} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_budget == -1 - - # Test zero (disable thinking) - generation_kwargs = {"thinking_budget": 0} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_budget == 0 - - # Test large value - generation_kwargs = {"thinking_budget": 24576} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_budget == 24576 - - # Test when thinking_budget is not present - generation_kwargs = {"temperature": 0.5} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result == generation_kwargs # No changes - - # Test invalid type (should fall back to dynamic) - generation_kwargs = {"thinking_budget": "invalid", "temperature": 0.5} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_budget == -1 # Dynamic allocation - assert result["temperature"] == 0.5 - - -def test_process_thinking_level(monkeypatch): - """Test the _process_thinking_config method with different thinking_level values.""" - monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") - component = GoogleGenAIChatGenerator() - - # Test valid thinking_level values - generation_kwargs = {"thinking_level": "high", "temperature": 0.7} - result = component._process_thinking_config(generation_kwargs.copy()) - - # thinking_level should be moved to thinking_config - assert "thinking_level" not in result - assert "thinking_config" in result - assert result["thinking_config"].thinking_level == types.ThinkingLevel.HIGH - # Other kwargs should be preserved - assert result["temperature"] == 0.7 - - # Test THINKING_LEVEL_LOW in upper case - generation_kwargs = {"thinking_level": "LOW"} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_level == types.ThinkingLevel.LOW - - # Test THINKING_LEVEL_UNSPECIFIED - generation_kwargs = {"thinking_level": "test"} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_level == types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - - # Test when thinking_level is not present - generation_kwargs = {"temperature": 0.5} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result == generation_kwargs # No changes - - # Test invalid type (should fall back to THINKING_LEVEL_UNSPECIFIED) - generation_kwargs = {"thinking_level": 123, "temperature": 0.5} - result = GoogleGenAIChatGenerator._process_thinking_config(generation_kwargs.copy()) - assert result["thinking_config"].thinking_level == types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED - assert result["temperature"] == 0.5 diff --git a/integrations/google_genai/tests/test_chat_generator_utils.py b/integrations/google_genai/tests/test_chat_generator_utils.py new file mode 100644 index 0000000000..719d7a3360 --- /dev/null +++ b/integrations/google_genai/tests/test_chat_generator_utils.py @@ -0,0 +1,805 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import base64 +from unittest.mock import Mock + +import pytest +from google.genai import types +from haystack.dataclasses import ( + ChatMessage, + ChatRole, + ComponentInfo, + FileContent, + ImageContent, + ReasoningContent, + TextContent, + ToolCall, +) + +from haystack_integrations.components.generators.google_genai.chat.chat_generator import ( + GoogleGenAIChatGenerator, +) +from haystack_integrations.components.generators.google_genai.chat.utils import ( + _aggregate_streaming_chunks_with_reasoning, + _convert_google_chunk_to_streaming_chunk, + _convert_google_genai_response_to_chatmessage, + _convert_message_to_google_genai_format, + _convert_usage_metadata_to_serializable, + _process_thinking_config, +) + + +def test_process_thinking_budget(): + """Test the _process_thinking_config method with different thinking_budget values.""" + + # Test valid thinking_budget values + generation_kwargs = {"thinking_budget": 1024, "temperature": 0.7} + result = _process_thinking_config(generation_kwargs.copy()) + + # thinking_budget should be moved to thinking_config + assert "thinking_budget" not in result + assert "thinking_config" in result + assert result["thinking_config"].thinking_budget == 1024 + # Other kwargs should be preserved + assert result["temperature"] == 0.7 + + # Test dynamic allocation (-1) + generation_kwargs = {"thinking_budget": -1} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_budget == -1 + + # Test zero (disable thinking) + generation_kwargs = {"thinking_budget": 0} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_budget == 0 + + # Test large value + generation_kwargs = {"thinking_budget": 24576} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_budget == 24576 + + # Test when thinking_budget is not present + generation_kwargs = {"temperature": 0.5} + result = _process_thinking_config(generation_kwargs.copy()) + assert result == generation_kwargs # No changes + + # Test invalid type (should fall back to dynamic) + generation_kwargs = {"thinking_budget": "invalid", "temperature": 0.5} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_budget == -1 # Dynamic allocation + assert result["temperature"] == 0.5 + + +def test_process_thinking_level(): + """Test the _process_thinking_config method with different thinking_level values.""" + + # Test valid thinking_level values + generation_kwargs = {"thinking_level": "high", "temperature": 0.7} + result = _process_thinking_config(generation_kwargs.copy()) + + # thinking_level should be moved to thinking_config + assert "thinking_level" not in result + assert "thinking_config" in result + assert result["thinking_config"].thinking_level == types.ThinkingLevel.HIGH + # Other kwargs should be preserved + assert result["temperature"] == 0.7 + + # Test THINKING_LEVEL_LOW in upper case + generation_kwargs = {"thinking_level": "LOW"} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_level == types.ThinkingLevel.LOW + + # Test THINKING_LEVEL_UNSPECIFIED + generation_kwargs = {"thinking_level": "test"} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_level == types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + + # Test when thinking_level is not present + generation_kwargs = {"temperature": 0.5} + result = _process_thinking_config(generation_kwargs.copy()) + assert result == generation_kwargs # No changes + + # Test invalid type (should fall back to THINKING_LEVEL_UNSPECIFIED) + generation_kwargs = {"thinking_level": 123, "temperature": 0.5} + result = _process_thinking_config(generation_kwargs.copy()) + assert result["thinking_config"].thinking_level == types.ThinkingLevel.THINKING_LEVEL_UNSPECIFIED + assert result["temperature"] == 0.5 + + +class TestStreamingChunkConversion: + def test_convert_google_chunk_to_streaming_chunk_text_only(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + mock_chunk = Mock() + mock_candidate = Mock() + mock_candidate.finish_reason = "STOP" + mock_chunk.candidates = [mock_candidate] + + mock_content = Mock() + mock_content.parts = [] + mock_part = Mock() + mock_part.text = "Hello, world!" + mock_part.function_call = None + # Explicitly set thought=False to simulate a regular (non-thought) part + mock_part.thought = False + mock_content.parts.append(mock_part) + mock_candidate.content = mock_content + + chunk = _convert_google_chunk_to_streaming_chunk( + chunk=mock_chunk, + index=0, + component_info=component_info, + model="gemini-2.5-flash", + ) + + assert chunk.content == "Hello, world!" + assert chunk.tool_calls == [] + assert chunk.finish_reason == "stop" + assert chunk.index == 0 + assert "received_at" in chunk.meta + assert chunk.component_info == component_info + + def test_convert_google_chunk_to_streaming_chunk_tool_call(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + mock_chunk = Mock() + mock_candidate = Mock() + mock_candidate.finish_reason = "STOP" + mock_chunk.candidates = [mock_candidate] + + mock_content = Mock() + mock_content.parts = [] + mock_part = Mock() + mock_part.text = None + mock_function_call = Mock() + mock_function_call.name = "weather" + mock_function_call.args = {"city": "Paris"} + mock_function_call.id = "call_123" + mock_part.function_call = mock_function_call + mock_content.parts.append(mock_part) + mock_candidate.content = mock_content + + chunk = _convert_google_chunk_to_streaming_chunk( + chunk=mock_chunk, index=0, component_info=component_info, model="gemini-2.5-flash" + ) + + assert chunk.content == "" + assert chunk.tool_calls is not None + assert len(chunk.tool_calls) == 1 + assert chunk.tool_calls[0].tool_name == "weather" + assert chunk.tool_calls[0].arguments == '{"city": "Paris"}' + assert chunk.tool_calls[0].id == "call_123" + assert chunk.finish_reason == "stop" + assert chunk.index == 0 + assert "received_at" in chunk.meta + assert chunk.component_info == component_info + + def test_convert_google_chunk_to_streaming_chunk_mixed_content(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + mock_chunk = Mock() + mock_candidate = Mock() + mock_candidate.finish_reason = "STOP" + mock_chunk.candidates = [mock_candidate] + + mock_content = Mock() + mock_content.parts = [] + + mock_text_part = Mock() + mock_text_part.text = "I'll check the weather for you." + mock_text_part.function_call = None + # Explicitly set thought=False to simulate a regular (non-thought) part + mock_text_part.thought = False + mock_content.parts.append(mock_text_part) + + mock_tool_part = Mock() + mock_tool_part.text = None + mock_function_call = Mock() + mock_function_call.name = "weather" + mock_function_call.args = {"city": "London"} + mock_function_call.id = "call_456" + mock_tool_part.function_call = mock_function_call + mock_content.parts.append(mock_tool_part) + + mock_candidate.content = mock_content + + chunk = _convert_google_chunk_to_streaming_chunk( + chunk=mock_chunk, index=0, component_info=component_info, model="gemini-2.5-flash" + ) + + # When both text and tool calls are present, tool calls are prioritized + assert chunk.content == "" + assert chunk.tool_calls is not None + assert len(chunk.tool_calls) == 1 + assert chunk.tool_calls[0].tool_name == "weather" + assert chunk.tool_calls[0].arguments == '{"city": "London"}' + assert chunk.finish_reason == "stop" + assert chunk.component_info == component_info + + def test_convert_google_chunk_to_streaming_chunk_empty_parts(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + mock_chunk = Mock() + mock_candidate = Mock() + mock_content = Mock() + mock_content.parts = [] + mock_candidate.content = mock_content + mock_chunk.candidates = [mock_candidate] + + chunk = _convert_google_chunk_to_streaming_chunk( + chunk=mock_chunk, index=0, component_info=component_info, model="gemini-2.5-flash" + ) + + assert chunk.content == "" + assert chunk.tool_calls == [] + assert chunk.component_info == component_info + + def test_convert_google_chunk_to_streaming_chunk_real_example(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + component = GoogleGenAIChatGenerator() + component_info = ComponentInfo.from_component(component) + + # Chunk 1: Text only + chunk1_parts = [ + types.Part( + text="I'll get the weather information for Paris and Berlin", function_call=None, function_response=None + ) + ] + chunk1_content = types.Content(role="model", parts=chunk1_parts) + chunk1_candidate = types.Candidate( + content=chunk1_content, + finish_reason=None, + index=None, + safety_ratings=None, + citation_metadata=None, + grounding_metadata=None, + finish_message=None, + token_count=None, + logprobs_result=None, + avg_logprobs=None, + url_context_metadata=None, + ) + chunk1_usage = types.GenerateContentResponseUsageMetadata( + prompt_token_count=217, candidates_token_count=None, total_token_count=217 + ) + chunk1 = types.GenerateContentResponse( + candidates=[chunk1_candidate], + usage_metadata=chunk1_usage, + model_version="gemini-2.5-flash", + response_id=None, + create_time=None, + prompt_feedback=None, + automatic_function_calling_history=None, + parsed=None, + ) + + streaming_chunk1 = _convert_google_chunk_to_streaming_chunk( + chunk=chunk1, index=0, component_info=component_info, model="gemini-2.5-flash" + ) + assert streaming_chunk1.content == "I'll get the weather information for Paris and Berlin" + assert streaming_chunk1.tool_calls == [] + assert streaming_chunk1.finish_reason is None + assert streaming_chunk1.index == 0 + assert "received_at" in streaming_chunk1.meta + assert streaming_chunk1.meta["model"] == "gemini-2.5-flash" + assert "usage" in streaming_chunk1.meta + assert streaming_chunk1.meta["usage"]["prompt_tokens"] == 217 + assert streaming_chunk1.meta["usage"]["completion_tokens"] is None + assert streaming_chunk1.meta["usage"]["total_tokens"] == 217 + assert streaming_chunk1.component_info == component_info + + # Chunk 2: Text only + chunk2_parts = [ + types.Part(text=" and present it in a structured format.", function_call=None, function_response=None) + ] + chunk2_content = types.Content(role="model", parts=chunk2_parts) + chunk2_candidate = types.Candidate( + content=chunk2_content, + finish_reason=None, + index=None, + safety_ratings=None, + citation_metadata=None, + grounding_metadata=None, + finish_message=None, + token_count=None, + logprobs_result=None, + avg_logprobs=None, + url_context_metadata=None, + ) + chunk2_usage = types.GenerateContentResponseUsageMetadata( + prompt_token_count=217, candidates_token_count=None, total_token_count=217 + ) + chunk2 = types.GenerateContentResponse( + candidates=[chunk2_candidate], + usage_metadata=chunk2_usage, + model_version="gemini-2.5-flash", + response_id=None, + create_time=None, + prompt_feedback=None, + automatic_function_calling_history=None, + parsed=None, + ) + + streaming_chunk2 = _convert_google_chunk_to_streaming_chunk( + chunk=chunk2, index=1, component_info=component_info, model="gemini-2.5-flash" + ) + assert streaming_chunk2.content == " and present it in a structured format." + assert streaming_chunk2.tool_calls == [] + assert streaming_chunk2.finish_reason is None + assert streaming_chunk2.index == 1 + assert "received_at" in streaming_chunk2.meta + assert streaming_chunk2.meta["model"] == "gemini-2.5-flash" + assert "usage" in streaming_chunk2.meta + assert streaming_chunk2.meta["usage"]["prompt_tokens"] == 217 + assert streaming_chunk2.meta["usage"]["completion_tokens"] is None + assert streaming_chunk2.meta["usage"]["total_tokens"] == 217 + assert streaming_chunk2.component_info == component_info + + # Chunk 3: Multiple tool calls (6 function calls) for 2 cities with 3 tools each + fc1 = types.FunctionCall(id=None, name="get_weather", args={"city": "Paris"}) + fc2 = types.FunctionCall(id=None, name="get_population", args={"city": "Paris"}) + fc3 = types.FunctionCall(id=None, name="get_time", args={"city": "Paris"}) + fc4 = types.FunctionCall(id=None, name="get_weather", args={"city": "Berlin"}) + fc5 = types.FunctionCall(id=None, name="get_population", args={"city": "Berlin"}) + fc6 = types.FunctionCall(id=None, name="get_time", args={"city": "Berlin"}) + + parts = [ + types.Part(text=None, function_call=fc1, function_response=None), + types.Part(text=None, function_call=fc2, function_response=None), + types.Part(text=None, function_call=fc3, function_response=None), + types.Part(text=None, function_call=fc4, function_response=None), + types.Part(text=None, function_call=fc5, function_response=None), + types.Part(text=None, function_call=fc6, function_response=None), + ] + + content = types.Content(role="model", parts=parts) + candidate = types.Candidate( + content=content, + finish_reason=types.FinishReason.STOP, + index=None, + safety_ratings=None, + citation_metadata=None, + grounding_metadata=None, + finish_message=None, + token_count=None, + logprobs_result=None, + avg_logprobs=None, + url_context_metadata=None, + ) + + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=144, candidates_token_count=121, total_token_count=265 + ) + chunk = types.GenerateContentResponse( + candidates=[candidate], + usage_metadata=usage_metadata, + model_version="gemini-2.5-flash", + response_id=None, + create_time=None, + prompt_feedback=None, + automatic_function_calling_history=None, + parsed=None, + ) + + streaming_chunk = _convert_google_chunk_to_streaming_chunk( + chunk=chunk, index=2, component_info=component_info, model="gemini-2.5-flash" + ) + assert streaming_chunk.content == "" + assert streaming_chunk.tool_calls is not None + assert len(streaming_chunk.tool_calls) == 6 + assert streaming_chunk.finish_reason == "stop" + assert streaming_chunk.index == 2 + assert "received_at" in streaming_chunk.meta + assert streaming_chunk.meta["model"] == "gemini-2.5-flash" + assert streaming_chunk.component_info == component_info + assert "usage" in streaming_chunk.meta + assert streaming_chunk.meta["usage"]["prompt_tokens"] == 144 + assert streaming_chunk.meta["usage"]["completion_tokens"] == 121 + assert streaming_chunk.meta["usage"]["total_tokens"] == 265 + + assert streaming_chunk.tool_calls[0].tool_name == "get_weather" + assert streaming_chunk.tool_calls[0].arguments == '{"city": "Paris"}' + assert streaming_chunk.tool_calls[0].id is None + assert streaming_chunk.tool_calls[0].index == 0 + + assert streaming_chunk.tool_calls[1].tool_name == "get_population" + assert streaming_chunk.tool_calls[1].arguments == '{"city": "Paris"}' + assert streaming_chunk.tool_calls[1].id is None + assert streaming_chunk.tool_calls[1].index == 1 + + assert streaming_chunk.tool_calls[2].tool_name == "get_time" + assert streaming_chunk.tool_calls[2].arguments == '{"city": "Paris"}' + assert streaming_chunk.tool_calls[2].id is None + assert streaming_chunk.tool_calls[2].index == 2 + + assert streaming_chunk.tool_calls[3].tool_name == "get_weather" + assert streaming_chunk.tool_calls[3].arguments == '{"city": "Berlin"}' + assert streaming_chunk.tool_calls[3].id is None + assert streaming_chunk.tool_calls[3].index == 3 + + assert streaming_chunk.tool_calls[4].tool_name == "get_population" + assert streaming_chunk.tool_calls[4].arguments == '{"city": "Berlin"}' + assert streaming_chunk.tool_calls[4].id is None + assert streaming_chunk.tool_calls[4].index == 4 + + assert streaming_chunk.tool_calls[5].tool_name == "get_time" + assert streaming_chunk.tool_calls[5].arguments == '{"city": "Berlin"}' + assert streaming_chunk.tool_calls[5].id is None + assert streaming_chunk.tool_calls[5].index == 5 + + def test_aggregate_streaming_chunks_with_reasoning(self): + """Test the _aggregate_streaming_chunks_with_reasoning method for reasoning content aggregation.""" + + # Create mock streaming chunks with reasoning content + chunk1 = Mock() + chunk1.content = "Hello" + chunk1.tool_calls = [] + chunk1.meta = {"usage": {"prompt_tokens": 10, "completion_tokens": 5}} + chunk1.reasoning = None + + chunk2 = Mock() + chunk2.content = " world" + chunk2.tool_calls = [] + chunk2.meta = {"usage": {"prompt_tokens": 10, "completion_tokens": 8}} + chunk2.reasoning = None + + # Mock the final chunk with reasoning + final_chunk = Mock() + final_chunk.content = "" + final_chunk.tool_calls = [] + final_chunk.meta = { + "usage": {"prompt_tokens": 10, "completion_tokens": 13, "thoughts_token_count": 5}, + "model": "gemini-2.5-pro", + } + final_chunk.reasoning = ReasoningContent(reasoning_text="I should greet the user politely") + + # Add reasoning deltas to the final chunk meta (this is how the real method works) + final_chunk.meta["reasoning_deltas"] = [{"type": "reasoning", "content": "I should greet the user politely"}] + + # Test aggregation + result = _aggregate_streaming_chunks_with_reasoning([chunk1, chunk2, final_chunk]) + + # Verify the aggregated message + assert result.text == "Hello world" + assert result.tool_calls == [] + assert result.reasoning is not None + assert result.reasoning.reasoning_text == "I should greet the user politely" + assert result.meta["usage"]["prompt_tokens"] == 10 + assert result.meta["usage"]["completion_tokens"] == 13 + assert result.meta["usage"]["thoughts_token_count"] == 5 + assert result.meta["model"] == "gemini-2.5-pro" + + +class TestConvertMessageToGoogleGenAI: + def test_convert_message_to_google_genai_format_complex(self): + """ + Test that the GoogleGenAIChatGenerator can convert a complex sequence of ChatMessages to Google GenAI format. + In particular, we check that different tool results are handled properly in sequence. + """ + messages = [ + ChatMessage.from_system("You are good assistant"), + ChatMessage.from_user("What's the weather like in Paris? And how much is 2+2?"), + ChatMessage.from_assistant( + text="", + tool_calls=[ + ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}), + ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}), + ], + ), + ChatMessage.from_tool( + tool_result="22° C", origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ), + ChatMessage.from_tool( + tool_result="4", origin=ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}) + ), + ] + + # Test system message handling (should be handled separately in Google GenAI) + system_message = messages[0] + assert system_message.is_from(ChatRole.SYSTEM) + + # Test user message conversion + user_message = messages[1] + google_content = _convert_message_to_google_genai_format(user_message) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert google_content.parts[0].text == "What's the weather like in Paris? And how much is 2+2?" + + # Test assistant message with tool calls + assistant_message = messages[2] + google_content = _convert_message_to_google_genai_format(assistant_message) + assert google_content.role == "model" + assert len(google_content.parts) == 2 + assert google_content.parts[0].function_call.name == "weather" + assert google_content.parts[0].function_call.args == {"city": "Paris"} + assert google_content.parts[1].function_call.name == "math" + assert google_content.parts[1].function_call.args == {"expression": "2+2"} + + # Test tool result messages + tool_result_1 = messages[3] + google_content = _convert_message_to_google_genai_format(tool_result_1) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert google_content.parts[0].function_response.name == "weather" + assert google_content.parts[0].function_response.response == {"result": "22° C"} + + tool_result_2 = messages[4] + google_content = _convert_message_to_google_genai_format(tool_result_2) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert google_content.parts[0].function_response.name == "math" + assert google_content.parts[0].function_response.response == {"result": "4"} + + def test_convert_message_to_google_genai_format_with_multiple_images(self, test_files_path): + """Test converting a message with multiple images in mixed content to Google GenAI format.""" + apple_path = test_files_path / "apple.jpg" + banana_path = test_files_path / "banana.png" + + apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) + banana_content = ImageContent.from_file_path(banana_path, size=(100, 100)) + + message = ChatMessage.from_user( + content_parts=[ + "Compare these fruits. First:", + apple_content, + "Second:", + banana_content, + "Which is healthier?", + ] + ) + + google_content = _convert_message_to_google_genai_format(message) + + assert google_content.role == "user" + assert len(google_content.parts) == 5 + + # Verify the exact order is preserved + assert google_content.parts[0].text == "Compare these fruits. First:" + + # First image (apple) + assert hasattr(google_content.parts[1], "inline_data") + assert google_content.parts[1].inline_data.mime_type == "image/jpeg" + assert google_content.parts[1].inline_data.data is not None + + assert google_content.parts[2].text == "Second:" + + # Second image (banana) + assert hasattr(google_content.parts[3], "inline_data") + assert google_content.parts[3].inline_data.mime_type == "image/png" + assert google_content.parts[3].inline_data.data is not None + + assert google_content.parts[4].text == "Which is healthier?" + + def test_convert_message_to_google_genai_format_image_with_minimal_text(self, test_files_path): + """Test converting a message with minimal text and image to Google GenAI format.""" + apple_path = test_files_path / "apple.jpg" + apple_content = ImageContent.from_file_path(apple_path, size=(100, 100)) + + # Haystack requires at least one textual part for user messages, so we use minimal text + message = ChatMessage.from_user(content_parts=["", apple_content]) + + google_content = _convert_message_to_google_genai_format(message) + + assert google_content.role == "user" + # Empty text should be filtered out by our implementation, leaving only the image + assert len(google_content.parts) == 1 + + # Should only have the image part (empty text filtered out) + assert hasattr(google_content.parts[0], "inline_data") + assert google_content.parts[0].inline_data.mime_type == "image/jpeg" + assert google_content.parts[0].inline_data.data is not None + + def test_convert_message_to_google_genai_file_content(self, test_files_path): + file_path = test_files_path / "sample_pdf_3.pdf" + file_content = FileContent.from_file_path(file_path) + message = ChatMessage.from_user(content_parts=["Describe this document:", file_content]) + google_content = _convert_message_to_google_genai_format(message) + assert google_content.role == "user" + assert len(google_content.parts) == 2 + + assert google_content.parts[0].text == "Describe this document:" + assert hasattr(google_content.parts[1], "inline_data") + assert google_content.parts[1].inline_data.mime_type == "application/pdf" + assert google_content.parts[1].inline_data.data is not None + + def test_convert_message_to_google_genai_file_content_in_assistant_message(self, test_files_path): + file_path = test_files_path / "sample_pdf_3.pdf" + file_content = FileContent.from_file_path(file_path) + message = ChatMessage.from_assistant("This is a document") + message._content.append(file_content) + + with pytest.raises(ValueError, match="FileContent is only supported for user messages"): + _convert_message_to_google_genai_format(message) + + def test_convert_message_to_google_genai_format_with_thought_signatures(self): + """Test converting an assistant message with thought signatures for multi-turn context preservation.""" + + # Create an assistant message with tool calls and thought signatures in meta + tool_call = ToolCall(id="call_123", tool_name="weather", arguments={"city": "Paris"}) + + # Thought signatures are stored in meta when thinking is enabled with tools + # They must be base64 encoded as per the API requirements + thought_signatures = [ + { + "part_index": 0, + "signature": base64.b64encode(b"encrypted_mock_thought_signature_1").decode("utf-8"), + "has_text": True, + "has_function_call": False, + "is_thought": False, + }, + { + "part_index": 1, + "signature": base64.b64encode(b"encrypted_mock_thought_signature_2").decode("utf-8"), + "has_text": False, + "has_function_call": True, + "is_thought": False, + }, + ] + + message = ChatMessage.from_assistant( + text="I'll check the weather for you", + tool_calls=[tool_call], + meta={"thought_signatures": thought_signatures}, + ) + + google_content = _convert_message_to_google_genai_format(message) + + assert google_content.role == "model" + assert len(google_content.parts) == 2 + + # First part should have text and its thought signature + assert google_content.parts[0].text == "I'll check the weather for you" + # thought_signature is returned as bytes from the API + assert google_content.parts[0].thought_signature == b"encrypted_mock_thought_signature_1" + + # Second part should have function call and its thought signature + assert google_content.parts[1].function_call.name == "weather" + assert google_content.parts[1].function_call.args == {"city": "Paris"} + assert google_content.parts[1].thought_signature == b"encrypted_mock_thought_signature_2" + + def test_convert_message_to_google_genai_format_with_reasoning_content(self): + """Test that ReasoningContent is properly skipped during conversion.""" + # ReasoningContent is for human transparency only, not sent to the API + reasoning = ReasoningContent(reasoning_text="Of Life, the Universe and Everything...") + + # Create a message with both text and reasoning content + message = ChatMessage.from_assistant(text="Forty-two", reasoning=reasoning) + + google_content = _convert_message_to_google_genai_format(message) + + assert google_content.role == "model" + assert len(google_content.parts) == 1 + + # Only the text should be included, reasoning content should be skipped + assert google_content.parts[0].text == "Forty-two" + # Verify no thought part was created (reasoning is not sent to API) + assert not hasattr(google_content.parts[0], "thought") or not google_content.parts[0].thought + + def test_convert_message_to_google_genai_format_tool_message(self): + tool_call = ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + message = ChatMessage.from_tool(tool_result="22° C", origin=tool_call) + google_content = _convert_message_to_google_genai_format(message) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert isinstance(google_content.parts[0].function_response, types.FunctionResponse) + assert google_content.parts[0].function_response.id == "123" + assert google_content.parts[0].function_response.name == "weather" + assert google_content.parts[0].function_response.response == {"result": "22° C"} + assert google_content.parts[0].function_response.parts is None + + def test_convert_message_to_google_genai_format_image_in_tool_result(self): + tool_call = ToolCall(id="123", tool_name="image_retriever", arguments={}) + + base64_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=" + tool_result = [TextContent("Here is the image"), ImageContent(base64_image=base64_str, mime_type="image/jpeg")] + + message = ChatMessage.from_tool(tool_result=tool_result, origin=tool_call) + google_content = _convert_message_to_google_genai_format(message) + assert google_content.role == "user" + assert len(google_content.parts) == 1 + assert isinstance(google_content.parts[0].function_response, types.FunctionResponse) + assert google_content.parts[0].function_response.id == "123" + assert google_content.parts[0].function_response.name == "image_retriever" + assert google_content.parts[0].function_response.response == {"result": ""} + assert len(google_content.parts[0].function_response.parts) == 2 + assert isinstance(google_content.parts[0].function_response.parts[0], types.FunctionResponsePart) + assert google_content.parts[0].function_response.parts[0].inline_data.mime_type == "text/plain" + assert google_content.parts[0].function_response.parts[0].inline_data.data == b"Here is the image" + assert isinstance(google_content.parts[0].function_response.parts[1], types.FunctionResponsePart) + assert google_content.parts[0].function_response.parts[1].inline_data.mime_type == "image/jpeg" + assert google_content.parts[0].function_response.parts[1].inline_data.data == base64.b64decode(base64_str) + assert len(google_content.parts[0].function_response.parts[1].inline_data.data) > 0 + + def test_convert_message_to_google_genai_format_invalid_tool_result_type(self): + tool_call = ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + message = ChatMessage.from_tool(tool_result=256, origin=tool_call) + with pytest.raises(ValueError, match="Unsupported content type in tool call result"): + _convert_message_to_google_genai_format(message) + + message = ChatMessage.from_tool(tool_result=[TextContent("This is supported"), 256], origin=tool_call) + with pytest.raises( + ValueError, + match=( + r"Unsupported content type in tool call result list. " + "Only TextContent and ImageContent are supported." + ), + ): + _convert_message_to_google_genai_format(message) + + +class TestConvertGoogleGenAIToMessage: + def test_convert_google_genai_response_to_chatmessage_parses_cached_tokens(self): + """ + When the API response includes cached_content_token_count in usage_metadata, + it is parsed into meta['usage']. + """ + + # Minimal candidate with one text part + mock_part = Mock() + mock_part.text = "Four." + mock_part.function_call = None + mock_part.thought = False + mock_part.thought_signature = None + mock_content = Mock() + mock_content.parts = [mock_part] + mock_candidate = Mock() + mock_candidate.content = mock_content + mock_candidate.finish_reason = "STOP" + + # Usage metadata including cached tokens (as returned by API when cache is used) + mock_usage = Mock() + mock_usage.prompt_token_count = 1000 + mock_usage.candidates_token_count = 5 + mock_usage.total_token_count = 1005 + mock_usage.cached_content_token_count = 800 + mock_usage.thoughts_token_count = None + + mock_response = Mock() + mock_response.candidates = [mock_candidate] + mock_response.usage_metadata = mock_usage + + message = _convert_google_genai_response_to_chatmessage(mock_response, "gemini-2.5-flash") + + assert message.meta is not None + assert "usage" in message.meta + usage = message.meta["usage"] + assert usage["prompt_tokens"] == 1000 + assert usage["completion_tokens"] == 5 + assert usage["total_tokens"] == 1005 + assert usage["cached_content_token_count"] == 800 + + def test_convert_usage_metadata_to_serializable(self): + """_convert_usage_metadata_to_serializable builds a serializable dict from a UsageMetadata object.""" + assert _convert_usage_metadata_to_serializable(None) == {} + assert _convert_usage_metadata_to_serializable(False) == {} + + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=100, + candidates_token_count=5, + total_token_count=105, + ) + result = _convert_usage_metadata_to_serializable(usage_metadata) + assert result["prompt_token_count"] == 100 + assert result["candidates_token_count"] == 5 + assert result["total_token_count"] == 105 + assert len(result) == 3 + + # Serialization of zero and composite types (ModalityTokenCount, lists) + modality_token_count = types.ModalityTokenCount(modality=types.Modality.TEXT, tokenCount=100) + usage_with_details = types.GenerateContentResponseUsageMetadata( + prompt_token_count=0, + candidates_tokens_details=[modality_token_count], + ) + result2 = _convert_usage_metadata_to_serializable(usage_with_details) + assert result2["prompt_token_count"] == 0 + assert result2["candidates_tokens_details"] == [{"modality": "TEXT", "token_count": 100}] diff --git a/integrations/google_genai/tests/test_files/sample_pdf_3.pdf b/integrations/google_genai/tests/test_files/sample_pdf_3.pdf new file mode 100644 index 0000000000..c0d07eaa68 Binary files /dev/null and b/integrations/google_genai/tests/test_files/sample_pdf_3.pdf differ