diff --git a/integrations/stackit/examples/chat_generators_with_structured_outputs.py b/integrations/stackit/examples/chat_generators_with_structured_outputs.py new file mode 100644 index 0000000000..7b30f1a44b --- /dev/null +++ b/integrations/stackit/examples/chat_generators_with_structured_outputs.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +# This example demonstrates how to use the STACKITChatGenerator component +# with structured outputs. +# To run this example, you will need to +# set `STACKIT_API_KEY` environment variable + +from haystack.dataclasses import ChatMessage +from pydantic import BaseModel + +from haystack_integrations.components.generators.stackit import STACKITChatGenerator + + +class NobelPrizeInfo(BaseModel): + recipient_name: str + award_year: int + category: str + achievement_description: str + nationality: str + + +chat_messages = [ + ChatMessage.from_user( + "In 2021, American scientist David Julius received the Nobel Prize in" + " Physiology or Medicine for his groundbreaking discoveries on how the human body" + " senses temperature and touch." + ) +] +component = STACKITChatGenerator( + model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8", generation_kwargs={"response_format": NobelPrizeInfo} +) +results = component.run(chat_messages) + +# print(results) diff --git a/integrations/stackit/pyproject.toml b/integrations/stackit/pyproject.toml index 470c2f4750..280a33cdd6 100644 --- a/integrations/stackit/pyproject.toml +++ b/integrations/stackit/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.13.0"] +dependencies = ["haystack-ai>=2.19.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/stackit#readme" diff --git a/integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py b/integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py index d937ead733..3b006560be 100644 --- a/integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py +++ b/integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py @@ -8,6 +8,8 @@ from haystack.dataclasses import StreamingCallbackT from haystack.utils import serialize_callable from haystack.utils.auth import Secret +from openai.lib._pydantic import to_strict_json_schema +from pydantic import BaseModel @component @@ -74,6 +76,13 @@ def __init__( events as they become available, with the stream terminated by a data: [DONE] message. - `safe_prompt`: Whether to inject a safety prompt before all conversations. - `random_seed`: The seed to use for random sampling. + - `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response. + If provided, the output will always be validated against this + format (unless the model returns a tool call). + For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs). + Notes: + - For structured outputs with streaming, + the `response_format` must be a JSON schema and not a Pydantic model. :param timeout: Timeout for STACKIT client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment variable, or 30 seconds. @@ -104,6 +113,21 @@ def to_dict(self) -> dict[str, Any]: The serialized component as a dictionary. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + generation_kwargs = self.generation_kwargs.copy() + response_format = generation_kwargs.get("response_format") + # If the response format is a Pydantic model, it's converted to openai's json schema format + # If it's already a json schema, it's left as is + if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel): + json_schema = { + "type": "json_schema", + "json_schema": { + "name": response_format.__name__, + "strict": True, + "schema": to_strict_json_schema(response_format), + }, + } + + generation_kwargs["response_format"] = json_schema # if we didn't implement the to_dict method here then the to_dict method of the superclass would be used # which would serialiaze some fields that we don't want to serialize (e.g. the ones we don't have in @@ -114,7 +138,7 @@ def to_dict(self) -> dict[str, Any]: model=self.model, streaming_callback=callback_name, api_base_url=self.api_base_url, - generation_kwargs=self.generation_kwargs, + generation_kwargs=generation_kwargs, api_key=self.api_key.to_dict(), timeout=self.timeout, max_retries=self.max_retries, diff --git a/integrations/stackit/tests/test_stackit_chat_generator.py b/integrations/stackit/tests/test_stackit_chat_generator.py index 162af740f9..a31084f293 100644 --- a/integrations/stackit/tests/test_stackit_chat_generator.py +++ b/integrations/stackit/tests/test_stackit_chat_generator.py @@ -1,3 +1,4 @@ +import json import os from datetime import datetime from unittest.mock import patch @@ -11,10 +12,22 @@ from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice +from pydantic import BaseModel from haystack_integrations.components.generators.stackit.chat.chat_generator import STACKITChatGenerator +class CalendarEvent(BaseModel): + event_name: str + event_date: str + event_location: str + + +@pytest.fixture +def calendar_event_model(): + return CalendarEvent + + @pytest.fixture def chat_messages(): return [ @@ -101,14 +114,18 @@ def test_to_dict_default(self, monkeypatch): for key, value in expected_params.items(): assert data["init_parameters"][key] == value - def test_to_dict_with_parameters(self, monkeypatch): + def test_to_dict_with_parameters(self, monkeypatch, calendar_event_model): monkeypatch.setenv("ENV_VAR", "test-api-key") component = STACKITChatGenerator( api_key=Secret.from_env_var("ENV_VAR"), model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8", streaming_callback=print_streaming_chunk, api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + generation_kwargs={ + "max_tokens": 10, + "some_test_param": "test-params", + "response_format": calendar_event_model, + }, timeout=10.0, max_retries=2, http_client_kwargs={"proxy": "https://proxy.example.com:8080"}, @@ -125,7 +142,28 @@ def test_to_dict_with_parameters(self, monkeypatch): "model": "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "generation_kwargs": { + "max_tokens": 10, + "some_test_param": "test-params", + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "CalendarEvent", + "strict": True, + "schema": { + "properties": { + "event_name": {"title": "Event Name", "type": "string"}, + "event_date": {"title": "Event Date", "type": "string"}, + "event_location": {"title": "Event Location", "type": "string"}, + }, + "required": ["event_name", "event_date", "event_location"], + "title": "CalendarEvent", + "type": "object", + "additionalProperties": False, + }, + }, + }, + }, "timeout": 10.0, "max_retries": 2, "http_client_kwargs": {"proxy": "https://proxy.example.com:8080"}, @@ -254,3 +292,61 @@ def __call__(self, chunk: StreamingChunk) -> None: assert callback.counter > 1 assert "Paris" in callback.responses + + @pytest.mark.skipif( + not os.environ.get("STACKIT_API_KEY", None), + reason="Export an env var called STACKIT_API_KEY containing the STACKIT API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_response_format_json_schema(self): + response_schema = { + "type": "json_schema", + "json_schema": { + "name": "CapitalCity", + "strict": True, + "schema": { + "title": "CapitalCity", + "type": "object", + "properties": { + "city": {"title": "City", "type": "string"}, + "country": {"title": "Country", "type": "string"}, + }, + "required": ["city", "country"], + "additionalProperties": False, + }, + }, + } + + chat_messages = [ChatMessage.from_user("What's the capital of France?")] + comp = STACKITChatGenerator( + model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8", generation_kwargs={"response_format": response_schema} + ) + results = comp.run(chat_messages) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + msg = json.loads(message.text) + assert "Paris" in msg["city"] + assert isinstance(msg["country"], str) + assert "France" in msg["country"] + assert message.meta["finish_reason"] == "stop" + + @pytest.mark.skipif( + not os.environ.get("STACKIT_API_KEY", None), + reason="Export an env var called STACKIT_API_KEY containing the STACKIT API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_response_format_pydantic_model(self, calendar_event_model): + chat_messages = [ + ChatMessage.from_user("The marketing summit takes place on October12th at the Hilton Hotel downtown.") + ] + component = STACKITChatGenerator( + model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8", + generation_kwargs={"response_format": calendar_event_model}, + ) + results = component.run(chat_messages) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + msg = json.loads(message.text) + assert "Marketing Summit" in msg["event_name"] + assert isinstance(msg["event_date"], str) + assert isinstance(msg["event_location"], str)