diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py index c6fdfa90d4..673c045b7a 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py @@ -19,6 +19,7 @@ serialize_tools_or_toolset, ) from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable +from pydantic import BaseModel from haystack_integrations.components.common.google_genai.utils import _get_client from haystack_integrations.components.generators.google_genai.chat.utils import ( @@ -27,6 +28,7 @@ _convert_google_genai_response_to_chatmessage, _convert_message_to_google_genai_format, _convert_tools_to_google_genai_format, + _process_response_format, _process_thinking_config, ) @@ -139,6 +141,28 @@ def weather_function(city: str): response = chat_generator_with_tools.run(messages=messages) ``` + ### Usage example with structured output + + ```python + from pydantic import BaseModel + from haystack.dataclasses.chat_message import ChatMessage + from haystack_integrations.components.generators.google_genai import GoogleGenAIChatGenerator + + class City(BaseModel): + name: str + country: str + population: int + + chat_generator = GoogleGenAIChatGenerator( + model="gemini-2.5-flash", + generation_kwargs={"response_format": City} + ) + + messages = [ChatMessage.from_user("Tell me about Paris")] + response = chat_generator.run(messages=messages) + print(response["replies"][0].text) # JSON output matching the City schema + ``` + ### Usage example with FileContent embedded in a ChatMessage ```python @@ -247,6 +271,12 @@ def to_dict(self) -> dict[str, Any]: """ callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None serialized_tools = serialize_tools_or_toolset(self._tools) if self._tools else None + + generation_kwargs = self._generation_kwargs.copy() + response_format = generation_kwargs.get("response_format") + if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel): + generation_kwargs["response_format"] = response_format.model_json_schema() + return default_to_dict( self, api_key=self._api_key.to_dict(), @@ -254,7 +284,7 @@ def to_dict(self) -> dict[str, Any]: vertex_ai_project=self._vertex_ai_project, vertex_ai_location=self._vertex_ai_location, model=self._model, - generation_kwargs=self._generation_kwargs, + generation_kwargs=generation_kwargs, safety_settings=self._safety_settings, streaming_callback=callback_name, tools=serialized_tools, @@ -376,8 +406,9 @@ def run( safety_settings = safety_settings or self._safety_settings tools = tools or self._tools - # Process thinking configuration + # Process thinking configuration and response format generation_kwargs = _process_thinking_config(generation_kwargs) + generation_kwargs = _process_response_format(generation_kwargs) # Select appropriate streaming callback streaming_callback = select_streaming_callback( @@ -486,8 +517,9 @@ async def run_async( safety_settings = safety_settings or self._safety_settings tools = tools or self._tools - # Process thinking configuration + # Process thinking configuration and response format generation_kwargs = _process_thinking_config(generation_kwargs) + generation_kwargs = _process_response_format(generation_kwargs) # Select appropriate streaming callback streaming_callback = select_streaming_callback( diff --git a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py index 96a4b49c99..82e64a6144 100644 --- a/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py +++ b/integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py @@ -28,6 +28,7 @@ flatten_tools_or_toolsets, ) from jsonref import replace_refs +from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -54,6 +55,49 @@ } +def _process_response_format(generation_kwargs: dict[str, Any]) -> dict[str, Any]: + """ + Process `response_format` from generation_kwargs into Google GenAI's native + `response_schema` and `response_mime_type` parameters. + + Accepts either a Pydantic BaseModel class or a JSON schema dict. When + `response_format` is present, it is popped and replaced with the two + Google-native keys. If `response_schema` or `response_mime_type` are + already set, they take precedence and `response_format` is ignored. + + Does not mutate the input dict; returns a new dict. + + :param generation_kwargs: The generation configuration dictionary. + :returns: A new dict with response_schema/response_mime_type if applicable. + """ + generation_kwargs = dict(generation_kwargs) + + # If the user already set Google-native keys, leave them alone + if "response_schema" in generation_kwargs or "response_mime_type" in generation_kwargs: + generation_kwargs.pop("response_format", None) + return generation_kwargs + + response_format = generation_kwargs.pop("response_format", None) + if response_format is None: + return generation_kwargs + + if isinstance(response_format, type) and issubclass(response_format, BaseModel): + generation_kwargs["response_schema"] = response_format + generation_kwargs["response_mime_type"] = "application/json" + return generation_kwargs + + if isinstance(response_format, dict): + generation_kwargs["response_schema"] = response_format + generation_kwargs["response_mime_type"] = "application/json" + return generation_kwargs + + msg = ( + f"Unsupported response_format type: {type(response_format).__name__}. " + "Expected a Pydantic model class or a JSON schema dict." + ) + raise TypeError(msg) + + def _process_thinking_config(generation_kwargs: dict[str, Any]) -> dict[str, Any]: """ Process thinking configuration from generation_kwargs. diff --git a/integrations/google_genai/tests/test_chat_generator.py b/integrations/google_genai/tests/test_chat_generator.py index 612ea147b6..93af7d193e 100644 --- a/integrations/google_genai/tests/test_chat_generator.py +++ b/integrations/google_genai/tests/test_chat_generator.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import json import os import pytest @@ -21,6 +22,7 @@ ) from haystack.tools import Tool, Toolset, create_tool_from_function from haystack.utils.auth import Secret +from pydantic import BaseModel from haystack_integrations.components.generators.google_genai.chat.chat_generator import ( GoogleGenAIChatGenerator, @@ -204,6 +206,52 @@ def test_serde_with_mixed_tools_and_toolsets(self, monkeypatch): assert restored._tools[0].name == "tool1" assert len(restored._tools[1]) == 1 + def test_to_dict_with_response_format_pydantic(self, monkeypatch): + """Test that to_dict serializes a Pydantic response_format to a JSON schema dict.""" + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + + class City(BaseModel): + name: str + country: str + population: int + + generator = GoogleGenAIChatGenerator(generation_kwargs={"response_format": City}) + data = generator.to_dict() + + response_format = data["init_parameters"]["generation_kwargs"]["response_format"] + assert response_format == { + "properties": { + "name": {"title": "Name", "type": "string"}, + "country": {"title": "Country", "type": "string"}, + "population": {"title": "Population", "type": "integer"}, + }, + "required": ["name", "country", "population"], + "title": "City", + "type": "object", + } + + def test_to_dict_with_response_format_dict(self, monkeypatch): + """Test that to_dict preserves a dict response_format as is.""" + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + generator = GoogleGenAIChatGenerator(generation_kwargs={"response_format": schema}) + data = generator.to_dict() + + assert data["init_parameters"]["generation_kwargs"]["response_format"] == schema + + def test_serde_with_response_format(self, monkeypatch): + """Test serialization/deserialization round-trip with response_format.""" + monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key") + + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + generator = GoogleGenAIChatGenerator(generation_kwargs={"response_format": schema, "temperature": 0.5}) + data = generator.to_dict() + + restored = GoogleGenAIChatGenerator.from_dict(data) + assert restored._generation_kwargs["response_format"] == schema + assert restored._generation_kwargs["temperature"] == 0.5 + @pytest.mark.skipif( not os.environ.get("GOOGLE_API_KEY", None), @@ -632,6 +680,48 @@ def test_live_run_with_thinking_unsupported_model_fails_fast(self): assert "thinking_budget" in error_message or "thinking features" in error_message assert "Try removing" in error_message or "use a different model" in error_message + def test_live_run_with_structured_output_pydantic(self): + """Test that response_format with a Pydantic model returns valid structured JSON output.""" + + class City(BaseModel): + name: str + country: str + population: int + + component = GoogleGenAIChatGenerator(generation_kwargs={"response_format": City}) + results = component.run([ChatMessage.from_user("Tell me about Paris. Respond in JSON.")]) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + assert message.text + + parsed = json.loads(message.text) + assert "name" in parsed + assert "country" in parsed + assert "population" in parsed + + def test_live_run_with_structured_output_dict_schema(self): + """Test that response_format with a JSON schema dict returns valid structured JSON output.""" + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "country": {"type": "string"}, + }, + "required": ["name", "country"], + } + + component = GoogleGenAIChatGenerator(generation_kwargs={"response_format": schema}) + results = component.run([ChatMessage.from_user("Tell me about Paris. Respond in JSON.")]) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + assert message.text + + parsed = json.loads(message.text) + assert "name" in parsed + assert "country" in parsed + def test_live_run_agent_with_images_in_tool_result(self, test_files_path): def retrieve_image(): return [ @@ -763,6 +853,26 @@ async def test_live_run_async_with_thinking_unsupported_model_fails_fast(self): assert "thinking_budget" in error_message or "thinking features" in error_message assert "Try removing" in error_message or "use a different model" in error_message + async def test_live_run_async_with_structured_output(self): + """Async integration test for structured output with a Pydantic model.""" + + class City(BaseModel): + name: str + country: str + population: int + + component = GoogleGenAIChatGenerator(generation_kwargs={"response_format": City}) + results = await component.run_async([ChatMessage.from_user("Tell me about Paris. Respond in JSON.")]) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + assert message.text + + parsed = json.loads(message.text) + assert "name" in parsed + assert "country" in parsed + assert "population" in parsed + async def test_concurrent_async_calls(self): """Test multiple concurrent async calls.""" component = GoogleGenAIChatGenerator() diff --git a/integrations/google_genai/tests/test_chat_generator_utils.py b/integrations/google_genai/tests/test_chat_generator_utils.py index 4513bb54b5..f3ce9a1816 100644 --- a/integrations/google_genai/tests/test_chat_generator_utils.py +++ b/integrations/google_genai/tests/test_chat_generator_utils.py @@ -17,6 +17,7 @@ TextContent, ToolCall, ) +from pydantic import BaseModel from haystack_integrations.components.generators.google_genai.chat.chat_generator import ( GoogleGenAIChatGenerator, @@ -27,6 +28,7 @@ _convert_google_genai_response_to_chatmessage, _convert_message_to_google_genai_format, _convert_usage_metadata_to_serializable, + _process_response_format, _process_thinking_config, ) @@ -160,6 +162,63 @@ def test_process_thinking_config_explicit_include_thoughts(): assert result == {"temperature": 0.5} +def test_process_response_format(): + """Test the _process_response_format function with different response_format values.""" + + class City(BaseModel): + name: str + country: str + population: int + + # Test Pydantic model + generation_kwargs = {"response_format": City, "temperature": 0.7} + result = _process_response_format(generation_kwargs) + + # response_format should be replaced with response_schema and response_mime_type + assert "response_format" not in result + assert result["response_schema"] is City + assert result["response_mime_type"] == "application/json" + # Other kwargs should be preserved + assert result["temperature"] == 0.7 + + # Test JSON schema dict + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + generation_kwargs = {"response_format": schema, "temperature": 0.5} + result = _process_response_format(generation_kwargs) + assert "response_format" not in result + assert result["response_schema"] == schema + assert result["response_mime_type"] == "application/json" + assert result["temperature"] == 0.5 + + # Test when response_format is not present + generation_kwargs = {"temperature": 0.5} + result = _process_response_format(generation_kwargs) + assert result == generation_kwargs # No changes + + # Test that native keys take precedence + native_schema = {"type": "object", "properties": {"x": {"type": "string"}}} + generation_kwargs = { + "response_format": City, + "response_schema": native_schema, + "response_mime_type": "application/json", + } + result = _process_response_format(generation_kwargs) + assert "response_format" not in result + assert result["response_schema"] == native_schema + assert result["response_mime_type"] == "application/json" + + # Test unsupported type raises TypeError + generation_kwargs = {"response_format": "invalid"} + with pytest.raises(TypeError, match="Unsupported response_format type"): + _process_response_format(generation_kwargs) + + # Test that input dict is not mutated + generation_kwargs = {"response_format": City, "temperature": 0.7} + original = generation_kwargs.copy() + _process_response_format(generation_kwargs) + assert generation_kwargs == original + + class TestStreamingChunkConversion: def test_convert_google_chunk_to_streaming_chunk_text_only(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")