Skip to content

Commit 8c8e05d

Browse files
authored
feat: add support for structured outputs for MistralChatGenerator (#2390)
* Update mistral * Add response format support * Update openrouter * Add structured outputs * Update open router * remove nvidia * remove openrouter * remove extra files * Update the ruff version * Updates * Update the versions * Restore * Update mistral.yml * Serialization * Add example * Fix error * Update tools * Fix haystack version * Update docstring
1 parent 6c8442c commit 8c8e05d

4 files changed

Lines changed: 159 additions & 20 deletions

File tree

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
# This example demonstrates how to use the MistralChatGenerator component
7+
# with structured outputs.
8+
# To run this example, you will need to
9+
# set `MISTRAL_API_KEY` environment variable
10+
11+
from haystack.dataclasses import ChatMessage
12+
from pydantic import BaseModel
13+
14+
from haystack_integrations.components.generators.mistral import MistralChatGenerator
15+
16+
17+
class NobelPrizeInfo(BaseModel):
18+
recipient_name: str
19+
award_year: int
20+
category: str
21+
achievement_description: str
22+
nationality: str
23+
24+
25+
chat_messages = [
26+
ChatMessage.from_user(
27+
"In 2021, American scientist David Julius received the Nobel Prize in"
28+
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
29+
" senses temperature and touch."
30+
)
31+
]
32+
component = MistralChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo})
33+
results = component.run(chat_messages)
34+
35+
# print(results)

integrations/mistral/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai>=2.15.1"]
26+
dependencies = ["haystack-ai>=2.19.0"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/mistral#readme"

integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Dict, List, Optional, Union
5+
from typing import Any, Dict, Optional
66

77
from haystack import component, default_to_dict, logging
88
from haystack.components.generators.chat import OpenAIChatGenerator
99
from haystack.dataclasses import ChatMessage, StreamingCallbackT
10-
from haystack.tools import Tool, Toolset
10+
from haystack.tools import ToolsType
1111
from haystack.utils import serialize_callable
1212
from haystack.utils.auth import Secret
13+
from openai.lib._pydantic import to_strict_json_schema
14+
from pydantic import BaseModel
1315

1416
logger = logging.getLogger(__name__)
1517

@@ -64,7 +66,7 @@ def __init__(
6466
streaming_callback: Optional[StreamingCallbackT] = None,
6567
api_base_url: Optional[str] = "https://api.mistral.ai/v1",
6668
generation_kwargs: Optional[Dict[str, Any]] = None,
67-
tools: Optional[Union[List[Tool], Toolset]] = None,
69+
tools: Optional[ToolsType] = None,
6870
*,
6971
timeout: Optional[float] = None,
7072
max_retries: Optional[int] = None,
@@ -98,6 +100,13 @@ def __init__(
98100
events as they become available, with the stream terminated by a data: [DONE] message.
99101
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
100102
- `random_seed`: The seed to use for random sampling.
103+
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
104+
If provided, the output will always be validated against this
105+
format (unless the model returns a tool call).
106+
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
107+
Notes:
108+
- For structured outputs with streaming,
109+
the `response_format` must be a JSON schema and not a Pydantic model.
101110
:param tools:
102111
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
103112
list of `Tool` objects or a `Toolset` instance.
@@ -130,7 +139,7 @@ def _prepare_api_call(
130139
messages: list[ChatMessage],
131140
streaming_callback: Optional[StreamingCallbackT] = None,
132141
generation_kwargs: Optional[dict[str, Any]] = None,
133-
tools: Optional[Union[list[Tool], Toolset]] = None,
142+
tools: Optional[ToolsType] = None,
134143
tools_strict: Optional[bool] = None,
135144
) -> dict[str, Any]:
136145
api_args = super(MistralChatGenerator, self)._prepare_api_call( # noqa: UP008
@@ -154,6 +163,22 @@ def to_dict(self) -> Dict[str, Any]:
154163
The serialized component as a dictionary.
155164
"""
156165
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
166+
generation_kwargs = self.generation_kwargs.copy()
167+
response_format = generation_kwargs.get("response_format")
168+
169+
# If the response format is a Pydantic model, it's converted to openai's json schema format
170+
# If it's already a json schema, it's left as is
171+
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
172+
json_schema = {
173+
"type": "json_schema",
174+
"json_schema": {
175+
"name": response_format.__name__,
176+
"strict": True,
177+
"schema": to_strict_json_schema(response_format),
178+
},
179+
}
180+
181+
generation_kwargs["response_format"] = json_schema
157182

158183
# if we didn't implement the to_dict method here then the to_dict method of the superclass would be used
159184
# which would serialiaze some fields that we don't want to serialize (e.g. the ones we don't have in
@@ -163,7 +188,7 @@ def to_dict(self) -> Dict[str, Any]:
163188
model=self.model,
164189
streaming_callback=callback_name,
165190
api_base_url=self.api_base_url,
166-
generation_kwargs=self.generation_kwargs,
191+
generation_kwargs=generation_kwargs,
167192
api_key=self.api_key.to_dict(),
168193
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
169194
timeout=self.timeout,

integrations/mistral/tests/test_mistral_chat_generator.py

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
1818
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
1919
from openai.types.completion_usage import CompletionUsage
20+
from pydantic import BaseModel
2021

2122
from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator
2223

@@ -136,12 +137,44 @@ def test_to_dict_default(self, monkeypatch):
136137

137138
def test_to_dict_with_parameters(self, monkeypatch):
138139
monkeypatch.setenv("ENV_VAR", "test-api-key")
140+
141+
class NobelPrizeInfo(BaseModel):
142+
recipient_name: str
143+
award_year: int
144+
145+
schema = {
146+
"json_schema": {
147+
"name": "NobelPrizeInfo",
148+
"schema": {
149+
"additionalProperties": False,
150+
"properties": {
151+
"award_year": {
152+
"title": "Award Year",
153+
"type": "integer",
154+
},
155+
"recipient_name": {
156+
"title": "Recipient Name",
157+
"type": "string",
158+
},
159+
},
160+
"required": [
161+
"recipient_name",
162+
"award_year",
163+
],
164+
"title": "NobelPrizeInfo",
165+
"type": "object",
166+
},
167+
"strict": True,
168+
},
169+
"type": "json_schema",
170+
}
171+
139172
component = MistralChatGenerator(
140173
api_key=Secret.from_env_var("ENV_VAR"),
141174
model="mistral-small",
142175
streaming_callback=print_streaming_chunk,
143176
api_base_url="test-base-url",
144-
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
177+
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params", "response_format": NobelPrizeInfo},
145178
)
146179
data = component.to_dict()
147180

@@ -155,7 +188,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
155188
"model": "mistral-small",
156189
"api_base_url": "test-base-url",
157190
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
158-
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
191+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params", "response_format": schema},
159192
}
160193

161194
for key, value in expected_params.items():
@@ -357,7 +390,7 @@ def test_run_with_params(self, chat_messages, mock_chat_completion, monkeypatch)
357390

358391
@pytest.mark.skipif(
359392
not os.environ.get("MISTRAL_API_KEY", None),
360-
reason="Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test.",
393+
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
361394
)
362395
@pytest.mark.integration
363396
def test_live_run(self):
@@ -372,7 +405,7 @@ def test_live_run(self):
372405

373406
@pytest.mark.skipif(
374407
not os.environ.get("MISTRAL_API_KEY", None),
375-
reason="Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test.",
408+
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
376409
)
377410
@pytest.mark.integration
378411
def test_live_run_wrong_model(self, chat_messages):
@@ -382,7 +415,7 @@ def test_live_run_wrong_model(self, chat_messages):
382415

383416
@pytest.mark.skipif(
384417
not os.environ.get("MISTRAL_API_KEY", None),
385-
reason="Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test.",
418+
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
386419
)
387420
@pytest.mark.integration
388421
def test_live_run_streaming(self):
@@ -411,17 +444,25 @@ def __call__(self, chunk: StreamingChunk) -> None:
411444

412445
@pytest.mark.skipif(
413446
not os.environ.get("MISTRAL_API_KEY", None),
414-
reason="Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test.",
447+
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
415448
)
416449
@pytest.mark.integration
417450
def test_live_run_response_format(self):
451+
class NobelPrizeInfo(BaseModel):
452+
recipient_name: str
453+
award_year: int
454+
category: str
455+
achievement_description: str
456+
nationality: str
457+
418458
chat_messages = [
419459
ChatMessage.from_user(
420-
'Provide the answer in JSON format with a key "answer". What\'s the capital of France?'
421-
'For example, respond with {"answer": "Paris"}.'
460+
"In 2021, American scientist David Julius received the Nobel Prize in"
461+
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
462+
" senses temperature and touch."
422463
)
423464
]
424-
component = MistralChatGenerator(generation_kwargs={"response_format": {"type": "json_object"}})
465+
component = MistralChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo})
425466
results = component.run(chat_messages)
426467
assert isinstance(results, dict)
427468
assert "replies" in results
@@ -430,13 +471,51 @@ def test_live_run_response_format(self):
430471
assert isinstance(results["replies"][0], ChatMessage)
431472
message = results["replies"][0]
432473
assert isinstance(message.text, str)
433-
assert "paris" in message.text.lower()
434474
msg = json.loads(message.text)
435-
assert "answer" in msg
475+
assert msg["recipient_name"] == "David Julius"
476+
assert msg["award_year"] == 2021
477+
assert "category" in msg
478+
assert "achievement_description" in msg
479+
assert msg["nationality"] == "American"
436480

437481
@pytest.mark.skipif(
438482
not os.environ.get("MISTRAL_API_KEY", None),
439-
reason="Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test.",
483+
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
484+
)
485+
@pytest.mark.integration
486+
def test_live_run_with_response_format_json_schema(self):
487+
response_schema = {
488+
"type": "json_schema",
489+
"json_schema": {
490+
"name": "CapitalCity",
491+
"strict": True,
492+
"schema": {
493+
"title": "CapitalCity",
494+
"type": "object",
495+
"properties": {
496+
"city": {"title": "City", "type": "string"},
497+
"country": {"title": "Country", "type": "string"},
498+
},
499+
"required": ["city", "country"],
500+
"additionalProperties": False,
501+
},
502+
},
503+
}
504+
505+
chat_messages = [ChatMessage.from_user("What's the capital of France?")]
506+
comp = MistralChatGenerator(generation_kwargs={"response_format": response_schema})
507+
results = comp.run(chat_messages)
508+
assert len(results["replies"]) == 1
509+
message: ChatMessage = results["replies"][0]
510+
msg = json.loads(message.text)
511+
assert "Paris" in msg["city"]
512+
assert isinstance(msg["country"], str)
513+
assert "France" in msg["country"]
514+
assert message.meta["finish_reason"] == "stop"
515+
516+
@pytest.mark.skipif(
517+
not os.environ.get("MISTRAL_API_KEY", None),
518+
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
440519
)
441520
@pytest.mark.integration
442521
def test_live_run_with_tools(self, tools):
@@ -456,7 +535,7 @@ def test_live_run_with_tools(self, tools):
456535

457536
@pytest.mark.skipif(
458537
not os.environ.get("MISTRAL_API_KEY", None),
459-
reason="Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test.",
538+
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
460539
)
461540
@pytest.mark.integration
462541
def test_live_run_with_tools_and_response(self, tools):
@@ -504,7 +583,7 @@ def test_live_run_with_tools_and_response(self, tools):
504583

505584
@pytest.mark.skipif(
506585
not os.environ.get("MISTRAL_API_KEY", None),
507-
reason="Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test.",
586+
reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.",
508587
)
509588
@pytest.mark.integration
510589
def test_live_run_with_tools_streaming(self, tools):

0 commit comments

Comments
 (0)