Skip to content

Commit 6c8442c

Browse files
authored
feat: add support for structured outputs in OpenRouterChatGenerator (#2406)
* Add support * Updates * Update workflow * Fix license * Fix linting * Update python version * Update files * Fix linting * Update chat_generator.py * Fix tools * Update haystack version * Updates
1 parent c8db40a commit 6c8442c

4 files changed

Lines changed: 157 additions & 30 deletions

File tree

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
# This example demonstrates how to use the OpenRouterChatGenerator component
7+
# with structured outputs.
8+
# To run this example, you will need to
9+
# set `OPENROUTER_API_KEY` environment variable
10+
11+
from haystack.dataclasses import ChatMessage
12+
from pydantic import BaseModel
13+
14+
from haystack_integrations.components.generators.openrouter import OpenRouterChatGenerator
15+
16+
17+
class NobelPrizeInfo(BaseModel):
18+
recipient_name: str
19+
award_year: int
20+
category: str
21+
achievement_description: str
22+
nationality: str
23+
24+
25+
chat_messages = [
26+
ChatMessage.from_user(
27+
"In 2021, American scientist David Julius received the Nobel Prize in"
28+
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
29+
" senses temperature and touch."
30+
)
31+
]
32+
component = OpenRouterChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo})
33+
results = component.run(chat_messages)
34+
35+
# print(results)

integrations/openrouter/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai>=2.13.1"]
26+
dependencies = ["haystack-ai>=2.19.0"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/openrouter#readme"
@@ -154,4 +154,4 @@ addopts = "--strict-markers"
154154
markers = [
155155
"integration: integration tests",
156156
]
157-
log_cli = true
157+
log_cli = true

integrations/openrouter/src/haystack_integrations/components/generators/openrouter/chat/chat_generator.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Dict, List, Optional, Union
5+
from typing import Any, Dict, Optional
66

77
from haystack import component, default_to_dict, logging
88
from haystack.components.generators.chat import OpenAIChatGenerator
99
from haystack.dataclasses import ChatMessage, StreamingCallbackT
10-
from haystack.tools import Tool, Toolset, _check_duplicate_tool_names
10+
from haystack.tools import ToolsType, _check_duplicate_tool_names, flatten_tools_or_toolsets, serialize_tools_or_toolset
1111
from haystack.utils import serialize_callable
1212
from haystack.utils.auth import Secret
1313

@@ -64,7 +64,7 @@ def __init__(
6464
streaming_callback: Optional[StreamingCallbackT] = None,
6565
api_base_url: Optional[str] = "https://openrouter.ai/api/v1",
6666
generation_kwargs: Optional[Dict[str, Any]] = None,
67-
tools: Optional[Union[List[Tool], Toolset]] = None,
67+
tools: Optional[ToolsType] = None,
6868
timeout: Optional[float] = None,
6969
extra_headers: Optional[Dict[str, Any]] = None,
7070
max_retries: Optional[int] = None,
@@ -98,6 +98,14 @@ def __init__(
9898
events as they become available, with the stream terminated by a data: [DONE] message.
9999
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
100100
- `random_seed`: The seed to use for random sampling.
101+
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
102+
If provided, the output will always be validated against this
103+
format (unless the model returns a tool call).
104+
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
105+
Notes:
106+
- This parameter accepts Pydantic models and JSON schemas for latest models starting from GPT-4o.
107+
- For structured outputs with streaming,
108+
the `response_format` must be a JSON schema and not a Pydantic model.
101109
:param tools:
102110
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
103111
list of `Tool` objects or a `Toolset` instance.
@@ -148,7 +156,7 @@ def to_dict(self) -> Dict[str, Any]:
148156
api_base_url=self.api_base_url,
149157
generation_kwargs=self.generation_kwargs,
150158
api_key=self.api_key.to_dict(),
151-
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
159+
tools=serialize_tools_or_toolset(self.tools),
152160
extra_headers=self.extra_headers,
153161
timeout=self.timeout,
154162
max_retries=self.max_retries,
@@ -158,46 +166,64 @@ def to_dict(self) -> Dict[str, Any]:
158166
def _prepare_api_call(
159167
self,
160168
*,
161-
messages: List[ChatMessage],
169+
messages: list[ChatMessage],
162170
streaming_callback: Optional[StreamingCallbackT] = None,
163-
generation_kwargs: Optional[Dict[str, Any]] = None,
164-
tools: Optional[Union[List[Tool], Toolset]] = None,
171+
generation_kwargs: Optional[dict[str, Any]] = None,
172+
tools: Optional[ToolsType] = None,
165173
tools_strict: Optional[bool] = None,
166-
) -> Dict[str, Any]:
174+
) -> dict[str, Any]:
167175
# update generation kwargs by merging with the generation kwargs passed to the run method
168176
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
169177
extra_headers = {**(self.extra_headers or {})}
170178

179+
is_streaming = streaming_callback is not None
180+
num_responses = generation_kwargs.pop("n", 1)
181+
182+
if is_streaming and num_responses > 1:
183+
msg = "Cannot stream multiple responses, please set n=1."
184+
raise ValueError(msg)
185+
response_format = generation_kwargs.pop("response_format", None)
186+
171187
# adapt ChatMessage(s) to the format expected by the OpenAI API
172188
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
173189

174-
tools = tools or self.tools
175-
if isinstance(tools, Toolset):
176-
tools = list(tools)
190+
flattened_tools = flatten_tools_or_toolsets(tools or self.tools)
177191
tools_strict = tools_strict if tools_strict is not None else self.tools_strict
178-
_check_duplicate_tool_names(list(tools or []))
192+
_check_duplicate_tool_names(flattened_tools)
179193

180194
openai_tools = {}
181-
if tools:
182-
tool_definitions = [
183-
{"type": "function", "function": {**t.tool_spec, **({"strict": tools_strict} if tools_strict else {})}}
184-
for t in tools
185-
]
195+
if flattened_tools:
196+
tool_definitions = []
197+
for t in flattened_tools:
198+
function_spec = {**t.tool_spec}
199+
if tools_strict:
200+
function_spec["strict"] = True
201+
function_spec["parameters"]["additionalProperties"] = False
202+
tool_definitions.append({"type": "function", "function": function_spec})
186203
openai_tools = {"tools": tool_definitions}
187204

188-
is_streaming = streaming_callback is not None
189-
num_responses = generation_kwargs.pop("n", 1)
190-
if is_streaming and num_responses > 1:
191-
msg = "Cannot stream multiple responses, please set n=1."
192-
raise ValueError(msg)
193-
194-
return {
205+
base_args = {
195206
"model": self.model,
196-
"messages": openai_formatted_messages, # type: ignore[arg-type] # openai expects list of specific message types
197-
"stream": streaming_callback is not None,
207+
"messages": openai_formatted_messages,
198208
"n": num_responses,
199209
**openai_tools,
200-
"extra_body": {**generation_kwargs},
201210
"extra_headers": {**extra_headers},
202-
"openai_endpoint": "create",
211+
"extra_body": {**generation_kwargs},
203212
}
213+
214+
if response_format and not is_streaming:
215+
# for structured outputs without streaming, we use openai's parse endpoint
216+
# Note: `stream` cannot be passed to chat.completions.parse
217+
# we pass a key `openai_endpoint` as a hint to the run method to use the parse endpoint
218+
# this key will be removed before the API call is made
219+
return {**base_args, "response_format": response_format, "openai_endpoint": "parse"}
220+
221+
# for structured outputs with streaming, we use openai's create endpoint
222+
# we pass a key `openai_endpoint` as a hint to the run method to use the create endpoint
223+
# this key will be removed before the API call is made
224+
final_args = {**base_args, "stream": is_streaming, "openai_endpoint": "create"}
225+
226+
# We only set the response_format parameter if it's not None since None is not a valid value in the API.
227+
if response_format:
228+
final_args["response_format"] = response_format
229+
return final_args

integrations/openrouter/tests/test_openrouter_chat_generator.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
from datetime import datetime
34
from unittest.mock import patch
@@ -16,10 +17,22 @@
1617
from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
1718
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
1819
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
20+
from pydantic import BaseModel
1921

2022
from haystack_integrations.components.generators.openrouter.chat.chat_generator import OpenRouterChatGenerator
2123

2224

25+
class CalendarEvent(BaseModel):
26+
event_name: str
27+
event_date: str
28+
event_location: str
29+
30+
31+
@pytest.fixture
32+
def calendar_event_model():
33+
return CalendarEvent
34+
35+
2336
class CollectorCallback:
2437
"""
2538
Callback to collect streaming chunks for testing purposes.
@@ -440,6 +453,41 @@ def test_pipeline_with_openrouter_chat_generator(self, tools):
440453
== results["tool_invoker"]["tool_messages"][0].tool_call_result.result
441454
)
442455

456+
@pytest.mark.skipif(
457+
not os.environ.get("OPENROUTER_API_KEY", None),
458+
reason="Export an env var called OPENROUTER_API_KEY containing the OpenRouter API key to run this test.",
459+
)
460+
@pytest.mark.integration
461+
def test_live_run_with_response_format_json_schema(self):
462+
response_schema = {
463+
"type": "json_schema",
464+
"json_schema": {
465+
"name": "CapitalCity",
466+
"strict": True,
467+
"schema": {
468+
"title": "CapitalCity",
469+
"type": "object",
470+
"properties": {
471+
"city": {"title": "City", "type": "string"},
472+
"country": {"title": "Country", "type": "string"},
473+
},
474+
"required": ["city", "country"],
475+
"additionalProperties": False,
476+
},
477+
},
478+
}
479+
480+
chat_messages = [ChatMessage.from_user("What's the capital of France?")]
481+
comp = OpenRouterChatGenerator(generation_kwargs={"response_format": response_schema})
482+
results = comp.run(chat_messages)
483+
assert len(results["replies"]) == 1
484+
message: ChatMessage = results["replies"][0]
485+
msg = json.loads(message.text)
486+
assert "Paris" in msg["city"]
487+
assert isinstance(msg["country"], str)
488+
assert "France" in msg["country"]
489+
assert message.meta["finish_reason"] == "stop"
490+
443491
def test_serde_in_pipeline(self, monkeypatch):
444492
"""
445493
Test serialization/deserialization of OpenRouterChatGenerator in a Pipeline,
@@ -539,6 +587,24 @@ def test_serde_in_pipeline(self, monkeypatch):
539587
assert loaded_generator.tools[0].description == generator.tools[0].description
540588
assert loaded_generator.tools[0].parameters == generator.tools[0].parameters
541589

590+
@pytest.mark.skipif(
591+
not os.environ.get("OPENROUTER_API_KEY", None),
592+
reason="Export an env var called OPENROUTER_API_KEY containing the OpenRouter API key to run this test.",
593+
)
594+
@pytest.mark.integration
595+
def test_live_run_with_response_format_pydantic_model(self, calendar_event_model):
596+
chat_messages = [
597+
ChatMessage.from_user("The marketing summit takes place on October12th at the Hilton Hotel downtown.")
598+
]
599+
component = OpenRouterChatGenerator(generation_kwargs={"response_format": calendar_event_model})
600+
results = component.run(chat_messages)
601+
assert len(results["replies"]) == 1
602+
message: ChatMessage = results["replies"][0]
603+
msg = json.loads(message.text)
604+
assert "Marketing Summit" in msg["event_name"]
605+
assert isinstance(msg["event_date"], str)
606+
assert isinstance(msg["event_location"], str)
607+
542608

543609
class TestChatCompletionChunkConversion:
544610
def test_handle_stream_response(self):

0 commit comments

Comments
 (0)