Skip to content

Commit 43e483b

Browse files
aelhajjanakin87
andauthored
feat: Add support for structured output (response_format) in GoogleGenAIChatGenerator (#2946)
* feat: Enhance GoogleGenAIChatGenerator with response format processing * test: Add tests for response_format handling in GoogleGenAIChatGenerator * Update integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * test: adjust test for genai chat generator * docs: add structured output example for GoogleGenAI * Update integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
1 parent 246ed51 commit 43e483b

4 files changed

Lines changed: 248 additions & 3 deletions

File tree

integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/chat_generator.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
serialize_tools_or_toolset,
2020
)
2121
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
22+
from pydantic import BaseModel
2223

2324
from haystack_integrations.components.common.google_genai.utils import _get_client
2425
from haystack_integrations.components.generators.google_genai.chat.utils import (
@@ -27,6 +28,7 @@
2728
_convert_google_genai_response_to_chatmessage,
2829
_convert_message_to_google_genai_format,
2930
_convert_tools_to_google_genai_format,
31+
_process_response_format,
3032
_process_thinking_config,
3133
)
3234

@@ -139,6 +141,28 @@ def weather_function(city: str):
139141
response = chat_generator_with_tools.run(messages=messages)
140142
```
141143
144+
### Usage example with structured output
145+
146+
```python
147+
from pydantic import BaseModel
148+
from haystack.dataclasses.chat_message import ChatMessage
149+
from haystack_integrations.components.generators.google_genai import GoogleGenAIChatGenerator
150+
151+
class City(BaseModel):
152+
name: str
153+
country: str
154+
population: int
155+
156+
chat_generator = GoogleGenAIChatGenerator(
157+
model="gemini-2.5-flash",
158+
generation_kwargs={"response_format": City}
159+
)
160+
161+
messages = [ChatMessage.from_user("Tell me about Paris")]
162+
response = chat_generator.run(messages=messages)
163+
print(response["replies"][0].text) # JSON output matching the City schema
164+
```
165+
142166
### Usage example with FileContent embedded in a ChatMessage
143167
144168
```python
@@ -250,14 +274,20 @@ def to_dict(self) -> dict[str, Any]:
250274
"""
251275
callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None
252276
serialized_tools = serialize_tools_or_toolset(self._tools) if self._tools else None
277+
278+
generation_kwargs = self._generation_kwargs.copy()
279+
response_format = generation_kwargs.get("response_format")
280+
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
281+
generation_kwargs["response_format"] = response_format.model_json_schema()
282+
253283
return default_to_dict(
254284
self,
255285
api_key=self._api_key.to_dict(),
256286
api=self._api,
257287
vertex_ai_project=self._vertex_ai_project,
258288
vertex_ai_location=self._vertex_ai_location,
259289
model=self._model,
260-
generation_kwargs=self._generation_kwargs,
290+
generation_kwargs=generation_kwargs,
261291
safety_settings=self._safety_settings,
262292
streaming_callback=callback_name,
263293
tools=serialized_tools,
@@ -379,8 +409,9 @@ def run(
379409
safety_settings = safety_settings or self._safety_settings
380410
tools = tools or self._tools
381411

382-
# Process thinking configuration
412+
# Process thinking configuration and response format
383413
generation_kwargs = _process_thinking_config(generation_kwargs)
414+
generation_kwargs = _process_response_format(generation_kwargs)
384415

385416
# Select appropriate streaming callback
386417
streaming_callback = select_streaming_callback(
@@ -489,8 +520,9 @@ async def run_async(
489520
safety_settings = safety_settings or self._safety_settings
490521
tools = tools or self._tools
491522

492-
# Process thinking configuration
523+
# Process thinking configuration and response format
493524
generation_kwargs = _process_thinking_config(generation_kwargs)
525+
generation_kwargs = _process_response_format(generation_kwargs)
494526

495527
# Select appropriate streaming callback
496528
streaming_callback = select_streaming_callback(

integrations/google_genai/src/haystack_integrations/components/generators/google_genai/chat/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
flatten_tools_or_toolsets,
2929
)
3030
from jsonref import replace_refs
31+
from pydantic import BaseModel
3132

3233
logger = logging.getLogger(__name__)
3334

@@ -54,6 +55,49 @@
5455
}
5556

5657

58+
def _process_response_format(generation_kwargs: dict[str, Any]) -> dict[str, Any]:
59+
"""
60+
Process `response_format` from generation_kwargs into Google GenAI's native
61+
`response_schema` and `response_mime_type` parameters.
62+
63+
Accepts either a Pydantic BaseModel class or a JSON schema dict. When
64+
`response_format` is present, it is popped and replaced with the two
65+
Google-native keys. If `response_schema` or `response_mime_type` are
66+
already set, they take precedence and `response_format` is ignored.
67+
68+
Does not mutate the input dict; returns a new dict.
69+
70+
:param generation_kwargs: The generation configuration dictionary.
71+
:returns: A new dict with response_schema/response_mime_type if applicable.
72+
"""
73+
generation_kwargs = dict(generation_kwargs)
74+
75+
# If the user already set Google-native keys, leave them alone
76+
if "response_schema" in generation_kwargs or "response_mime_type" in generation_kwargs:
77+
generation_kwargs.pop("response_format", None)
78+
return generation_kwargs
79+
80+
response_format = generation_kwargs.pop("response_format", None)
81+
if response_format is None:
82+
return generation_kwargs
83+
84+
if isinstance(response_format, type) and issubclass(response_format, BaseModel):
85+
generation_kwargs["response_schema"] = response_format
86+
generation_kwargs["response_mime_type"] = "application/json"
87+
return generation_kwargs
88+
89+
if isinstance(response_format, dict):
90+
generation_kwargs["response_schema"] = response_format
91+
generation_kwargs["response_mime_type"] = "application/json"
92+
return generation_kwargs
93+
94+
msg = (
95+
f"Unsupported response_format type: {type(response_format).__name__}. "
96+
"Expected a Pydantic model class or a JSON schema dict."
97+
)
98+
raise TypeError(msg)
99+
100+
57101
def _process_thinking_config(generation_kwargs: dict[str, Any]) -> dict[str, Any]:
58102
"""
59103
Process thinking configuration from generation_kwargs.

integrations/google_genai/tests/test_chat_generator.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import asyncio
6+
import json
67
import os
78

89
import pytest
@@ -21,6 +22,7 @@
2122
)
2223
from haystack.tools import Tool, Toolset, create_tool_from_function
2324
from haystack.utils.auth import Secret
25+
from pydantic import BaseModel
2426

2527
from haystack_integrations.components.generators.google_genai.chat.chat_generator import (
2628
GoogleGenAIChatGenerator,
@@ -204,6 +206,52 @@ def test_serde_with_mixed_tools_and_toolsets(self, monkeypatch):
204206
assert restored._tools[0].name == "tool1"
205207
assert len(restored._tools[1]) == 1
206208

209+
def test_to_dict_with_response_format_pydantic(self, monkeypatch):
210+
"""Test that to_dict serializes a Pydantic response_format to a JSON schema dict."""
211+
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")
212+
213+
class City(BaseModel):
214+
name: str
215+
country: str
216+
population: int
217+
218+
generator = GoogleGenAIChatGenerator(generation_kwargs={"response_format": City})
219+
data = generator.to_dict()
220+
221+
response_format = data["init_parameters"]["generation_kwargs"]["response_format"]
222+
assert response_format == {
223+
"properties": {
224+
"name": {"title": "Name", "type": "string"},
225+
"country": {"title": "Country", "type": "string"},
226+
"population": {"title": "Population", "type": "integer"},
227+
},
228+
"required": ["name", "country", "population"],
229+
"title": "City",
230+
"type": "object",
231+
}
232+
233+
def test_to_dict_with_response_format_dict(self, monkeypatch):
234+
"""Test that to_dict preserves a dict response_format as is."""
235+
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")
236+
237+
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
238+
generator = GoogleGenAIChatGenerator(generation_kwargs={"response_format": schema})
239+
data = generator.to_dict()
240+
241+
assert data["init_parameters"]["generation_kwargs"]["response_format"] == schema
242+
243+
def test_serde_with_response_format(self, monkeypatch):
244+
"""Test serialization/deserialization round-trip with response_format."""
245+
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")
246+
247+
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
248+
generator = GoogleGenAIChatGenerator(generation_kwargs={"response_format": schema, "temperature": 0.5})
249+
data = generator.to_dict()
250+
251+
restored = GoogleGenAIChatGenerator.from_dict(data)
252+
assert restored._generation_kwargs["response_format"] == schema
253+
assert restored._generation_kwargs["temperature"] == 0.5
254+
207255

208256
@pytest.mark.skipif(
209257
not os.environ.get("GOOGLE_API_KEY", None),
@@ -632,6 +680,48 @@ def test_live_run_with_thinking_unsupported_model_fails_fast(self):
632680
assert "thinking_budget" in error_message or "thinking features" in error_message
633681
assert "Try removing" in error_message or "use a different model" in error_message
634682

683+
def test_live_run_with_structured_output_pydantic(self):
684+
"""Test that response_format with a Pydantic model returns valid structured JSON output."""
685+
686+
class City(BaseModel):
687+
name: str
688+
country: str
689+
population: int
690+
691+
component = GoogleGenAIChatGenerator(generation_kwargs={"response_format": City})
692+
results = component.run([ChatMessage.from_user("Tell me about Paris. Respond in JSON.")])
693+
694+
assert len(results["replies"]) == 1
695+
message = results["replies"][0]
696+
assert message.text
697+
698+
parsed = json.loads(message.text)
699+
assert "name" in parsed
700+
assert "country" in parsed
701+
assert "population" in parsed
702+
703+
def test_live_run_with_structured_output_dict_schema(self):
704+
"""Test that response_format with a JSON schema dict returns valid structured JSON output."""
705+
schema = {
706+
"type": "object",
707+
"properties": {
708+
"name": {"type": "string"},
709+
"country": {"type": "string"},
710+
},
711+
"required": ["name", "country"],
712+
}
713+
714+
component = GoogleGenAIChatGenerator(generation_kwargs={"response_format": schema})
715+
results = component.run([ChatMessage.from_user("Tell me about Paris. Respond in JSON.")])
716+
717+
assert len(results["replies"]) == 1
718+
message = results["replies"][0]
719+
assert message.text
720+
721+
parsed = json.loads(message.text)
722+
assert "name" in parsed
723+
assert "country" in parsed
724+
635725
def test_live_run_agent_with_images_in_tool_result(self, test_files_path):
636726
def retrieve_image():
637727
return [
@@ -763,6 +853,26 @@ async def test_live_run_async_with_thinking_unsupported_model_fails_fast(self):
763853
assert "thinking_budget" in error_message or "thinking features" in error_message
764854
assert "Try removing" in error_message or "use a different model" in error_message
765855

856+
async def test_live_run_async_with_structured_output(self):
857+
"""Async integration test for structured output with a Pydantic model."""
858+
859+
class City(BaseModel):
860+
name: str
861+
country: str
862+
population: int
863+
864+
component = GoogleGenAIChatGenerator(generation_kwargs={"response_format": City})
865+
results = await component.run_async([ChatMessage.from_user("Tell me about Paris. Respond in JSON.")])
866+
867+
assert len(results["replies"]) == 1
868+
message = results["replies"][0]
869+
assert message.text
870+
871+
parsed = json.loads(message.text)
872+
assert "name" in parsed
873+
assert "country" in parsed
874+
assert "population" in parsed
875+
766876
async def test_concurrent_async_calls(self):
767877
"""Test multiple concurrent async calls."""
768878
component = GoogleGenAIChatGenerator()

integrations/google_genai/tests/test_chat_generator_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
TextContent,
1818
ToolCall,
1919
)
20+
from pydantic import BaseModel
2021

2122
from haystack_integrations.components.generators.google_genai.chat.chat_generator import (
2223
GoogleGenAIChatGenerator,
@@ -27,6 +28,7 @@
2728
_convert_google_genai_response_to_chatmessage,
2829
_convert_message_to_google_genai_format,
2930
_convert_usage_metadata_to_serializable,
31+
_process_response_format,
3032
_process_thinking_config,
3133
)
3234

@@ -160,6 +162,63 @@ def test_process_thinking_config_explicit_include_thoughts():
160162
assert result == {"temperature": 0.5}
161163

162164

165+
def test_process_response_format():
166+
"""Test the _process_response_format function with different response_format values."""
167+
168+
class City(BaseModel):
169+
name: str
170+
country: str
171+
population: int
172+
173+
# Test Pydantic model
174+
generation_kwargs = {"response_format": City, "temperature": 0.7}
175+
result = _process_response_format(generation_kwargs)
176+
177+
# response_format should be replaced with response_schema and response_mime_type
178+
assert "response_format" not in result
179+
assert result["response_schema"] is City
180+
assert result["response_mime_type"] == "application/json"
181+
# Other kwargs should be preserved
182+
assert result["temperature"] == 0.7
183+
184+
# Test JSON schema dict
185+
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
186+
generation_kwargs = {"response_format": schema, "temperature": 0.5}
187+
result = _process_response_format(generation_kwargs)
188+
assert "response_format" not in result
189+
assert result["response_schema"] == schema
190+
assert result["response_mime_type"] == "application/json"
191+
assert result["temperature"] == 0.5
192+
193+
# Test when response_format is not present
194+
generation_kwargs = {"temperature": 0.5}
195+
result = _process_response_format(generation_kwargs)
196+
assert result == generation_kwargs # No changes
197+
198+
# Test that native keys take precedence
199+
native_schema = {"type": "object", "properties": {"x": {"type": "string"}}}
200+
generation_kwargs = {
201+
"response_format": City,
202+
"response_schema": native_schema,
203+
"response_mime_type": "application/json",
204+
}
205+
result = _process_response_format(generation_kwargs)
206+
assert "response_format" not in result
207+
assert result["response_schema"] == native_schema
208+
assert result["response_mime_type"] == "application/json"
209+
210+
# Test unsupported type raises TypeError
211+
generation_kwargs = {"response_format": "invalid"}
212+
with pytest.raises(TypeError, match="Unsupported response_format type"):
213+
_process_response_format(generation_kwargs)
214+
215+
# Test that input dict is not mutated
216+
generation_kwargs = {"response_format": City, "temperature": 0.7}
217+
original = generation_kwargs.copy()
218+
_process_response_format(generation_kwargs)
219+
assert generation_kwargs == original
220+
221+
163222
class TestStreamingChunkConversion:
164223
def test_convert_google_chunk_to_streaming_chunk_text_only(self, monkeypatch):
165224
monkeypatch.setenv("GOOGLE_API_KEY", "test-api-key")

0 commit comments

Comments
 (0)