11import os
22from datetime import datetime
3- from unittest .mock import patch
3+ from unittest .mock import ANY , patch
44
55import pytest
66import pytz
77from haystack import Pipeline
88from haystack .components .generators .utils import print_streaming_chunk
99from haystack .components .tools import ToolInvoker
10- from haystack .dataclasses import ChatMessage , ChatRole , StreamingChunk , ToolCall
10+ from haystack .dataclasses import ChatMessage , ChatRole , ComponentInfo , StreamingChunk , ToolCall , ToolCallDelta
1111from haystack .tools import Tool
1212from haystack .utils .auth import Secret
1313from openai import OpenAIError
14- from openai .types .chat import ChatCompletion , ChatCompletionMessage
14+ from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessage
1515from openai .types .chat .chat_completion import Choice
16+ from openai .types .chat .chat_completion_chunk import Choice as ChoiceChunk
17+ from openai .types .chat .chat_completion_chunk import ChoiceDelta , ChoiceDeltaToolCall , ChoiceDeltaToolCallFunction
18+ from openai .types .completion_usage import CompletionUsage
1619
1720from haystack_integrations .components .generators .mistral .chat .chat_generator import MistralChatGenerator
1821
1922
23+ class CollectorCallback :
24+ """
25+ Callback to collect streaming chunks for testing purposes.
26+ """
27+
28+ def __init__ (self ):
29+ self .chunks = []
30+
31+ def __call__ (self , chunk : StreamingChunk ) -> None :
32+ self .chunks .append (chunk )
33+
34+
2035@pytest .fixture
2136def chat_messages ():
2237 return [
@@ -179,6 +194,137 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
179194 with pytest .raises (ValueError , match = "None of the .* environment variables are set" ):
180195 MistralChatGenerator .from_dict (data )
181196
197+ def test_handle_stream_response (self ):
198+ mistral_chunks = [
199+ ChatCompletionChunk (
200+ id = "76535283139540de943bc2036121d4c5" ,
201+ choices = [ChoiceChunk (delta = ChoiceDelta (content = "" , role = "assistant" ), index = 0 )],
202+ created = 1750076261 ,
203+ model = "mistral-small-latest" ,
204+ object = "chat.completion.chunk" ,
205+ ),
206+ ChatCompletionChunk (
207+ id = "76535283139540de943bc2036121d4c5" ,
208+ choices = [
209+ ChoiceChunk (
210+ delta = ChoiceDelta (
211+ tool_calls = [
212+ ChoiceDeltaToolCall (
213+ index = 0 ,
214+ id = "FL1FFlqUG" ,
215+ function = ChoiceDeltaToolCallFunction (arguments = '{"city": "Paris"}' , name = "weather" ),
216+ ),
217+ ChoiceDeltaToolCall (
218+ index = 1 ,
219+ id = "xSuhp66iB" ,
220+ function = ChoiceDeltaToolCallFunction (
221+ arguments = '{"city": "Berlin"}' , name = "weather"
222+ ),
223+ ),
224+ ],
225+ ),
226+ finish_reason = "tool_calls" ,
227+ index = 0 ,
228+ )
229+ ],
230+ created = 1750076261 ,
231+ model = "mistral-small-latest" ,
232+ object = "chat.completion.chunk" ,
233+ usage = CompletionUsage (
234+ completion_tokens = 35 ,
235+ prompt_tokens = 77 ,
236+ total_tokens = 112 ,
237+ ),
238+ ),
239+ ]
240+
241+ collector_callback = CollectorCallback ()
242+ llm = MistralChatGenerator (api_key = Secret .from_token ("test-api-key" ))
243+ result = llm ._handle_stream_response (mistral_chunks , callback = collector_callback )[0 ] # type: ignore
244+
245+ # Verify the callback collected the expected number of chunks
246+ # We expect 2 chunks: one for the initial empty content and one for the tool calls
247+ assert len (collector_callback .chunks ) == 2
248+ assert collector_callback .chunks [0 ] == StreamingChunk (
249+ content = "" ,
250+ meta = {
251+ "model" : "mistral-small-latest" ,
252+ "index" : 0 ,
253+ "tool_calls" : None ,
254+ "finish_reason" : None ,
255+ "received_at" : ANY ,
256+ "usage" : None ,
257+ },
258+ component_info = ComponentInfo (
259+ type = "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator" ,
260+ name = None ,
261+ ),
262+ )
263+ assert collector_callback .chunks [1 ] == StreamingChunk (
264+ content = "" ,
265+ meta = {
266+ "model" : "mistral-small-latest" ,
267+ "index" : 0 ,
268+ "tool_calls" : [
269+ ChoiceDeltaToolCall (
270+ index = 0 ,
271+ id = "FL1FFlqUG" ,
272+ function = ChoiceDeltaToolCallFunction (arguments = '{"city": "Paris"}' , name = "weather" ),
273+ ),
274+ ChoiceDeltaToolCall (
275+ index = 1 ,
276+ id = "xSuhp66iB" ,
277+ function = ChoiceDeltaToolCallFunction (arguments = '{"city": "Berlin"}' , name = "weather" ),
278+ ),
279+ ],
280+ "finish_reason" : "tool_calls" ,
281+ "received_at" : ANY ,
282+ "usage" : {
283+ "completion_tokens" : 35 ,
284+ "prompt_tokens" : 77 ,
285+ "total_tokens" : 112 ,
286+ "completion_tokens_details" : None ,
287+ "prompt_tokens_details" : None ,
288+ },
289+ },
290+ component_info = ComponentInfo (
291+ type = "haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator" ,
292+ name = None ,
293+ ),
294+ index = 0 ,
295+ tool_calls = [
296+ ToolCallDelta (index = 0 , tool_name = "weather" , arguments = '{"city": "Paris"}' , id = "FL1FFlqUG" ),
297+ ToolCallDelta (index = 1 , tool_name = "weather" , arguments = '{"city": "Berlin"}' , id = "xSuhp66iB" ),
298+ ],
299+ start = True ,
300+ finish_reason = "tool_calls" ,
301+ )
302+
303+ # Assert text is empty
304+ assert result .text is None
305+
306+ # Verify both tool calls were found and processed
307+ assert len (result .tool_calls ) == 2
308+ assert result .tool_calls [0 ].id == "FL1FFlqUG"
309+ assert result .tool_calls [0 ].tool_name == "weather"
310+ assert result .tool_calls [0 ].arguments == {"city" : "Paris" }
311+ assert result .tool_calls [1 ].id == "xSuhp66iB"
312+ assert result .tool_calls [1 ].tool_name == "weather"
313+ assert result .tool_calls [1 ].arguments == {"city" : "Berlin" }
314+
315+ # Verify meta information
316+ assert result .meta ["model" ] == "mistral-small-latest"
317+ assert result .meta ["finish_reason" ] == "tool_calls"
318+ assert result .meta ["index" ] == 0
319+ assert result .meta ["completion_start_time" ] is not None
320+ assert result .meta ["usage" ] == {
321+ "completion_tokens" : 35 ,
322+ "prompt_tokens" : 77 ,
323+ "total_tokens" : 112 ,
324+ "completion_tokens_details" : None ,
325+ "prompt_tokens_details" : None ,
326+ }
327+
182328 def test_run (self , chat_messages , mock_chat_completion , monkeypatch ): # noqa: ARG002
183329 monkeypatch .setenv ("MISTRAL_API_KEY" , "fake-api-key" )
184330 component = MistralChatGenerator ()
@@ -291,42 +437,44 @@ def test_live_run_with_tools_and_response(self, tools):
291437 """
292438 Integration test that the MistralChatGenerator component can run with tools and get a response.
293439 """
294- initial_messages = [ChatMessage .from_user ("What's the weather like in Paris?" )]
440+ initial_messages = [ChatMessage .from_user ("What's the weather like in Paris and Berlin ?" )]
295441 component = MistralChatGenerator (tools = tools )
296442 results = component .run (messages = initial_messages , generation_kwargs = {"tool_choice" : "any" })
297443
298- assert len (results ["replies" ]) > 0 , "No replies received"
444+ assert len (results ["replies" ]) == 1
299445
300446 # Find the message with tool calls
301- tool_message = None
302- for message in results [ "replies" ]:
303- if message . tool_call :
304- tool_message = message
305- break
306-
307- assert tool_message is not None , "No message with tool call found"
308- assert isinstance ( tool_message , ChatMessage ), "Tool message is not a ChatMessage instance"
309- assert ChatMessage . is_from ( tool_message , ChatRole . ASSISTANT ), "Tool message is not from the assistant"
310-
311- tool_call = tool_message . tool_call
312- assert tool_call . id , "Tool call does not contain value for 'id' key"
313- assert tool_call .tool_name == "weather"
314- assert tool_call . arguments == {"city" : "Paris" }
447+ tool_message = results [ "replies" ][ 0 ]
448+
449+ assert isinstance ( tool_message , ChatMessage )
450+ tool_calls = tool_message . tool_calls
451+ assert len ( tool_calls ) == 2
452+ assert ChatMessage . is_from ( tool_message , ChatRole . ASSISTANT )
453+
454+ for tool_call in tool_calls :
455+ assert tool_call . id is not None
456+ assert isinstance ( tool_call , ToolCall )
457+ assert tool_call . tool_name == "weather"
458+
459+ arguments = [ tool_call .arguments for tool_call in tool_calls ]
460+ assert sorted ( arguments , key = lambda x : x [ "city" ]) == [ {"city" : "Berlin" }, { "city" : " Paris" }]
315461 assert tool_message .meta ["finish_reason" ] == "tool_calls"
316462
317463 new_messages = [
318464 initial_messages [0 ],
319465 tool_message ,
320- ChatMessage .from_tool (tool_result = "22° C" , origin = tool_call ),
466+ ChatMessage .from_tool (tool_result = "22° C and sunny" , origin = tool_calls [0 ]),
467+ ChatMessage .from_tool (tool_result = "16° C and windy" , origin = tool_calls [1 ]),
321468 ]
322469 # Pass the tool result to the model to get the final response
323470 results = component .run (new_messages )
324471
325472 assert len (results ["replies" ]) == 1
326473 final_message = results ["replies" ][0 ]
327- assert not final_message .tool_call
474+ assert final_message .is_from ( ChatRole . ASSISTANT )
328475 assert len (final_message .text ) > 0
329476 assert "paris" in final_message .text .lower ()
477+ assert "berlin" in final_message .text .lower ()
330478
331479 @pytest .mark .skipif (
332480 not os .environ .get ("MISTRAL_API_KEY" , None ),
@@ -337,45 +485,29 @@ def test_live_run_with_tools_streaming(self, tools):
337485 """
338486 Integration test that the MistralChatGenerator component can run with tools and streaming.
339487 """
340-
341- class Callback :
342- def __init__ (self ):
343- self .responses = ""
344- self .counter = 0
345- self .tool_calls = []
346-
347- def __call__ (self , chunk : StreamingChunk ) -> None :
348- self .counter += 1
349- if chunk .content :
350- self .responses += chunk .content
351- if chunk .meta .get ("tool_calls" ):
352- self .tool_calls .extend (chunk .meta ["tool_calls" ])
353-
354- callback = Callback ()
355- component = MistralChatGenerator (tools = tools , streaming_callback = callback )
488+ component = MistralChatGenerator (tools = tools , streaming_callback = print_streaming_chunk )
356489 results = component .run (
357- [ChatMessage .from_user ("What's the weather like in Paris?" )], generation_kwargs = {"tool_choice" : "any" }
490+ [ChatMessage .from_user ("What's the weather like in Paris and Berlin?" )],
491+ generation_kwargs = {"tool_choice" : "any" },
358492 )
359493
360- assert len (results ["replies" ]) > 0 , "No replies received"
361- assert callback .counter > 1 , "Streaming callback was not called multiple times"
362- assert callback .tool_calls , "No tool calls received in streaming"
494+ assert len (results ["replies" ]) == 1
363495
364496 # Find the message with tool calls
365- tool_message = None
366- for message in results [ "replies" ]:
367- if message . tool_call :
368- tool_message = message
369- break
370-
371- assert tool_message is not None , "No message with tool call found"
372- assert isinstance ( tool_message , ChatMessage ), "Tool message is not a ChatMessage instance"
373- assert ChatMessage . is_from ( tool_message , ChatRole . ASSISTANT ), "Tool message is not from the assistant"
374-
375- tool_call = tool_message . tool_call
376- assert tool_call . id , "Tool call does not contain value for 'id' key"
377- assert tool_call .tool_name == "weather"
378- assert tool_call . arguments == {"city" : "Paris" }
497+ tool_message = results [ "replies" ][ 0 ]
498+
499+ assert isinstance ( tool_message , ChatMessage )
500+ tool_calls = tool_message . tool_calls
501+ assert len ( tool_calls ) == 2
502+ assert ChatMessage . is_from ( tool_message , ChatRole . ASSISTANT )
503+
504+ for tool_call in tool_calls :
505+ assert tool_call . id is not None
506+ assert isinstance ( tool_call , ToolCall )
507+ assert tool_call . tool_name == "weather"
508+
509+ arguments = [ tool_call .arguments for tool_call in tool_calls ]
510+ assert sorted ( arguments , key = lambda x : x [ "city" ]) == [ {"city" : "Berlin" }, { "city" : " Paris" }]
379511 assert tool_message .meta ["finish_reason" ] == "tool_calls"
380512
381513 @pytest .mark .skipif (
0 commit comments