diff --git a/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py b/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py index 7db80e5ae2..b3ba353d3e 100644 --- a/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py +++ b/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py @@ -2,11 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 +import json from typing import Any, ClassVar from haystack import component, default_to_dict, logging from haystack.components.generators.chat import OpenAIChatGenerator -from haystack.dataclasses import ChatMessage, StreamingCallbackT +from haystack.components.generators.chat.openai import _check_finish_reason +from haystack.dataclasses import ( + ChatMessage, + ReasoningContent, + StreamingCallbackT, + ToolCall, + select_streaming_callback, +) from haystack.tools import ToolsType, serialize_tools_or_toolset from haystack.utils import serialize_callable from haystack.utils.auth import Secret @@ -16,6 +24,75 @@ logger = logging.getLogger(__name__) +def _parse_mistral_content(content: Any) -> tuple[str | None, ReasoningContent | None]: + """Parse Mistral message content which can be a string or an array of typed blocks.""" + if content is None: + return None, None + if isinstance(content, str): + return content or None, None + if not isinstance(content, list): + return str(content), None + + text_parts: list[str] = [] + thinking_parts: list[str] = [] + + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type", "") + if block_type == "thinking": + for item in block.get("thinking", []): + if isinstance(item, dict) and item.get("type") == "text": + thinking_parts.append(item.get("text", "")) + elif block_type == "text": + text_parts.append(block.get("text", "")) + + text = "".join(text_parts) or None + reasoning = None + if thinking_parts: + reasoning = ReasoningContent(reasoning_text="".join(thinking_parts)) + + return text, reasoning + + +def _convert_mistral_response_to_chat_messages(response_data: dict[str, Any] | str) -> list[ChatMessage]: + """Convert a raw Mistral API JSON response to a list of ChatMessages, handling array content.""" + data: dict[str, Any] = json.loads(response_data) if isinstance(response_data, str) else response_data + completions: list[ChatMessage] = [] + usage = data.get("usage") + model = data.get("model", "") + + for choice in data.get("choices", []): + message = choice.get("message", {}) + text, reasoning = _parse_mistral_content(message.get("content")) + + tool_calls: list[ToolCall] = [] + for tc in message.get("tool_calls") or []: + func = tc.get("function", {}) + try: + arguments = json.loads(func.get("arguments", "{}")) + tool_calls.append(ToolCall(id=tc.get("id"), tool_name=func.get("name"), arguments=arguments)) + except json.JSONDecodeError: + logger.warning( + "Mistral returned malformed JSON for tool call arguments. " + "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", + _id=tc.get("id"), + _name=func.get("name"), + _arguments=func.get("arguments"), + ) + + meta: dict[str, Any] = { + "model": model, + "index": choice.get("index", 0), + "finish_reason": choice.get("finish_reason"), + "usage": usage, + } + + completions.append(ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta, reasoning=reasoning)) + + return completions + + @component class MistralChatGenerator(OpenAIChatGenerator): """ @@ -28,9 +105,12 @@ class MistralChatGenerator(OpenAIChatGenerator): parameter in `run` method. Key Features and Compatibility: - - **Primary Compatibility**: Designed to work seamlessly with the Mistral API Chat Completion endpoint. + - **Primary Compatibility**: Compatible with the Mistral API Chat Completion endpoint. - **Streaming Support**: Supports streaming responses from the Mistral API Chat Completion endpoint. - **Customizability**: Supports all parameters supported by the Mistral API Chat Completion endpoint. + - **Reasoning Support**: Extracts reasoning/thinking content from models that support it + (e.g., mistral-small with `reasoning_effort`, magistral models) and stores it in the + `ReasoningContent` field on `ChatMessage`. This component uses the ChatMessage format for structuring both input and output, ensuring coherent and contextually relevant responses in chat-based text generation scenarios. @@ -58,6 +138,22 @@ class MistralChatGenerator(OpenAIChatGenerator): >> _meta={'model': 'mistral-small-latest', 'index': 0, 'finish_reason': 'stop', >> 'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]} ``` + + Reasoning usage example: + ```python + from haystack_integrations.components.generators.mistral import MistralChatGenerator + from haystack.dataclasses import ChatMessage + + messages = [ChatMessage.from_user("Solve: if x + 3 = 7, what is x?")] + + client = MistralChatGenerator( + model="mistral-small-latest", + generation_kwargs={"reasoning_effort": "high"}, + ) + response = client.run(messages) + print(response["replies"][0].reasoning) # Access reasoning content + print(response["replies"][0].text) # Access final answer + ``` """ SUPPORTED_MODELS: ClassVar[list[str]] = [ @@ -104,8 +200,6 @@ class MistralChatGenerator(OpenAIChatGenerator): "voxtral-mini-2507", "voxtral-mini-latest", "voxtral-mini-2602", - "voxtral-mini-latest", - "voxtral-mini-2507", ] """A list of models supported by Mistral AI see [Mistral AI docs](https://docs.mistral.ai/getting-started/models) for more information @@ -153,7 +247,12 @@ def __init__( events as they become available, with the stream terminated by a data: [DONE] message. - `safe_prompt`: Whether to inject a safety prompt before all conversations. - `random_seed`: The seed to use for random sampling. - - `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response. + - `reasoning_effort`: Controls reasoning/thinking tokens for models that support adjustable reasoning + (e.g., `mistral-small-latest`, `mistral-medium`). Accepted values: `"high"`, `"none"`. + See [Mistral reasoning docs](https://docs.mistral.ai/capabilities/reasoning/). + - `prompt_mode`: For native reasoning models (magistral). Set to `"reasoning"` to use the default + reasoning system prompt, or omit for the model's default behavior. + - `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response. If provided, the output will always be validated against this format (unless the model returns a tool call). For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs). @@ -202,12 +301,169 @@ def _prepare_api_call( tools=tools, tools_strict=tools_strict, ) - # Mistral does not support response_format and in Haystack 2.18 we always include response_format even if - # it's None + if "response_format" in api_args and api_args["response_format"] is None: api_args.pop("response_format") + + extra_body: dict[str, Any] = {} + for param in ("reasoning_effort", "prompt_mode", "safe_prompt"): + if param in api_args: + extra_body[param] = api_args.pop(param) + if extra_body: + api_args.setdefault("extra_body", {}).update(extra_body) + + for i, chat_msg in enumerate(messages): + if chat_msg.reasoning and chat_msg.reasoning.reasoning_text: + formatted = api_args["messages"][i] + text_content = formatted.get("content", "") or "" + formatted["content"] = [ + {"type": "thinking", "thinking": [{"type": "text", "text": chat_msg.reasoning.reasoning_text}]}, + {"type": "text", "text": text_content}, + ] + return api_args + @component.output_types(replies=list[ChatMessage]) + def run( + self, + messages: list[ChatMessage], + streaming_callback: StreamingCallbackT | None = None, + generation_kwargs: dict[str, Any] | None = None, + *, + tools: ToolsType | None = None, + tools_strict: bool | None = None, + ) -> dict[str, list[ChatMessage]]: + """ + Invokes chat completion on the Mistral API. + + :param messages: + A list of ChatMessage instances representing the input messages. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + :param generation_kwargs: + Additional keyword arguments for text generation. These parameters will + override the parameters passed during component initialization. + For details on Mistral API parameters, see + [Mistral docs](https://docs.mistral.ai/api/). + :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. + If set, it will override the `tools` parameter provided during initialization. + :param tools_strict: + Whether to enable strict schema adherence for tool calls. + + :returns: + A dictionary with the following key: + - `replies`: A list containing the generated responses as ChatMessage instances. + """ + if not self._is_warmed_up: + self.warm_up() + + if len(messages) == 0: + return {"replies": []} + + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False + ) + + if streaming_callback is not None: + merged_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + if merged_kwargs.get("reasoning_effort") or merged_kwargs.get("prompt_mode"): + logger.warning( + "Streaming with reasoning parameters is active. Reasoning content from thinking " + "blocks will not be captured during streaming. Use non-streaming mode to extract " + "reasoning content." + ) + + api_args = self._prepare_api_call( + messages=messages, + streaming_callback=streaming_callback, + generation_kwargs=generation_kwargs, + tools=tools, + tools_strict=tools_strict, + ) + openai_endpoint = api_args.pop("openai_endpoint") + + if streaming_callback is not None: + chat_completion = getattr(self.client.chat.completions, openai_endpoint)(**api_args) + completions = self._handle_stream_response(chat_completion, streaming_callback) + else: + raw_response = getattr(self.client.chat.completions.with_raw_response, openai_endpoint)(**api_args) + completions = _convert_mistral_response_to_chat_messages(raw_response.text) + + for message in completions: + _check_finish_reason(message.meta) + + return {"replies": completions} + + @component.output_types(replies=list[ChatMessage]) + async def run_async( + self, + messages: list[ChatMessage], + streaming_callback: StreamingCallbackT | None = None, + generation_kwargs: dict[str, Any] | None = None, + *, + tools: ToolsType | None = None, + tools_strict: bool | None = None, + ) -> dict[str, list[ChatMessage]]: + """ + Asynchronously invokes chat completion on the Mistral API. + + :param messages: + A list of ChatMessage instances representing the input messages. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + Must be a coroutine. + :param generation_kwargs: + Additional keyword arguments for text generation. + :param tools: A list of Tool and/or Toolset objects, or a single Toolset. + :param tools_strict: + Whether to enable strict schema adherence for tool calls. + + :returns: + A dictionary with the following key: + - `replies`: A list containing the generated responses as ChatMessage instances. + """ + if not self._is_warmed_up: + self.warm_up() + + if len(messages) == 0: + return {"replies": []} + + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True + ) + + if streaming_callback is not None: + merged_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + if merged_kwargs.get("reasoning_effort") or merged_kwargs.get("prompt_mode"): + logger.warning( + "Streaming with reasoning parameters is active. Reasoning content from thinking " + "blocks will not be captured during streaming. Use non-streaming mode to extract " + "reasoning content." + ) + + api_args = self._prepare_api_call( + messages=messages, + streaming_callback=streaming_callback, + generation_kwargs=generation_kwargs, + tools=tools, + tools_strict=tools_strict, + ) + openai_endpoint = api_args.pop("openai_endpoint") + + if streaming_callback is not None: + chat_completion = await getattr(self.async_client.chat.completions, openai_endpoint)(**api_args) + completions = await self._handle_async_stream_response(chat_completion, streaming_callback) + else: + raw_response = await getattr(self.async_client.chat.completions.with_raw_response, openai_endpoint)( + **api_args + ) + completions = _convert_mistral_response_to_chat_messages(raw_response.text) + + for message in completions: + _check_finish_reason(message.meta) + + return {"replies": completions} + def to_dict(self) -> dict[str, Any]: """ Serialize this component to a dictionary. diff --git a/integrations/mistral/tests/test_mistral_chat_generator.py b/integrations/mistral/tests/test_mistral_chat_generator.py index 76f1ce5fd0..e996e276a7 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator.py +++ b/integrations/mistral/tests/test_mistral_chat_generator.py @@ -1,25 +1,35 @@ import json +import logging import os -from datetime import datetime from unittest.mock import ANY, patch import pytest -import pytz from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk from haystack.components.tools import ToolInvoker -from haystack.dataclasses import ChatMessage, ChatRole, ComponentInfo, StreamingChunk, ToolCall, ToolCallDelta +from haystack.dataclasses import ( + ChatMessage, + ChatRole, + ComponentInfo, + ReasoningContent, + StreamingChunk, + ToolCall, + ToolCallDelta, +) from haystack.tools import Tool, Toolset from haystack.utils.auth import Secret from openai import OpenAIError -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice +from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel -from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator +from haystack_integrations.components.generators.mistral.chat.chat_generator import ( + MistralChatGenerator, + _convert_mistral_response_to_chat_messages, + _parse_mistral_content, +) class CollectorCallback: @@ -90,23 +100,25 @@ def mock_chat_completion(): Mock the OpenAI API completion response and reuse it for tests """ with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: - completion = ChatCompletion( - id="foo", - model="mistral-small-latest", - object="chat.completion", - choices=[ - Choice( - finish_reason="stop", - logprobs=None, - index=0, - message=ChatCompletionMessage(content="Hello world!", role="assistant"), - ) - ], - created=int(datetime.now(tz=pytz.timezone("UTC")).timestamp()), - usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + mock_response = type("MockRawResponse", (), {})() + mock_response.text = json.dumps( + { + "id": "foo", + "model": "mistral-small-latest", + "object": "chat.completion", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": {"role": "assistant", "content": "Hello world!"}, + } + ], + "created": 1234567890, + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + } ) - mock_chat_completion_create.return_value = completion + mock_chat_completion_create.return_value = mock_response yield mock_chat_completion_create @@ -837,3 +849,381 @@ def test_live_run_with_mixed_tools(self, mixed_tools): assert "city" in tool_call.arguments assert tool_call.arguments["city"] in ["Paris", "Berlin"] assert tool_call_message.meta["finish_reason"] == "tool_calls" + + @pytest.mark.skipif( + not os.environ.get("MISTRAL_API_KEY", None), + reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_reasoning(self): + chat_messages = [ChatMessage.from_user("If x + 3 = 7, what is x?")] + component = MistralChatGenerator(generation_kwargs={"reasoning_effort": "high"}) + results = component.run(chat_messages) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert message.reasoning is not None + assert message.reasoning.reasoning_text + assert message.text + assert "4" in message.text + assert message.meta["finish_reason"] == "stop" + + +@pytest.fixture +def mock_reasoning_response(): + """Mock that returns a raw-response-like object with reasoning array content.""" + with patch("openai.resources.chat.completions.Completions.create") as mock_create: + mock_response = type("MockRawResponse", (), {})() + mock_response.text = json.dumps( + { + "id": "test-reasoning", + "model": "mistral-small-latest", + "object": "chat.completion", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": [{"type": "text", "text": "Let me solve this step by step. 2+2=4."}], + }, + {"type": "text", "text": "The answer is 4."}, + ], + }, + } + ], + "created": 1234567890, + "usage": {"prompt_tokens": 10, "completion_tokens": 50, "total_tokens": 60}, + } + ) + mock_create.return_value = mock_response + yield mock_create + + +class TestReasoningSupport: + def test_parse_mistral_content_string(self): + text, reasoning = _parse_mistral_content("Hello world") + assert text == "Hello world" + assert reasoning is None + + def test_parse_mistral_content_none(self): + text, reasoning = _parse_mistral_content(None) + assert text is None + assert reasoning is None + + def test_parse_mistral_content_empty_string(self): + text, reasoning = _parse_mistral_content("") + assert text is None + assert reasoning is None + + def test_parse_mistral_content_array_with_reasoning(self): + content = [ + {"type": "thinking", "thinking": [{"type": "text", "text": "Step 1: analyze. Step 2: solve."}]}, + {"type": "text", "text": "The answer is 42."}, + ] + text, reasoning = _parse_mistral_content(content) + assert text == "The answer is 42." + assert reasoning is not None + assert reasoning.reasoning_text == "Step 1: analyze. Step 2: solve." + + def test_parse_mistral_content_array_text_only(self): + content = [ + {"type": "text", "text": "Just a plain response."}, + ] + text, reasoning = _parse_mistral_content(content) + assert text == "Just a plain response." + assert reasoning is None + + def test_parse_mistral_content_array_thinking_only(self): + content = [ + {"type": "thinking", "thinking": [{"type": "text", "text": "Internal reasoning only."}]}, + ] + text, reasoning = _parse_mistral_content(content) + assert text is None + assert reasoning is not None + assert reasoning.reasoning_text == "Internal reasoning only." + + def test_parse_mistral_content_multiple_thinking_blocks(self): + content = [ + {"type": "thinking", "thinking": [{"type": "text", "text": "First thought. "}]}, + {"type": "thinking", "thinking": [{"type": "text", "text": "Second thought."}]}, + {"type": "text", "text": "Final answer."}, + ] + text, reasoning = _parse_mistral_content(content) + assert text == "Final answer." + assert reasoning is not None + assert reasoning.reasoning_text == "First thought. Second thought." + + def test_parse_mistral_content_non_dict_blocks(self): + content = [ + "stray string", + 42, + {"type": "text", "text": "Valid block."}, + ] + text, reasoning = _parse_mistral_content(content) + assert text == "Valid block." + assert reasoning is None + + def test_convert_response_with_reasoning(self): + response_data = { + "id": "test", + "model": "mistral-small-latest", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": [{"type": "text", "text": "Reasoning here."}]}, + {"type": "text", "text": "Answer here."}, + ], + }, + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + messages = _convert_mistral_response_to_chat_messages(response_data) + assert len(messages) == 1 + msg = messages[0] + assert msg.text == "Answer here." + assert msg.reasoning is not None + assert msg.reasoning.reasoning_text == "Reasoning here." + assert msg.meta["model"] == "mistral-small-latest" + assert msg.meta["finish_reason"] == "stop" + + def test_convert_response_without_reasoning(self): + response_data = { + "id": "test", + "model": "mistral-small-latest", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": {"role": "assistant", "content": "Plain text response."}, + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, + } + messages = _convert_mistral_response_to_chat_messages(response_data) + assert len(messages) == 1 + msg = messages[0] + assert msg.text == "Plain text response." + assert msg.reasoning is None + + def test_convert_response_with_tool_calls(self): + response_data = { + "id": "test", + "model": "mistral-small-latest", + "choices": [ + { + "finish_reason": "tool_calls", + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "function": {"name": "weather", "arguments": '{"city": "Paris"}'}, + } + ], + }, + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 15, "total_tokens": 25}, + } + messages = _convert_mistral_response_to_chat_messages(response_data) + assert len(messages) == 1 + msg = messages[0] + assert msg.text is None + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0].tool_name == "weather" + assert msg.tool_calls[0].arguments == {"city": "Paris"} + + def test_convert_response_malformed_tool_call_json(self): + response_data = { + "id": "test", + "model": "mistral-small-latest", + "choices": [ + { + "finish_reason": "tool_calls", + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_bad", + "function": {"name": "weather", "arguments": "{invalid json}"}, + }, + { + "id": "call_good", + "function": {"name": "weather", "arguments": '{"city": "Paris"}'}, + }, + ], + }, + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 15, "total_tokens": 25}, + } + messages = _convert_mistral_response_to_chat_messages(response_data) + assert len(messages) == 1 + msg = messages[0] + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0].id == "call_good" + + def test_convert_response_with_reasoning_and_tool_calls(self): + response_data = { + "id": "test", + "model": "mistral-small-latest", + "choices": [ + { + "finish_reason": "tool_calls", + "index": 0, + "message": { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": [{"type": "text", "text": "I should check the weather."}]}, + {"type": "text", "text": "Let me look that up."}, + ], + "tool_calls": [ + { + "id": "call_123", + "function": {"name": "weather", "arguments": '{"city": "Paris"}'}, + } + ], + }, + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + messages = _convert_mistral_response_to_chat_messages(response_data) + assert len(messages) == 1 + msg = messages[0] + assert msg.text == "Let me look that up." + assert msg.reasoning is not None + assert msg.reasoning.reasoning_text == "I should check the weather." + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0].tool_name == "weather" + assert msg.tool_calls[0].arguments == {"city": "Paris"} + assert msg.meta["finish_reason"] == "tool_calls" + + def test_run_with_reasoning(self, chat_messages, mock_reasoning_response, monkeypatch): # noqa: ARG002 + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") + component = MistralChatGenerator() + response = component.run(chat_messages) + + assert isinstance(response, dict) + assert "replies" in response + assert len(response["replies"]) == 1 + + msg = response["replies"][0] + assert msg.text == "The answer is 4." + assert msg.reasoning is not None + assert msg.reasoning.reasoning_text == "Let me solve this step by step. 2+2=4." + + def test_prepare_api_call_routes_reasoning_effort(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") + component = MistralChatGenerator( + generation_kwargs={"reasoning_effort": "high", "temperature": 0.7}, + ) + messages = [ChatMessage.from_user("test")] + api_args = component._prepare_api_call(messages=messages) + + assert "reasoning_effort" not in api_args + assert api_args["extra_body"]["reasoning_effort"] == "high" + assert api_args["temperature"] == 0.7 + + def test_prepare_api_call_routes_prompt_mode(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") + component = MistralChatGenerator( + model="magistral-small-latest", + generation_kwargs={"prompt_mode": "reasoning"}, + ) + messages = [ChatMessage.from_user("test")] + api_args = component._prepare_api_call(messages=messages) + + assert "prompt_mode" not in api_args + assert api_args["extra_body"]["prompt_mode"] == "reasoning" + + def test_streaming_with_reasoning_logs_warning(self, monkeypatch, caplog): + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") + component = MistralChatGenerator( + generation_kwargs={"reasoning_effort": "high"}, + streaming_callback=print_streaming_chunk, + ) + + with ( + caplog.at_level(logging.WARNING), + patch.object(component, "_prepare_api_call", side_effect=RuntimeError), + pytest.raises(RuntimeError), + ): + component.run([ChatMessage.from_user("test")]) + + def test_prepare_api_call_preserves_reasoning(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") + component = MistralChatGenerator() + messages = [ + ChatMessage.from_user("What is 2+2?"), + ChatMessage.from_assistant( + text="The answer is 4.", + reasoning=ReasoningContent(reasoning_text="2+2 equals 4"), + ), + ChatMessage.from_user("Are you sure?"), + ] + api_args = component._prepare_api_call(messages=messages) + + assistant_msg = api_args["messages"][1] + assert isinstance(assistant_msg["content"], list) + assert len(assistant_msg["content"]) == 2 + assert assistant_msg["content"][0]["type"] == "thinking" + assert assistant_msg["content"][0]["thinking"][0]["text"] == "2+2 equals 4" + assert assistant_msg["content"][1]["type"] == "text" + assert assistant_msg["content"][1]["text"] == "The answer is 4." + + def test_parse_mistral_content_unexpected_type(self): + text, reasoning = _parse_mistral_content(42) + assert text == "42" + assert reasoning is None + + def test_parse_mistral_content_unexpected_object(self): + text, reasoning = _parse_mistral_content(3.14) + assert text == "3.14" + assert reasoning is None + + def test_run_empty_messages(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") + component = MistralChatGenerator() + response = component.run([]) + assert response == {"replies": []} + + def test_convert_response_from_json_string(self): + json_str = json.dumps( + { + "id": "test", + "model": "mistral-small-latest", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": [{"type": "text", "text": "Thinking."}]}, + {"type": "text", "text": "Answer."}, + ], + }, + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, + } + ) + messages = _convert_mistral_response_to_chat_messages(json_str) + assert len(messages) == 1 + assert messages[0].text == "Answer." + assert messages[0].reasoning is not None + assert messages[0].reasoning.reasoning_text == "Thinking." diff --git a/integrations/mistral/tests/test_mistral_chat_generator_async.py b/integrations/mistral/tests/test_mistral_chat_generator_async.py index 852136eced..c4e0eb8e28 100644 --- a/integrations/mistral/tests/test_mistral_chat_generator_async.py +++ b/integrations/mistral/tests/test_mistral_chat_generator_async.py @@ -1,9 +1,9 @@ +import json +import logging import os -from datetime import datetime from unittest.mock import AsyncMock, patch import pytest -import pytz from haystack.dataclasses import ( ChatMessage, ChatRole, @@ -11,8 +11,6 @@ ) from haystack.tools import Tool from openai import AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice from haystack_integrations.components.generators.mistral.chat.chat_generator import ( MistralChatGenerator, @@ -58,27 +56,24 @@ def mock_async_chat_completion(): "openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, ) as mock_chat_completion_create: - completion = ChatCompletion( - id="foo", - model="mistral-small-latest", - object="chat.completion", - choices=[ - Choice( - finish_reason="stop", - logprobs=None, - index=0, - message=ChatCompletionMessage(content="Hello world!", role="assistant"), - ) - ], - created=int(datetime.now(tz=pytz.timezone("UTC")).timestamp()), - usage={ - "prompt_tokens": 57, - "completion_tokens": 40, - "total_tokens": 97, - }, + mock_response = type("MockRawResponse", (), {})() + mock_response.text = json.dumps( + { + "id": "foo", + "model": "mistral-small-latest", + "object": "chat.completion", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": {"role": "assistant", "content": "Hello world!"}, + } + ], + "created": 1234567890, + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + } ) - # For async mocks, the return value should be awaitable - mock_chat_completion_create.return_value = completion + mock_chat_completion_create.return_value = mock_response yield mock_chat_completion_create @@ -262,3 +257,93 @@ async def callback(chunk: StreamingChunk): assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} assert tool_message.meta["finish_reason"] == "tool_calls" + + @pytest.mark.skipif( + not os.environ.get("MISTRAL_API_KEY", None), + reason="Export an env var called MISTRAL_API_KEY containing the Mistral API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_async_with_reasoning(self): + chat_messages = [ChatMessage.from_user("If x + 3 = 7, what is x?")] + component = MistralChatGenerator(generation_kwargs={"reasoning_effort": "high"}) + results = await component.run_async(chat_messages) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert message.reasoning is not None + assert message.reasoning.reasoning_text + assert message.text + assert "4" in message.text + assert message.meta["finish_reason"] == "stop" + + @pytest.mark.asyncio + async def test_run_async_with_reasoning(self, chat_messages, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") + + mock_response = type("MockRawResponse", (), {})() + mock_response.text = json.dumps( + { + "id": "test-reasoning-async", + "model": "mistral-small-latest", + "object": "chat.completion", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": [{"type": "text", "text": "Async reasoning content."}], + }, + {"type": "text", "text": "Async answer."}, + ], + }, + } + ], + "created": 1234567890, + "usage": {"prompt_tokens": 10, "completion_tokens": 50, "total_tokens": 60}, + } + ) + + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_create: + mock_create.return_value = mock_response + component = MistralChatGenerator() + response = await component.run_async(chat_messages) + + assert len(response["replies"]) == 1 + msg = response["replies"][0] + assert msg.text == "Async answer." + assert msg.reasoning is not None + assert msg.reasoning.reasoning_text == "Async reasoning content." + + @pytest.mark.asyncio + async def test_run_async_empty_messages(self, monkeypatch): + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") + component = MistralChatGenerator() + response = await component.run_async([]) + assert response == {"replies": []} + + @pytest.mark.asyncio + async def test_run_async_streaming_with_reasoning_logs_warning(self, monkeypatch, caplog): + monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key") + + async def async_callback(chunk: StreamingChunk): + pass + + component = MistralChatGenerator( + generation_kwargs={"reasoning_effort": "high"}, + streaming_callback=async_callback, + ) + + with ( + caplog.at_level(logging.WARNING), + patch.object(component, "_prepare_api_call", side_effect=RuntimeError), + pytest.raises(RuntimeError), + ): + await component.run_async([ChatMessage.from_user("test")])