Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0


# This example demonstrates how to use the STACKITChatGenerator component
# with structured outputs.
# To run this example, you will need to
# set `STACKIT_API_KEY` environment variable

from haystack.dataclasses import ChatMessage
from pydantic import BaseModel

from haystack_integrations.components.generators.stackit import STACKITChatGenerator


class NobelPrizeInfo(BaseModel):
recipient_name: str
award_year: int
category: str
achievement_description: str
nationality: str


chat_messages = [
ChatMessage.from_user(
"In 2021, American scientist David Julius received the Nobel Prize in"
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
" senses temperature and touch."
)
]
component = STACKITChatGenerator(
model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8", generation_kwargs={"response_format": NobelPrizeInfo}
)
results = component.run(chat_messages)

# print(results)
Comment thread
Amnah199 marked this conversation as resolved.
2 changes: 1 addition & 1 deletion integrations/stackit/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.13.0"]
dependencies = ["haystack-ai>=2.19.0"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/stackit#readme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from haystack.dataclasses import StreamingCallbackT
from haystack.utils import serialize_callable
from haystack.utils.auth import Secret
from openai.lib._pydantic import to_strict_json_schema
from pydantic import BaseModel


@component
Expand Down Expand Up @@ -74,6 +76,13 @@ def __init__(
events as they become available, with the stream terminated by a data: [DONE] message.
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
- `random_seed`: The seed to use for random sampling.
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
If provided, the output will always be validated against this
format (unless the model returns a tool call).
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
Notes:
- For structured outputs with streaming,
the `response_format` must be a JSON schema and not a Pydantic model.
:param timeout:
Timeout for STACKIT client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
variable, or 30 seconds.
Expand Down Expand Up @@ -104,6 +113,21 @@ def to_dict(self) -> dict[str, Any]:
The serialized component as a dictionary.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
generation_kwargs = self.generation_kwargs.copy()
response_format = generation_kwargs.get("response_format")
# If the response format is a Pydantic model, it's converted to openai's json schema format
# If it's already a json schema, it's left as is
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
json_schema = {
"type": "json_schema",
"json_schema": {
"name": response_format.__name__,
"strict": True,
"schema": to_strict_json_schema(response_format),
},
}

generation_kwargs["response_format"] = json_schema

# if we didn't implement the to_dict method here then the to_dict method of the superclass would be used
# which would serialiaze some fields that we don't want to serialize (e.g. the ones we don't have in
Expand All @@ -114,7 +138,7 @@ def to_dict(self) -> dict[str, Any]:
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
generation_kwargs=generation_kwargs,
api_key=self.api_key.to_dict(),
timeout=self.timeout,
max_retries=self.max_retries,
Expand Down
102 changes: 99 additions & 3 deletions integrations/stackit/tests/test_stackit_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from datetime import datetime
from unittest.mock import patch
Expand All @@ -11,10 +12,22 @@
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from pydantic import BaseModel

from haystack_integrations.components.generators.stackit.chat.chat_generator import STACKITChatGenerator


class CalendarEvent(BaseModel):
event_name: str
event_date: str
event_location: str


@pytest.fixture
def calendar_event_model():
return CalendarEvent


@pytest.fixture
def chat_messages():
return [
Expand Down Expand Up @@ -101,14 +114,18 @@ def test_to_dict_default(self, monkeypatch):
for key, value in expected_params.items():
assert data["init_parameters"][key] == value

def test_to_dict_with_parameters(self, monkeypatch):
def test_to_dict_with_parameters(self, monkeypatch, calendar_event_model):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = STACKITChatGenerator(
api_key=Secret.from_env_var("ENV_VAR"),
model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8",
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
generation_kwargs={
"max_tokens": 10,
"some_test_param": "test-params",
"response_format": calendar_event_model,
},
timeout=10.0,
max_retries=2,
http_client_kwargs={"proxy": "https://proxy.example.com:8080"},
Expand All @@ -125,7 +142,28 @@ def test_to_dict_with_parameters(self, monkeypatch):
"model": "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"generation_kwargs": {
"max_tokens": 10,
"some_test_param": "test-params",
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "CalendarEvent",
"strict": True,
"schema": {
"properties": {
"event_name": {"title": "Event Name", "type": "string"},
"event_date": {"title": "Event Date", "type": "string"},
"event_location": {"title": "Event Location", "type": "string"},
},
"required": ["event_name", "event_date", "event_location"],
"title": "CalendarEvent",
"type": "object",
"additionalProperties": False,
},
},
},
},
"timeout": 10.0,
"max_retries": 2,
"http_client_kwargs": {"proxy": "https://proxy.example.com:8080"},
Expand Down Expand Up @@ -254,3 +292,61 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert callback.counter > 1
assert "Paris" in callback.responses

@pytest.mark.skipif(
not os.environ.get("STACKIT_API_KEY", None),
reason="Export an env var called STACKIT_API_KEY containing the STACKIT API key to run this test.",
)
@pytest.mark.integration
def test_live_run_with_response_format_json_schema(self):
response_schema = {
"type": "json_schema",
"json_schema": {
"name": "CapitalCity",
"strict": True,
"schema": {
"title": "CapitalCity",
"type": "object",
"properties": {
"city": {"title": "City", "type": "string"},
"country": {"title": "Country", "type": "string"},
},
"required": ["city", "country"],
"additionalProperties": False,
},
},
}

chat_messages = [ChatMessage.from_user("What's the capital of France?")]
comp = STACKITChatGenerator(
model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8", generation_kwargs={"response_format": response_schema}
)
results = comp.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
msg = json.loads(message.text)
assert "Paris" in msg["city"]
assert isinstance(msg["country"], str)
assert "France" in msg["country"]
assert message.meta["finish_reason"] == "stop"

@pytest.mark.skipif(
not os.environ.get("STACKIT_API_KEY", None),
reason="Export an env var called STACKIT_API_KEY containing the STACKIT API key to run this test.",
)
@pytest.mark.integration
def test_live_run_with_response_format_pydantic_model(self, calendar_event_model):
chat_messages = [
ChatMessage.from_user("The marketing summit takes place on October12th at the Hilton Hotel downtown.")
]
component = STACKITChatGenerator(
model="neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8",
generation_kwargs={"response_format": calendar_event_model},
)
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
msg = json.loads(message.text)
assert "Marketing Summit" in msg["event_name"]
assert isinstance(msg["event_date"], str)
assert isinstance(msg["event_location"], str)