1+ import json
12import os
23from datetime import datetime
34from unittest .mock import patch
1112from openai .types import CompletionUsage
1213from openai .types .chat import ChatCompletion , ChatCompletionMessage
1314from openai .types .chat .chat_completion import Choice
15+ from pydantic import BaseModel
1416
1517from haystack_integrations .components .generators .stackit .chat .chat_generator import STACKITChatGenerator
1618
1719
20+ class CalendarEvent (BaseModel ):
21+ event_name : str
22+ event_date : str
23+ event_location : str
24+
25+
26+ @pytest .fixture
27+ def calendar_event_model ():
28+ return CalendarEvent
29+
30+
1831@pytest .fixture
1932def chat_messages ():
2033 return [
@@ -101,14 +114,18 @@ def test_to_dict_default(self, monkeypatch):
101114 for key , value in expected_params .items ():
102115 assert data ["init_parameters" ][key ] == value
103116
104- def test_to_dict_with_parameters (self , monkeypatch ):
117+ def test_to_dict_with_parameters (self , monkeypatch , calendar_event_model ):
105118 monkeypatch .setenv ("ENV_VAR" , "test-api-key" )
106119 component = STACKITChatGenerator (
107120 api_key = Secret .from_env_var ("ENV_VAR" ),
108121 model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" ,
109122 streaming_callback = print_streaming_chunk ,
110123 api_base_url = "test-base-url" ,
111- generation_kwargs = {"max_tokens" : 10 , "some_test_param" : "test-params" },
124+ generation_kwargs = {
125+ "max_tokens" : 10 ,
126+ "some_test_param" : "test-params" ,
127+ "response_format" : calendar_event_model ,
128+ },
112129 timeout = 10.0 ,
113130 max_retries = 2 ,
114131 http_client_kwargs = {"proxy" : "https://proxy.example.com:8080" },
@@ -125,7 +142,28 @@ def test_to_dict_with_parameters(self, monkeypatch):
125142 "model" : "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" ,
126143 "api_base_url" : "test-base-url" ,
127144 "streaming_callback" : "haystack.components.generators.utils.print_streaming_chunk" ,
128- "generation_kwargs" : {"max_tokens" : 10 , "some_test_param" : "test-params" },
145+ "generation_kwargs" : {
146+ "max_tokens" : 10 ,
147+ "some_test_param" : "test-params" ,
148+ "response_format" : {
149+ "type" : "json_schema" ,
150+ "json_schema" : {
151+ "name" : "CalendarEvent" ,
152+ "strict" : True ,
153+ "schema" : {
154+ "properties" : {
155+ "event_name" : {"title" : "Event Name" , "type" : "string" },
156+ "event_date" : {"title" : "Event Date" , "type" : "string" },
157+ "event_location" : {"title" : "Event Location" , "type" : "string" },
158+ },
159+ "required" : ["event_name" , "event_date" , "event_location" ],
160+ "title" : "CalendarEvent" ,
161+ "type" : "object" ,
162+ "additionalProperties" : False ,
163+ },
164+ },
165+ },
166+ },
129167 "timeout" : 10.0 ,
130168 "max_retries" : 2 ,
131169 "http_client_kwargs" : {"proxy" : "https://proxy.example.com:8080" },
@@ -254,3 +292,61 @@ def __call__(self, chunk: StreamingChunk) -> None:
254292
255293 assert callback .counter > 1
256294 assert "Paris" in callback .responses
295+
296+ @pytest .mark .skipif (
297+ not os .environ .get ("STACKIT_API_KEY" , None ),
298+ reason = "Export an env var called STACKIT_API_KEY containing the STACKIT API key to run this test." ,
299+ )
300+ @pytest .mark .integration
301+ def test_live_run_with_response_format_json_schema (self ):
302+ response_schema = {
303+ "type" : "json_schema" ,
304+ "json_schema" : {
305+ "name" : "CapitalCity" ,
306+ "strict" : True ,
307+ "schema" : {
308+ "title" : "CapitalCity" ,
309+ "type" : "object" ,
310+ "properties" : {
311+ "city" : {"title" : "City" , "type" : "string" },
312+ "country" : {"title" : "Country" , "type" : "string" },
313+ },
314+ "required" : ["city" , "country" ],
315+ "additionalProperties" : False ,
316+ },
317+ },
318+ }
319+
320+ chat_messages = [ChatMessage .from_user ("What's the capital of France?" )]
321+ comp = STACKITChatGenerator (
322+ model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" , generation_kwargs = {"response_format" : response_schema }
323+ )
324+ results = comp .run (chat_messages )
325+ assert len (results ["replies" ]) == 1
326+ message : ChatMessage = results ["replies" ][0 ]
327+ msg = json .loads (message .text )
328+ assert "Paris" in msg ["city" ]
329+ assert isinstance (msg ["country" ], str )
330+ assert "France" in msg ["country" ]
331+ assert message .meta ["finish_reason" ] == "stop"
332+
333+ @pytest .mark .skipif (
334+ not os .environ .get ("STACKIT_API_KEY" , None ),
335+ reason = "Export an env var called STACKIT_API_KEY containing the STACKIT API key to run this test." ,
336+ )
337+ @pytest .mark .integration
338+ def test_live_run_with_response_format_pydantic_model (self , calendar_event_model ):
339+ chat_messages = [
340+ ChatMessage .from_user ("The marketing summit takes place on October12th at the Hilton Hotel downtown." )
341+ ]
342+ component = STACKITChatGenerator (
343+ model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8" ,
344+ generation_kwargs = {"response_format" : calendar_event_model },
345+ )
346+ results = component .run (chat_messages )
347+ assert len (results ["replies" ]) == 1
348+ message : ChatMessage = results ["replies" ][0 ]
349+ msg = json .loads (message .text )
350+ assert "Marketing Summit" in msg ["event_name" ]
351+ assert isinstance (msg ["event_date" ], str )
352+ assert isinstance (msg ["event_location" ], str )
0 commit comments