diff --git a/integrations/mistral/pyproject.toml b/integrations/mistral/pyproject.toml index c62e279f89..611806255c 100644 --- a/integrations/mistral/pyproject.toml +++ b/integrations/mistral/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.13.0"] +dependencies = ["haystack-ai>=2.15.1"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/mistral#readme" diff --git a/integrations/mistral/tests/test_mistral_chat_generator.py b/integrations/mistral/tests/test_mistral_chat_generator.py index ac868505e5..8487123a32 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator.py +++ b/integrations/mistral/tests/test_mistral_chat_generator.py @@ -1,22 +1,37 @@ import os from datetime import datetime -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest import pytz from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk from haystack.components.tools import ToolInvoker -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall +from haystack.dataclasses import ChatMessage, ChatRole, ComponentInfo, StreamingChunk, ToolCall, ToolCallDelta from haystack.tools import Tool from haystack.utils.auth import Secret from openai import OpenAIError -from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk +from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction +from openai.types.completion_usage import CompletionUsage from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator +class CollectorCallback: + """ + Callback to collect streaming chunks for testing purposes. + """ + + def __init__(self): + self.chunks = [] + + def __call__(self, chunk: StreamingChunk) -> None: + self.chunks.append(chunk) + + @pytest.fixture def chat_messages(): return [ @@ -179,6 +194,137 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): with pytest.raises(ValueError, match="None of the .* environment variables are set"): MistralChatGenerator.from_dict(data) + def test_handle_stream_response(self): + mistral_chunks = [ + ChatCompletionChunk( + id="76535283139540de943bc2036121d4c5", + choices=[ChoiceChunk(delta=ChoiceDelta(content="", role="assistant"), index=0)], + created=1750076261, + model="mistral-small-latest", + object="chat.completion.chunk", + ), + ChatCompletionChunk( + id="76535283139540de943bc2036121d4c5", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="FL1FFlqUG", + function=ChoiceDeltaToolCallFunction(arguments='{"city": "Paris"}', name="weather"), + ), + ChoiceDeltaToolCall( + index=1, + id="xSuhp66iB", + function=ChoiceDeltaToolCallFunction( + arguments='{"city": "Berlin"}', name="weather" + ), + ), + ], + ), + finish_reason="tool_calls", + index=0, + ) + ], + created=1750076261, + model="mistral-small-latest", + object="chat.completion.chunk", + usage=CompletionUsage( + completion_tokens=35, + prompt_tokens=77, + total_tokens=112, + ), + ), + ] + + collector_callback = CollectorCallback() + llm = MistralChatGenerator(api_key=Secret.from_token("test-api-key")) + result = llm._handle_stream_response(mistral_chunks, callback=collector_callback)[0] # type: ignore + + # Verify the callback collected the expected number of chunks + # We expect 2 chunks: one for the initial empty content and one for the tool calls + assert len(collector_callback.chunks) == 2 + assert collector_callback.chunks[0] == StreamingChunk( + content="", + meta={ + "model": "mistral-small-latest", + "index": 0, + "tool_calls": None, + "finish_reason": None, + "received_at": ANY, + "usage": None, + }, + component_info=ComponentInfo( + type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator", + name=None, + ), + ) + assert collector_callback.chunks[1] == StreamingChunk( + content="", + meta={ + "model": "mistral-small-latest", + "index": 0, + "tool_calls": [ + ChoiceDeltaToolCall( + index=0, + id="FL1FFlqUG", + function=ChoiceDeltaToolCallFunction(arguments='{"city": "Paris"}', name="weather"), + ), + ChoiceDeltaToolCall( + index=1, + id="xSuhp66iB", + function=ChoiceDeltaToolCallFunction(arguments='{"city": "Berlin"}', name="weather"), + ), + ], + "finish_reason": "tool_calls", + "received_at": ANY, + "usage": { + "completion_tokens": 35, + "prompt_tokens": 77, + "total_tokens": 112, + "completion_tokens_details": None, + "prompt_tokens_details": None, + }, + }, + component_info=ComponentInfo( + type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator", + name=None, + ), + index=0, + tool_calls=[ + ToolCallDelta(index=0, tool_name="weather", arguments='{"city": "Paris"}', id="FL1FFlqUG"), + ToolCallDelta(index=1, tool_name="weather", arguments='{"city": "Berlin"}', id="xSuhp66iB"), + ], + start=True, + finish_reason="tool_calls", + ) + + # Assert text is empty + assert result.text is None + + # Verify both tool calls were found and processed + assert len(result.tool_calls) == 2 + assert result.tool_calls[0].id == "FL1FFlqUG" + assert result.tool_calls[0].tool_name == "weather" + assert result.tool_calls[0].arguments == {"city": "Paris"} + assert result.tool_calls[1].id == "xSuhp66iB" + assert result.tool_calls[1].tool_name == "weather" + assert result.tool_calls[1].arguments == {"city": "Berlin"} + + # Verify meta information + assert result.meta["model"] == "mistral-small-latest" + assert result.meta["finish_reason"] == "tool_calls" + assert result.meta["index"] == 0 + assert result.meta["completion_start_time"] is not None + assert result.meta["usage"] == { + "completion_tokens": 35, + "prompt_tokens": 77, + "total_tokens": 112, + "completion_tokens_details": None, + "prompt_tokens_details": None, + } + def test_run(self, chat_messages, mock_chat_completion, monkeypatch): # noqa: ARG002 monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") component = MistralChatGenerator() @@ -291,42 +437,44 @@ def test_live_run_with_tools_and_response(self, tools): """ Integration test that the MistralChatGenerator component can run with tools and get a response. """ - initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")] component = MistralChatGenerator(tools=tools) results = component.run(messages=initial_messages, generation_kwargs={"tool_choice": "any"}) - assert len(results["replies"]) > 0, "No replies received" + assert len(results["replies"]) == 1 # Find the message with tool calls - tool_message = None - for message in results["replies"]: - if message.tool_call: - tool_message = message - break - - assert tool_message is not None, "No message with tool call found" - assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" - assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" - - tool_call = tool_message.tool_call - assert tool_call.id, "Tool call does not contain value for 'id' key" - assert tool_call.tool_name == "weather" - assert tool_call.arguments == {"city": "Paris"} + tool_message = results["replies"][0] + + assert isinstance(tool_message, ChatMessage) + tool_calls = tool_message.tool_calls + assert len(tool_calls) == 2 + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT) + + for tool_call in tool_calls: + assert tool_call.id is not None + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + + arguments = [tool_call.arguments for tool_call in tool_calls] + assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}] assert tool_message.meta["finish_reason"] == "tool_calls" new_messages = [ initial_messages[0], tool_message, - ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ChatMessage.from_tool(tool_result="22° C and sunny", origin=tool_calls[0]), + ChatMessage.from_tool(tool_result="16° C and windy", origin=tool_calls[1]), ] # Pass the tool result to the model to get the final response results = component.run(new_messages) assert len(results["replies"]) == 1 final_message = results["replies"][0] - assert not final_message.tool_call + assert final_message.is_from(ChatRole.ASSISTANT) assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() + assert "berlin" in final_message.text.lower() @pytest.mark.skipif( not os.environ.get("MISTRAL_API_KEY", None), @@ -337,45 +485,29 @@ def test_live_run_with_tools_streaming(self, tools): """ Integration test that the MistralChatGenerator component can run with tools and streaming. """ - - class Callback: - def __init__(self): - self.responses = "" - self.counter = 0 - self.tool_calls = [] - - def __call__(self, chunk: StreamingChunk) -> None: - self.counter += 1 - if chunk.content: - self.responses += chunk.content - if chunk.meta.get("tool_calls"): - self.tool_calls.extend(chunk.meta["tool_calls"]) - - callback = Callback() - component = MistralChatGenerator(tools=tools, streaming_callback=callback) + component = MistralChatGenerator(tools=tools, streaming_callback=print_streaming_chunk) results = component.run( - [ChatMessage.from_user("What's the weather like in Paris?")], generation_kwargs={"tool_choice": "any"} + [ChatMessage.from_user("What's the weather like in Paris and Berlin?")], + generation_kwargs={"tool_choice": "any"}, ) - assert len(results["replies"]) > 0, "No replies received" - assert callback.counter > 1, "Streaming callback was not called multiple times" - assert callback.tool_calls, "No tool calls received in streaming" + assert len(results["replies"]) == 1 # Find the message with tool calls - tool_message = None - for message in results["replies"]: - if message.tool_call: - tool_message = message - break - - assert tool_message is not None, "No message with tool call found" - assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" - assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" - - tool_call = tool_message.tool_call - assert tool_call.id, "Tool call does not contain value for 'id' key" - assert tool_call.tool_name == "weather" - assert tool_call.arguments == {"city": "Paris"} + tool_message = results["replies"][0] + + assert isinstance(tool_message, ChatMessage) + tool_calls = tool_message.tool_calls + assert len(tool_calls) == 2 + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT) + + for tool_call in tool_calls: + assert tool_call.id is not None + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + + arguments = [tool_call.arguments for tool_call in tool_calls] + assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}] assert tool_message.meta["finish_reason"] == "tool_calls" @pytest.mark.skipif(