Skip to content

Commit 37dae5f

Browse files
authored
test: Check that _convert_chat_completion_chunk_to_streaming_chunk works for MistralChatGenerator (#1953)
* Update and add a test * updates * Fix * Update other integration test to use two tool calls * Update deps * Finish todos * Updates * Revert change
1 parent 4a58714 commit 37dae5f

2 files changed

Lines changed: 188 additions & 56 deletions

File tree

integrations/mistral/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai>=2.13.0"]
26+
dependencies = ["haystack-ai>=2.15.1"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/mistral#readme"

integrations/mistral/tests/test_mistral_chat_generator.py

Lines changed: 187 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,37 @@
11
import os
22
from datetime import datetime
3-
from unittest.mock import patch
3+
from unittest.mock import ANY, patch
44

55
import pytest
66
import pytz
77
from haystack import Pipeline
88
from haystack.components.generators.utils import print_streaming_chunk
99
from haystack.components.tools import ToolInvoker
10-
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall
10+
from haystack.dataclasses import ChatMessage, ChatRole, ComponentInfo, StreamingChunk, ToolCall, ToolCallDelta
1111
from haystack.tools import Tool
1212
from haystack.utils.auth import Secret
1313
from openai import OpenAIError
14-
from openai.types.chat import ChatCompletion, ChatCompletionMessage
14+
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
1515
from openai.types.chat.chat_completion import Choice
16+
from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
17+
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
18+
from openai.types.completion_usage import CompletionUsage
1619

1720
from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator
1821

1922

23+
class CollectorCallback:
24+
"""
25+
Callback to collect streaming chunks for testing purposes.
26+
"""
27+
28+
def __init__(self):
29+
self.chunks = []
30+
31+
def __call__(self, chunk: StreamingChunk) -> None:
32+
self.chunks.append(chunk)
33+
34+
2035
@pytest.fixture
2136
def chat_messages():
2237
return [
@@ -179,6 +194,137 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
179194
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
180195
MistralChatGenerator.from_dict(data)
181196

197+
def test_handle_stream_response(self):
198+
mistral_chunks = [
199+
ChatCompletionChunk(
200+
id="76535283139540de943bc2036121d4c5",
201+
choices=[ChoiceChunk(delta=ChoiceDelta(content="", role="assistant"), index=0)],
202+
created=1750076261,
203+
model="mistral-small-latest",
204+
object="chat.completion.chunk",
205+
),
206+
ChatCompletionChunk(
207+
id="76535283139540de943bc2036121d4c5",
208+
choices=[
209+
ChoiceChunk(
210+
delta=ChoiceDelta(
211+
tool_calls=[
212+
ChoiceDeltaToolCall(
213+
index=0,
214+
id="FL1FFlqUG",
215+
function=ChoiceDeltaToolCallFunction(arguments='{"city": "Paris"}', name="weather"),
216+
),
217+
ChoiceDeltaToolCall(
218+
index=1,
219+
id="xSuhp66iB",
220+
function=ChoiceDeltaToolCallFunction(
221+
arguments='{"city": "Berlin"}', name="weather"
222+
),
223+
),
224+
],
225+
),
226+
finish_reason="tool_calls",
227+
index=0,
228+
)
229+
],
230+
created=1750076261,
231+
model="mistral-small-latest",
232+
object="chat.completion.chunk",
233+
usage=CompletionUsage(
234+
completion_tokens=35,
235+
prompt_tokens=77,
236+
total_tokens=112,
237+
),
238+
),
239+
]
240+
241+
collector_callback = CollectorCallback()
242+
llm = MistralChatGenerator(api_key=Secret.from_token("test-api-key"))
243+
result = llm._handle_stream_response(mistral_chunks, callback=collector_callback)[0] # type: ignore
244+
245+
# Verify the callback collected the expected number of chunks
246+
# We expect 2 chunks: one for the initial empty content and one for the tool calls
247+
assert len(collector_callback.chunks) == 2
248+
assert collector_callback.chunks[0] == StreamingChunk(
249+
content="",
250+
meta={
251+
"model": "mistral-small-latest",
252+
"index": 0,
253+
"tool_calls": None,
254+
"finish_reason": None,
255+
"received_at": ANY,
256+
"usage": None,
257+
},
258+
component_info=ComponentInfo(
259+
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
260+
name=None,
261+
),
262+
)
263+
assert collector_callback.chunks[1] == StreamingChunk(
264+
content="",
265+
meta={
266+
"model": "mistral-small-latest",
267+
"index": 0,
268+
"tool_calls": [
269+
ChoiceDeltaToolCall(
270+
index=0,
271+
id="FL1FFlqUG",
272+
function=ChoiceDeltaToolCallFunction(arguments='{"city": "Paris"}', name="weather"),
273+
),
274+
ChoiceDeltaToolCall(
275+
index=1,
276+
id="xSuhp66iB",
277+
function=ChoiceDeltaToolCallFunction(arguments='{"city": "Berlin"}', name="weather"),
278+
),
279+
],
280+
"finish_reason": "tool_calls",
281+
"received_at": ANY,
282+
"usage": {
283+
"completion_tokens": 35,
284+
"prompt_tokens": 77,
285+
"total_tokens": 112,
286+
"completion_tokens_details": None,
287+
"prompt_tokens_details": None,
288+
},
289+
},
290+
component_info=ComponentInfo(
291+
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
292+
name=None,
293+
),
294+
index=0,
295+
tool_calls=[
296+
ToolCallDelta(index=0, tool_name="weather", arguments='{"city": "Paris"}', id="FL1FFlqUG"),
297+
ToolCallDelta(index=1, tool_name="weather", arguments='{"city": "Berlin"}', id="xSuhp66iB"),
298+
],
299+
start=True,
300+
finish_reason="tool_calls",
301+
)
302+
303+
# Assert text is empty
304+
assert result.text is None
305+
306+
# Verify both tool calls were found and processed
307+
assert len(result.tool_calls) == 2
308+
assert result.tool_calls[0].id == "FL1FFlqUG"
309+
assert result.tool_calls[0].tool_name == "weather"
310+
assert result.tool_calls[0].arguments == {"city": "Paris"}
311+
assert result.tool_calls[1].id == "xSuhp66iB"
312+
assert result.tool_calls[1].tool_name == "weather"
313+
assert result.tool_calls[1].arguments == {"city": "Berlin"}
314+
315+
# Verify meta information
316+
assert result.meta["model"] == "mistral-small-latest"
317+
assert result.meta["finish_reason"] == "tool_calls"
318+
assert result.meta["index"] == 0
319+
assert result.meta["completion_start_time"] is not None
320+
assert result.meta["usage"] == {
321+
"completion_tokens": 35,
322+
"prompt_tokens": 77,
323+
"total_tokens": 112,
324+
"completion_tokens_details": None,
325+
"prompt_tokens_details": None,
326+
}
327+
182328
def test_run(self, chat_messages, mock_chat_completion, monkeypatch): # noqa: ARG002
183329
monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key")
184330
component = MistralChatGenerator()
@@ -291,42 +437,44 @@ def test_live_run_with_tools_and_response(self, tools):
291437
"""
292438
Integration test that the MistralChatGenerator component can run with tools and get a response.
293439
"""
294-
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
440+
initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")]
295441
component = MistralChatGenerator(tools=tools)
296442
results = component.run(messages=initial_messages, generation_kwargs={"tool_choice": "any"})
297443

298-
assert len(results["replies"]) > 0, "No replies received"
444+
assert len(results["replies"]) == 1
299445

300446
# Find the message with tool calls
301-
tool_message = None
302-
for message in results["replies"]:
303-
if message.tool_call:
304-
tool_message = message
305-
break
306-
307-
assert tool_message is not None, "No message with tool call found"
308-
assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance"
309-
assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant"
310-
311-
tool_call = tool_message.tool_call
312-
assert tool_call.id, "Tool call does not contain value for 'id' key"
313-
assert tool_call.tool_name == "weather"
314-
assert tool_call.arguments == {"city": "Paris"}
447+
tool_message = results["replies"][0]
448+
449+
assert isinstance(tool_message, ChatMessage)
450+
tool_calls = tool_message.tool_calls
451+
assert len(tool_calls) == 2
452+
assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT)
453+
454+
for tool_call in tool_calls:
455+
assert tool_call.id is not None
456+
assert isinstance(tool_call, ToolCall)
457+
assert tool_call.tool_name == "weather"
458+
459+
arguments = [tool_call.arguments for tool_call in tool_calls]
460+
assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}]
315461
assert tool_message.meta["finish_reason"] == "tool_calls"
316462

317463
new_messages = [
318464
initial_messages[0],
319465
tool_message,
320-
ChatMessage.from_tool(tool_result="22° C", origin=tool_call),
466+
ChatMessage.from_tool(tool_result="22° C and sunny", origin=tool_calls[0]),
467+
ChatMessage.from_tool(tool_result="16° C and windy", origin=tool_calls[1]),
321468
]
322469
# Pass the tool result to the model to get the final response
323470
results = component.run(new_messages)
324471

325472
assert len(results["replies"]) == 1
326473
final_message = results["replies"][0]
327-
assert not final_message.tool_call
474+
assert final_message.is_from(ChatRole.ASSISTANT)
328475
assert len(final_message.text) > 0
329476
assert "paris" in final_message.text.lower()
477+
assert "berlin" in final_message.text.lower()
330478

331479
@pytest.mark.skipif(
332480
not os.environ.get("MISTRAL_API_KEY", None),
@@ -337,45 +485,29 @@ def test_live_run_with_tools_streaming(self, tools):
337485
"""
338486
Integration test that the MistralChatGenerator component can run with tools and streaming.
339487
"""
340-
341-
class Callback:
342-
def __init__(self):
343-
self.responses = ""
344-
self.counter = 0
345-
self.tool_calls = []
346-
347-
def __call__(self, chunk: StreamingChunk) -> None:
348-
self.counter += 1
349-
if chunk.content:
350-
self.responses += chunk.content
351-
if chunk.meta.get("tool_calls"):
352-
self.tool_calls.extend(chunk.meta["tool_calls"])
353-
354-
callback = Callback()
355-
component = MistralChatGenerator(tools=tools, streaming_callback=callback)
488+
component = MistralChatGenerator(tools=tools, streaming_callback=print_streaming_chunk)
356489
results = component.run(
357-
[ChatMessage.from_user("What's the weather like in Paris?")], generation_kwargs={"tool_choice": "any"}
490+
[ChatMessage.from_user("What's the weather like in Paris and Berlin?")],
491+
generation_kwargs={"tool_choice": "any"},
358492
)
359493

360-
assert len(results["replies"]) > 0, "No replies received"
361-
assert callback.counter > 1, "Streaming callback was not called multiple times"
362-
assert callback.tool_calls, "No tool calls received in streaming"
494+
assert len(results["replies"]) == 1
363495

364496
# Find the message with tool calls
365-
tool_message = None
366-
for message in results["replies"]:
367-
if message.tool_call:
368-
tool_message = message
369-
break
370-
371-
assert tool_message is not None, "No message with tool call found"
372-
assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance"
373-
assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant"
374-
375-
tool_call = tool_message.tool_call
376-
assert tool_call.id, "Tool call does not contain value for 'id' key"
377-
assert tool_call.tool_name == "weather"
378-
assert tool_call.arguments == {"city": "Paris"}
497+
tool_message = results["replies"][0]
498+
499+
assert isinstance(tool_message, ChatMessage)
500+
tool_calls = tool_message.tool_calls
501+
assert len(tool_calls) == 2
502+
assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT)
503+
504+
for tool_call in tool_calls:
505+
assert tool_call.id is not None
506+
assert isinstance(tool_call, ToolCall)
507+
assert tool_call.tool_name == "weather"
508+
509+
arguments = [tool_call.arguments for tool_call in tool_calls]
510+
assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}]
379511
assert tool_message.meta["finish_reason"] == "tool_calls"
380512

381513
@pytest.mark.skipif(

0 commit comments

Comments
 (0)