Skip to content

Commit a54ea66

Browse files
authored
feat: support structured outputs in STACKITChatGenerator (#2536)
* support structured output * Add an example * Update tests
1 parent c51d305 commit a54ea66

4 files changed

Lines changed: 162 additions & 5 deletions

File tree

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
# This example demonstrates how to use the STACKITChatGenerator component
7+
# with structured outputs.
8+
# To run this example, you will need to
9+
# set `STACKIT_API_KEY` environment variable
10+
11+
from haystack.dataclasses import ChatMessage
12+
from pydantic import BaseModel
13+
14+
from haystack_integrations.components.generators.stackit import STACKITChatGenerator
15+
16+
17+
class NobelPrizeInfo(BaseModel):
18+
recipient_name: str
19+
award_year: int
20+
category: str
21+
achievement_description: str
22+
nationality: str
23+
24+
25+
chat_messages = [
26+
ChatMessage.from_user(
27+
"In 2021, American scientist David Julius received the Nobel Prize in"
28+
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
29+
" senses temperature and touch."
30+
)
31+
]
32+
component = STACKITChatGenerator(
33+
model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8", generation_kwargs={"response_format": NobelPrizeInfo}
34+
)
35+
results = component.run(chat_messages)
36+
37+
# print(results)

integrations/stackit/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai>=2.13.0"]
26+
dependencies = ["haystack-ai>=2.19.0"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/stackit#readme"

integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from haystack.dataclasses import StreamingCallbackT
99
from haystack.utils import serialize_callable
1010
from haystack.utils.auth import Secret
11+
from openai.lib._pydantic import to_strict_json_schema
12+
from pydantic import BaseModel
1113

1214

1315
@component
@@ -74,6 +76,13 @@ def __init__(
7476
events as they become available, with the stream terminated by a data: [DONE] message.
7577
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
7678
- `random_seed`: The seed to use for random sampling.
79+
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
80+
If provided, the output will always be validated against this
81+
format (unless the model returns a tool call).
82+
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
83+
Notes:
84+
- For structured outputs with streaming,
85+
the `response_format` must be a JSON schema and not a Pydantic model.
7786
:param timeout:
7887
Timeout for STACKIT client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
7988
variable, or 30 seconds.
@@ -104,6 +113,21 @@ def to_dict(self) -> dict[str, Any]:
104113
The serialized component as a dictionary.
105114
"""
106115
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
116+
generation_kwargs = self.generation_kwargs.copy()
117+
response_format = generation_kwargs.get("response_format")
118+
# If the response format is a Pydantic model, it's converted to openai's json schema format
119+
# If it's already a json schema, it's left as is
120+
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
121+
json_schema = {
122+
"type": "json_schema",
123+
"json_schema": {
124+
"name": response_format.__name__,
125+
"strict": True,
126+
"schema": to_strict_json_schema(response_format),
127+
},
128+
}
129+
130+
generation_kwargs["response_format"] = json_schema
107131

108132
# if we didn't implement the to_dict method here then the to_dict method of the superclass would be used
109133
# which would serialiaze some fields that we don't want to serialize (e.g. the ones we don't have in
@@ -114,7 +138,7 @@ def to_dict(self) -> dict[str, Any]:
114138
model=self.model,
115139
streaming_callback=callback_name,
116140
api_base_url=self.api_base_url,
117-
generation_kwargs=self.generation_kwargs,
141+
generation_kwargs=generation_kwargs,
118142
api_key=self.api_key.to_dict(),
119143
timeout=self.timeout,
120144
max_retries=self.max_retries,

integrations/stackit/tests/test_stackit_chat_generator.py

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
from datetime import datetime
34
from unittest.mock import patch
@@ -11,10 +12,22 @@
1112
from openai.types import CompletionUsage
1213
from openai.types.chat import ChatCompletion, ChatCompletionMessage
1314
from openai.types.chat.chat_completion import Choice
15+
from pydantic import BaseModel
1416

1517
from haystack_integrations.components.generators.stackit.chat.chat_generator import STACKITChatGenerator
1618

1719

20+
class CalendarEvent(BaseModel):
21+
event_name: str
22+
event_date: str
23+
event_location: str
24+
25+
26+
@pytest.fixture
27+
def calendar_event_model():
28+
return CalendarEvent
29+
30+
1831
@pytest.fixture
1932
def chat_messages():
2033
return [
@@ -101,14 +114,18 @@ def test_to_dict_default(self, monkeypatch):
101114
for key, value in expected_params.items():
102115
assert data["init_parameters"][key] == value
103116

104-
def test_to_dict_with_parameters(self, monkeypatch):
117+
def test_to_dict_with_parameters(self, monkeypatch, calendar_event_model):
105118
monkeypatch.setenv("ENV_VAR", "test-api-key")
106119
component = STACKITChatGenerator(
107120
api_key=Secret.from_env_var("ENV_VAR"),
108121
model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8",
109122
streaming_callback=print_streaming_chunk,
110123
api_base_url="test-base-url",
111-
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
124+
generation_kwargs={
125+
"max_tokens": 10,
126+
"some_test_param": "test-params",
127+
"response_format": calendar_event_model,
128+
},
112129
timeout=10.0,
113130
max_retries=2,
114131
http_client_kwargs={"proxy": "https://proxy.example.com:8080"},
@@ -125,7 +142,28 @@ def test_to_dict_with_parameters(self, monkeypatch):
125142
"model": "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8",
126143
"api_base_url": "test-base-url",
127144
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
128-
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
145+
"generation_kwargs": {
146+
"max_tokens": 10,
147+
"some_test_param": "test-params",
148+
"response_format": {
149+
"type": "json_schema",
150+
"json_schema": {
151+
"name": "CalendarEvent",
152+
"strict": True,
153+
"schema": {
154+
"properties": {
155+
"event_name": {"title": "Event Name", "type": "string"},
156+
"event_date": {"title": "Event Date", "type": "string"},
157+
"event_location": {"title": "Event Location", "type": "string"},
158+
},
159+
"required": ["event_name", "event_date", "event_location"],
160+
"title": "CalendarEvent",
161+
"type": "object",
162+
"additionalProperties": False,
163+
},
164+
},
165+
},
166+
},
129167
"timeout": 10.0,
130168
"max_retries": 2,
131169
"http_client_kwargs": {"proxy": "https://proxy.example.com:8080"},
@@ -254,3 +292,61 @@ def __call__(self, chunk: StreamingChunk) -> None:
254292

255293
assert callback.counter > 1
256294
assert "Paris" in callback.responses
295+
296+
@pytest.mark.skipif(
297+
not os.environ.get("STACKIT_API_KEY", None),
298+
reason="Export an env var called STACKIT_API_KEY containing the STACKIT API key to run this test.",
299+
)
300+
@pytest.mark.integration
301+
def test_live_run_with_response_format_json_schema(self):
302+
response_schema = {
303+
"type": "json_schema",
304+
"json_schema": {
305+
"name": "CapitalCity",
306+
"strict": True,
307+
"schema": {
308+
"title": "CapitalCity",
309+
"type": "object",
310+
"properties": {
311+
"city": {"title": "City", "type": "string"},
312+
"country": {"title": "Country", "type": "string"},
313+
},
314+
"required": ["city", "country"],
315+
"additionalProperties": False,
316+
},
317+
},
318+
}
319+
320+
chat_messages = [ChatMessage.from_user("What's the capital of France?")]
321+
comp = STACKITChatGenerator(
322+
model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8", generation_kwargs={"response_format": response_schema}
323+
)
324+
results = comp.run(chat_messages)
325+
assert len(results["replies"]) == 1
326+
message: ChatMessage = results["replies"][0]
327+
msg = json.loads(message.text)
328+
assert "Paris" in msg["city"]
329+
assert isinstance(msg["country"], str)
330+
assert "France" in msg["country"]
331+
assert message.meta["finish_reason"] == "stop"
332+
333+
@pytest.mark.skipif(
334+
not os.environ.get("STACKIT_API_KEY", None),
335+
reason="Export an env var called STACKIT_API_KEY containing the STACKIT API key to run this test.",
336+
)
337+
@pytest.mark.integration
338+
def test_live_run_with_response_format_pydantic_model(self, calendar_event_model):
339+
chat_messages = [
340+
ChatMessage.from_user("The marketing summit takes place on October12th at the Hilton Hotel downtown.")
341+
]
342+
component = STACKITChatGenerator(
343+
model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8",
344+
generation_kwargs={"response_format": calendar_event_model},
345+
)
346+
results = component.run(chat_messages)
347+
assert len(results["replies"]) == 1
348+
message: ChatMessage = results["replies"][0]
349+
msg = json.loads(message.text)
350+
assert "Marketing Summit" in msg["event_name"]
351+
assert isinstance(msg["event_date"], str)
352+
assert isinstance(msg["event_location"], str)

0 commit comments

Comments
 (0)