22#
33# SPDX-License-Identifier: Apache-2.0
44
5- from typing import Any , Dict , List , Optional , Union
5+ from typing import Any , Dict , Optional
66
77from haystack import component , default_to_dict , logging
88from haystack .components .generators .chat import OpenAIChatGenerator
99from haystack .dataclasses import ChatMessage , StreamingCallbackT
10- from haystack .tools import Tool , Toolset , _check_duplicate_tool_names
10+ from haystack .tools import ToolsType , _check_duplicate_tool_names , flatten_tools_or_toolsets , serialize_tools_or_toolset
1111from haystack .utils import serialize_callable
1212from haystack .utils .auth import Secret
1313
@@ -64,7 +64,7 @@ def __init__(
6464 streaming_callback : Optional [StreamingCallbackT ] = None ,
6565 api_base_url : Optional [str ] = "https://openrouter.ai/api/v1" ,
6666 generation_kwargs : Optional [Dict [str , Any ]] = None ,
67- tools : Optional [Union [ List [ Tool ], Toolset ] ] = None ,
67+ tools : Optional [ToolsType ] = None ,
6868 timeout : Optional [float ] = None ,
6969 extra_headers : Optional [Dict [str , Any ]] = None ,
7070 max_retries : Optional [int ] = None ,
@@ -98,6 +98,14 @@ def __init__(
9898 events as they become available, with the stream terminated by a data: [DONE] message.
9999 - `safe_prompt`: Whether to inject a safety prompt before all conversations.
100100 - `random_seed`: The seed to use for random sampling.
101+ - `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
102+ If provided, the output will always be validated against this
103+ format (unless the model returns a tool call).
104+ For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
105+ Notes:
106+ - This parameter accepts Pydantic models and JSON schemas for latest models starting from GPT-4o.
107+ - For structured outputs with streaming,
108+ the `response_format` must be a JSON schema and not a Pydantic model.
101109 :param tools:
102110 A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
103111 list of `Tool` objects or a `Toolset` instance.
@@ -148,7 +156,7 @@ def to_dict(self) -> Dict[str, Any]:
148156 api_base_url = self .api_base_url ,
149157 generation_kwargs = self .generation_kwargs ,
150158 api_key = self .api_key .to_dict (),
151- tools = [ tool . to_dict () for tool in self .tools ] if self . tools else None ,
159+ tools = serialize_tools_or_toolset ( self .tools ) ,
152160 extra_headers = self .extra_headers ,
153161 timeout = self .timeout ,
154162 max_retries = self .max_retries ,
@@ -158,46 +166,64 @@ def to_dict(self) -> Dict[str, Any]:
158166 def _prepare_api_call (
159167 self ,
160168 * ,
161- messages : List [ChatMessage ],
169+ messages : list [ChatMessage ],
162170 streaming_callback : Optional [StreamingCallbackT ] = None ,
163- generation_kwargs : Optional [Dict [str , Any ]] = None ,
164- tools : Optional [Union [ List [ Tool ], Toolset ] ] = None ,
171+ generation_kwargs : Optional [dict [str , Any ]] = None ,
172+ tools : Optional [ToolsType ] = None ,
165173 tools_strict : Optional [bool ] = None ,
166- ) -> Dict [str , Any ]:
174+ ) -> dict [str , Any ]:
167175 # update generation kwargs by merging with the generation kwargs passed to the run method
168176 generation_kwargs = {** self .generation_kwargs , ** (generation_kwargs or {})}
169177 extra_headers = {** (self .extra_headers or {})}
170178
179+ is_streaming = streaming_callback is not None
180+ num_responses = generation_kwargs .pop ("n" , 1 )
181+
182+ if is_streaming and num_responses > 1 :
183+ msg = "Cannot stream multiple responses, please set n=1."
184+ raise ValueError (msg )
185+ response_format = generation_kwargs .pop ("response_format" , None )
186+
171187 # adapt ChatMessage(s) to the format expected by the OpenAI API
172188 openai_formatted_messages = [message .to_openai_dict_format () for message in messages ]
173189
174- tools = tools or self .tools
175- if isinstance (tools , Toolset ):
176- tools = list (tools )
190+ flattened_tools = flatten_tools_or_toolsets (tools or self .tools )
177191 tools_strict = tools_strict if tools_strict is not None else self .tools_strict
178- _check_duplicate_tool_names (list ( tools or []) )
192+ _check_duplicate_tool_names (flattened_tools )
179193
180194 openai_tools = {}
181- if tools :
182- tool_definitions = [
183- {"type" : "function" , "function" : {** t .tool_spec , ** ({"strict" : tools_strict } if tools_strict else {})}}
184- for t in tools
185- ]
195+ if flattened_tools :
196+ tool_definitions = []
197+ for t in flattened_tools :
198+ function_spec = {** t .tool_spec }
199+ if tools_strict :
200+ function_spec ["strict" ] = True
201+ function_spec ["parameters" ]["additionalProperties" ] = False
202+ tool_definitions .append ({"type" : "function" , "function" : function_spec })
186203 openai_tools = {"tools" : tool_definitions }
187204
188- is_streaming = streaming_callback is not None
189- num_responses = generation_kwargs .pop ("n" , 1 )
190- if is_streaming and num_responses > 1 :
191- msg = "Cannot stream multiple responses, please set n=1."
192- raise ValueError (msg )
193-
194- return {
205+ base_args = {
195206 "model" : self .model ,
196- "messages" : openai_formatted_messages , # type: ignore[arg-type] # openai expects list of specific message types
197- "stream" : streaming_callback is not None ,
207+ "messages" : openai_formatted_messages ,
198208 "n" : num_responses ,
199209 ** openai_tools ,
200- "extra_body" : {** generation_kwargs },
201210 "extra_headers" : {** extra_headers },
202- "openai_endpoint " : "create" ,
211+ "extra_body " : { ** generation_kwargs } ,
203212 }
213+
214+ if response_format and not is_streaming :
215+ # for structured outputs without streaming, we use openai's parse endpoint
216+ # Note: `stream` cannot be passed to chat.completions.parse
217+ # we pass a key `openai_endpoint` as a hint to the run method to use the parse endpoint
218+ # this key will be removed before the API call is made
219+ return {** base_args , "response_format" : response_format , "openai_endpoint" : "parse" }
220+
221+ # for structured outputs with streaming, we use openai's create endpoint
222+ # we pass a key `openai_endpoint` as a hint to the run method to use the create endpoint
223+ # this key will be removed before the API call is made
224+ final_args = {** base_args , "stream" : is_streaming , "openai_endpoint" : "create" }
225+
226+ # We only set the response_format parameter if it's not None since None is not a valid value in the API.
227+ if response_format :
228+ final_args ["response_format" ] = response_format
229+ return final_args
0 commit comments