Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
serialize_tools_or_toolset,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from pydantic import BaseModel

from haystack_integrations.components.common.google_genai.utils import _get_client
from haystack_integrations.components.generators.google_genai.chat.utils import (
Expand All @@ -27,6 +28,7 @@
_convert_google_genai_response_to_chatmessage,
_convert_message_to_google_genai_format,
_convert_tools_to_google_genai_format,
_process_response_format,
_process_thinking_config,
)

Expand Down Expand Up @@ -139,6 +141,28 @@ def weather_function(city: str):
response = chat_generator_with_tools.run(messages=messages)
```

### Usage example with structured output

```python
from pydantic import BaseModel
from haystack.dataclasses.chat_message import ChatMessage
from haystack_integrations.components.generators.google_genai import GoogleGenAIChatGenerator

class City(BaseModel):
name: str
country: str
population: int

chat_generator = GoogleGenAIChatGenerator(
model="gemini-2.5-flash",
generation_kwargs={"response_format": City}
)

messages = [ChatMessage.from_user("Tell me about Paris")]
response = chat_generator.run(messages=messages)
print(response["replies"][0].text) # JSON output matching the City schema
```

### Usage example with FileContent embedded in a ChatMessage

```python
Expand Down Expand Up @@ -247,14 +271,20 @@ def to_dict(self) -> dict[str, Any]:
"""
callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None
serialized_tools = serialize_tools_or_toolset(self._tools) if self._tools else None

Comment thread
anakin87 marked this conversation as resolved.
generation_kwargs = self._generation_kwargs.copy()
response_format = generation_kwargs.get("response_format")
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
Comment thread
anakin87 marked this conversation as resolved.
generation_kwargs["response_format"] = response_format.model_json_schema()

return default_to_dict(
self,
api_key=self._api_key.to_dict(),
api=self._api,
vertex_ai_project=self._vertex_ai_project,
vertex_ai_location=self._vertex_ai_location,
model=self._model,
generation_kwargs=self._generation_kwargs,
generation_kwargs=generation_kwargs,
safety_settings=self._safety_settings,
streaming_callback=callback_name,
tools=serialized_tools,
Expand Down Expand Up @@ -376,8 +406,9 @@ def run(
safety_settings = safety_settings or self._safety_settings
tools = tools or self._tools

# Process thinking configuration
# Process thinking configuration and response format
generation_kwargs = _process_thinking_config(generation_kwargs)
generation_kwargs = _process_response_format(generation_kwargs)

# Select appropriate streaming callback
streaming_callback = select_streaming_callback(
Expand Down Expand Up @@ -486,8 +517,9 @@ async def run_async(
safety_settings = safety_settings or self._safety_settings
tools = tools or self._tools

# Process thinking configuration
# Process thinking configuration and response format
generation_kwargs = _process_thinking_config(generation_kwargs)
generation_kwargs = _process_response_format(generation_kwargs)

# Select appropriate streaming callback
streaming_callback = select_streaming_callback(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
flatten_tools_or_toolsets,
)
from jsonref import replace_refs
from pydantic import BaseModel

logger = logging.getLogger(__name__)

Expand All @@ -54,6 +55,49 @@
}


def _process_response_format(generation_kwargs: dict[str, Any]) -> dict[str, Any]:
"""
Process `response_format` from generation_kwargs into Google GenAI's native
``response_schema`` and ``response_mime_type`` parameters.

Accepts either a Pydantic BaseModel class or a JSON schema dict. When
``response_format`` is present it is popped and replaced with the two
Google-native keys. If ``response_schema`` or ``response_mime_type`` are
already set they take precedence and ``response_format`` is ignored.
Comment thread
anakin87 marked this conversation as resolved.
Outdated

Does not mutate the input dict; returns a new dict.

:param generation_kwargs: The generation configuration dictionary.
:returns: A new dict with response_schema/response_mime_type if applicable.
"""
generation_kwargs = dict(generation_kwargs)

# If the user already set Google-native keys, leave them alone
if "response_schema" in generation_kwargs or "response_mime_type" in generation_kwargs:
generation_kwargs.pop("response_format", None)
return generation_kwargs

response_format = generation_kwargs.pop("response_format", None)
if response_format is None:
return generation_kwargs

if isinstance(response_format, type) and issubclass(response_format, BaseModel):
generation_kwargs["response_schema"] = response_format
generation_kwargs["response_mime_type"] = "application/json"
return generation_kwargs

if isinstance(response_format, dict):
generation_kwargs["response_schema"] = response_format
generation_kwargs["response_mime_type"] = "application/json"
return generation_kwargs

msg = (
f"Unsupported response_format type: {type(response_format).__name__}. "
"Expected a Pydantic model class or a JSON schema dict."
)
raise TypeError(msg)


def _process_thinking_config(generation_kwargs: dict[str, Any]) -> dict[str, Any]:
"""
Process thinking configuration from generation_kwargs.
Expand Down
110 changes: 110 additions & 0 deletions integrations/google_genai/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import json
import os

import pytest
Expand All @@ -21,6 +22,7 @@
)
from haystack.tools import Tool, Toolset, create_tool_from_function
from haystack.utils.auth import Secret
from pydantic import BaseModel

from haystack_integrations.components.generators.google_genai.chat.chat_generator import (
GoogleGenAIChatGenerator,
Expand Down Expand Up @@ -204,6 +206,52 @@ def test_serde_with_mixed_tools_and_toolsets(self, monkeypatch):
assert restored._tools[0].name == "tool1"
assert len(restored._tools[1]) == 1

def test_to_dict_with_response_format_pydantic(self, monkeypatch):
"""Test that to_dict serializes a Pydantic response_format to a JSON schema dict."""
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")

class City(BaseModel):
name: str
country: str
population: int

generator = GoogleGenAIChatGenerator(generation_kwargs={"response_format": City})
data = generator.to_dict()

response_format = data["init_parameters"]["generation_kwargs"]["response_format"]
assert response_format == {
"properties": {
"name": {"title": "Name", "type": "string"},
"country": {"title": "Country", "type": "string"},
"population": {"title": "Population", "type": "integer"},
},
"required": ["name", "country", "population"],
"title": "City",
"type": "object",
}

def test_to_dict_with_response_format_dict(self, monkeypatch):
"""Test that to_dict preserves a dict response_format as is."""
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")

schema = {"type": "object", "properties": {"name": {"type": "string"}}}
generator = GoogleGenAIChatGenerator(generation_kwargs={"response_format": schema})
data = generator.to_dict()

assert data["init_parameters"]["generation_kwargs"]["response_format"] == schema

def test_serde_with_response_format(self, monkeypatch):
"""Test serialization/deserialization round-trip with response_format."""
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")

schema = {"type": "object", "properties": {"name": {"type": "string"}}}
generator = GoogleGenAIChatGenerator(generation_kwargs={"response_format": schema, "temperature": 0.5})
data = generator.to_dict()

restored = GoogleGenAIChatGenerator.from_dict(data)
assert restored._generation_kwargs["response_format"] == schema
assert restored._generation_kwargs["temperature"] == 0.5


@pytest.mark.skipif(
not os.environ.get("GOOGLE_API_KEY", None),
Expand Down Expand Up @@ -632,6 +680,48 @@ def test_live_run_with_thinking_unsupported_model_fails_fast(self):
assert "thinking_budget" in error_message or "thinking features" in error_message
assert "Try removing" in error_message or "use a different model" in error_message

def test_live_run_with_structured_output_pydantic(self):
"""Test that response_format with a Pydantic model returns valid structured JSON output."""

class City(BaseModel):
name: str
country: str
population: int

component = GoogleGenAIChatGenerator(generation_kwargs={"response_format": City})
results = component.run([ChatMessage.from_user("Tell me about Paris. Respond in JSON.")])

assert len(results["replies"]) == 1
message = results["replies"][0]
assert message.text

parsed = json.loads(message.text)
assert "name" in parsed
assert "country" in parsed
assert "population" in parsed

def test_live_run_with_structured_output_dict_schema(self):
"""Test that response_format with a JSON schema dict returns valid structured JSON output."""
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"country": {"type": "string"},
},
"required": ["name", "country"],
}

component = GoogleGenAIChatGenerator(generation_kwargs={"response_format": schema})
results = component.run([ChatMessage.from_user("Tell me about Paris. Respond in JSON.")])

assert len(results["replies"]) == 1
message = results["replies"][0]
assert message.text

parsed = json.loads(message.text)
assert "name" in parsed
assert "country" in parsed

def test_live_run_agent_with_images_in_tool_result(self, test_files_path):
def retrieve_image():
return [
Expand Down Expand Up @@ -763,6 +853,26 @@ async def test_live_run_async_with_thinking_unsupported_model_fails_fast(self):
assert "thinking_budget" in error_message or "thinking features" in error_message
assert "Try removing" in error_message or "use a different model" in error_message

async def test_live_run_async_with_structured_output(self):
"""Async integration test for structured output with a Pydantic model."""

class City(BaseModel):
name: str
country: str
population: int

component = GoogleGenAIChatGenerator(generation_kwargs={"response_format": City})
results = await component.run_async([ChatMessage.from_user("Tell me about Paris. Respond in JSON.")])

assert len(results["replies"]) == 1
message = results["replies"][0]
assert message.text

parsed = json.loads(message.text)
assert "name" in parsed
assert "country" in parsed
assert "population" in parsed

async def test_concurrent_async_calls(self):
"""Test multiple concurrent async calls."""
component = GoogleGenAIChatGenerator()
Expand Down
59 changes: 59 additions & 0 deletions integrations/google_genai/tests/test_chat_generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TextContent,
ToolCall,
)
from pydantic import BaseModel

from haystack_integrations.components.generators.google_genai.chat.chat_generator import (
GoogleGenAIChatGenerator,
Expand All @@ -27,6 +28,7 @@
_convert_google_genai_response_to_chatmessage,
_convert_message_to_google_genai_format,
_convert_usage_metadata_to_serializable,
_process_response_format,
_process_thinking_config,
)

Expand Down Expand Up @@ -160,6 +162,63 @@ def test_process_thinking_config_explicit_include_thoughts():
assert result == {"temperature": 0.5}


def test_process_response_format():
"""Test the _process_response_format function with different response_format values."""

class City(BaseModel):
name: str
country: str
population: int

# Test Pydantic model
generation_kwargs = {"response_format": City, "temperature": 0.7}
result = _process_response_format(generation_kwargs)

# response_format should be replaced with response_schema and response_mime_type
assert "response_format" not in result
assert result["response_schema"] is City
assert result["response_mime_type"] == "application/json"
# Other kwargs should be preserved
assert result["temperature"] == 0.7

# Test JSON schema dict
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
generation_kwargs = {"response_format": schema, "temperature": 0.5}
result = _process_response_format(generation_kwargs)
assert "response_format" not in result
assert result["response_schema"] == schema
assert result["response_mime_type"] == "application/json"
assert result["temperature"] == 0.5

# Test when response_format is not present
generation_kwargs = {"temperature": 0.5}
result = _process_response_format(generation_kwargs)
assert result == generation_kwargs # No changes

# Test that native keys take precedence
native_schema = {"type": "object", "properties": {"x": {"type": "string"}}}
generation_kwargs = {
"response_format": City,
"response_schema": native_schema,
"response_mime_type": "application/json",
}
result = _process_response_format(generation_kwargs)
assert "response_format" not in result
assert result["response_schema"] == native_schema
assert result["response_mime_type"] == "application/json"

# Test unsupported type raises TypeError
generation_kwargs = {"response_format": "invalid"}
with pytest.raises(TypeError, match="Unsupported response_format type"):
_process_response_format(generation_kwargs)

# Test that input dict is not mutated
generation_kwargs = {"response_format": City, "temperature": 0.7}
original = generation_kwargs.copy()
_process_response_format(generation_kwargs)
assert generation_kwargs == original


class TestStreamingChunkConversion:
def test_convert_google_chunk_to_streaming_chunk_text_only(self, monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")
Expand Down