1717from openai .types .chat .chat_completion_chunk import Choice as ChoiceChunk
1818from openai .types .chat .chat_completion_chunk import ChoiceDelta , ChoiceDeltaToolCall , ChoiceDeltaToolCallFunction
1919from openai .types .completion_usage import CompletionUsage
20+ from pydantic import BaseModel
2021
2122from haystack_integrations .components .generators .mistral .chat .chat_generator import MistralChatGenerator
2223
@@ -136,12 +137,44 @@ def test_to_dict_default(self, monkeypatch):
136137
137138 def test_to_dict_with_parameters (self , monkeypatch ):
138139 monkeypatch .setenv ("ENV_VAR" , "test-api-key" )
140+
141+ class NobelPrizeInfo (BaseModel ):
142+ recipient_name : str
143+ award_year : int
144+
145+ schema = {
146+ "json_schema" : {
147+ "name" : "NobelPrizeInfo" ,
148+ "schema" : {
149+ "additionalProperties" : False ,
150+ "properties" : {
151+ "award_year" : {
152+ "title" : "Award Year" ,
153+ "type" : "integer" ,
154+ },
155+ "recipient_name" : {
156+ "title" : "Recipient Name" ,
157+ "type" : "string" ,
158+ },
159+ },
160+ "required" : [
161+ "recipient_name" ,
162+ "award_year" ,
163+ ],
164+ "title" : "NobelPrizeInfo" ,
165+ "type" : "object" ,
166+ },
167+ "strict" : True ,
168+ },
169+ "type" : "json_schema" ,
170+ }
171+
139172 component = MistralChatGenerator (
140173 api_key = Secret .from_env_var ("ENV_VAR" ),
141174 model = "mistral-small" ,
142175 streaming_callback = print_streaming_chunk ,
143176 api_base_url = "test-base-url" ,
144- generation_kwargs = {"max_tokens" : 10 , "some_test_param" : "test-params" },
177+ generation_kwargs = {"max_tokens" : 10 , "some_test_param" : "test-params" , "response_format" : NobelPrizeInfo },
145178 )
146179 data = component .to_dict ()
147180
@@ -155,7 +188,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
155188 "model" : "mistral-small" ,
156189 "api_base_url" : "test-base-url" ,
157190 "streaming_callback" : "haystack.components.generators.utils.print_streaming_chunk" ,
158- "generation_kwargs" : {"max_tokens" : 10 , "some_test_param" : "test-params" },
191+ "generation_kwargs" : {"max_tokens" : 10 , "some_test_param" : "test-params" , "response_format" : schema },
159192 }
160193
161194 for key , value in expected_params .items ():
@@ -357,7 +390,7 @@ def test_run_with_params(self, chat_messages, mock_chat_completion, monkeypatch)
357390
358391 @pytest .mark .skipif (
359392 not os .environ .get ("MISTRAL_API_KEY" , None ),
360- reason = "Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test." ,
393+ reason = "Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test." ,
361394 )
362395 @pytest .mark .integration
363396 def test_live_run (self ):
@@ -372,7 +405,7 @@ def test_live_run(self):
372405
373406 @pytest .mark .skipif (
374407 not os .environ .get ("MISTRAL_API_KEY" , None ),
375- reason = "Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test." ,
408+ reason = "Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test." ,
376409 )
377410 @pytest .mark .integration
378411 def test_live_run_wrong_model (self , chat_messages ):
@@ -382,7 +415,7 @@ def test_live_run_wrong_model(self, chat_messages):
382415
383416 @pytest .mark .skipif (
384417 not os .environ .get ("MISTRAL_API_KEY" , None ),
385- reason = "Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test." ,
418+ reason = "Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test." ,
386419 )
387420 @pytest .mark .integration
388421 def test_live_run_streaming (self ):
@@ -411,17 +444,25 @@ def __call__(self, chunk: StreamingChunk) -> None:
411444
412445 @pytest .mark .skipif (
413446 not os .environ .get ("MISTRAL_API_KEY" , None ),
414- reason = "Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test." ,
447+ reason = "Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test." ,
415448 )
416449 @pytest .mark .integration
417450 def test_live_run_response_format (self ):
451+ class NobelPrizeInfo (BaseModel ):
452+ recipient_name : str
453+ award_year : int
454+ category : str
455+ achievement_description : str
456+ nationality : str
457+
418458 chat_messages = [
419459 ChatMessage .from_user (
420- 'Provide the answer in JSON format with a key "answer". What\' s the capital of France?'
421- 'For example, respond with {"answer": "Paris"}.'
460+ "In 2021, American scientist David Julius received the Nobel Prize in"
461+ " Physiology or Medicine for his groundbreaking discoveries on how the human body"
462+ " senses temperature and touch."
422463 )
423464 ]
424- component = MistralChatGenerator (generation_kwargs = {"response_format" : { "type" : "json_object" } })
465+ component = MistralChatGenerator (generation_kwargs = {"response_format" : NobelPrizeInfo })
425466 results = component .run (chat_messages )
426467 assert isinstance (results , dict )
427468 assert "replies" in results
@@ -430,13 +471,51 @@ def test_live_run_response_format(self):
430471 assert isinstance (results ["replies" ][0 ], ChatMessage )
431472 message = results ["replies" ][0 ]
432473 assert isinstance (message .text , str )
433- assert "paris" in message .text .lower ()
434474 msg = json .loads (message .text )
435- assert "answer" in msg
475+ assert msg ["recipient_name" ] == "David Julius"
476+ assert msg ["award_year" ] == 2021
477+ assert "category" in msg
478+ assert "achievement_description" in msg
479+ assert msg ["nationality" ] == "American"
436480
437481 @pytest .mark .skipif (
438482 not os .environ .get ("MISTRAL_API_KEY" , None ),
439- reason = "Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test." ,
483+ reason = "Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test." ,
484+ )
485+ @pytest .mark .integration
486+ def test_live_run_with_response_format_json_schema (self ):
487+ response_schema = {
488+ "type" : "json_schema" ,
489+ "json_schema" : {
490+ "name" : "CapitalCity" ,
491+ "strict" : True ,
492+ "schema" : {
493+ "title" : "CapitalCity" ,
494+ "type" : "object" ,
495+ "properties" : {
496+ "city" : {"title" : "City" , "type" : "string" },
497+ "country" : {"title" : "Country" , "type" : "string" },
498+ },
499+ "required" : ["city" , "country" ],
500+ "additionalProperties" : False ,
501+ },
502+ },
503+ }
504+
505+ chat_messages = [ChatMessage .from_user ("What's the capital of France?" )]
506+ comp = MistralChatGenerator (generation_kwargs = {"response_format" : response_schema })
507+ results = comp .run (chat_messages )
508+ assert len (results ["replies" ]) == 1
509+ message : ChatMessage = results ["replies" ][0 ]
510+ msg = json .loads (message .text )
511+ assert "Paris" in msg ["city" ]
512+ assert isinstance (msg ["country" ], str )
513+ assert "France" in msg ["country" ]
514+ assert message .meta ["finish_reason" ] == "stop"
515+
516+ @pytest .mark .skipif (
517+ not os .environ .get ("MISTRAL_API_KEY" , None ),
518+ reason = "Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test." ,
440519 )
441520 @pytest .mark .integration
442521 def test_live_run_with_tools (self , tools ):
@@ -456,7 +535,7 @@ def test_live_run_with_tools(self, tools):
456535
457536 @pytest .mark .skipif (
458537 not os .environ .get ("MISTRAL_API_KEY" , None ),
459- reason = "Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test." ,
538+ reason = "Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test." ,
460539 )
461540 @pytest .mark .integration
462541 def test_live_run_with_tools_and_response (self , tools ):
@@ -504,7 +583,7 @@ def test_live_run_with_tools_and_response(self, tools):
504583
505584 @pytest .mark .skipif (
506585 not os .environ .get ("MISTRAL_API_KEY" , None ),
507- reason = "Export an env var called MISTRAL_API_KEY containing the OpenAI API key to run this test." ,
586+ reason = "Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test." ,
508587 )
509588 @pytest .mark .integration
510589 def test_live_run_with_tools_streaming (self , tools ):
0 commit comments