Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
from typing import Any, ClassVar

from haystack import component, default_to_dict, logging
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage, StreamingCallbackT
from haystack.components.generators.chat.openai import _check_finish_reason
from haystack.dataclasses import (
ChatMessage,
ReasoningContent,
StreamingCallbackT,
ToolCall,
select_streaming_callback,
)
from haystack.tools import ToolsType, serialize_tools_or_toolset
from haystack.utils import serialize_callable
from haystack.utils.auth import Secret
Expand All @@ -16,6 +24,75 @@
logger = logging.getLogger(__name__)


def _parse_mistral_content(content: Any) -> tuple[str | None, ReasoningContent | None]:
"""Parse Mistral message content which can be a string or an array of typed blocks."""
if content is None:
return None, None
if isinstance(content, str):
return content or None, None
if not isinstance(content, list):
return str(content), None

text_parts: list[str] = []
thinking_parts: list[str] = []

for block in content:
if not isinstance(block, dict):
continue
block_type = block.get("type", "")
if block_type == "thinking":
for item in block.get("thinking", []):
if isinstance(item, dict) and item.get("type") == "text":
thinking_parts.append(item.get("text", ""))
elif block_type == "text":
text_parts.append(block.get("text", ""))

text = "".join(text_parts) or None
reasoning = None
if thinking_parts:
reasoning = ReasoningContent(reasoning_text="".join(thinking_parts))

return text, reasoning


def _convert_mistral_response_to_chat_messages(response_data: dict[str, Any] | str) -> list[ChatMessage]:
"""Convert a raw Mistral API JSON response to a list of ChatMessages, handling array content."""
data: dict[str, Any] = json.loads(response_data) if isinstance(response_data, str) else response_data
completions: list[ChatMessage] = []
usage = data.get("usage")
model = data.get("model", "")

for choice in data.get("choices", []):
message = choice.get("message", {})
text, reasoning = _parse_mistral_content(message.get("content"))

tool_calls: list[ToolCall] = []
for tc in message.get("tool_calls") or []:
func = tc.get("function", {})
try:
arguments = json.loads(func.get("arguments", "{}"))
tool_calls.append(ToolCall(id=tc.get("id"), tool_name=func.get("name"), arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"Mistral returned malformed JSON for tool call arguments. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=tc.get("id"),
_name=func.get("name"),
_arguments=func.get("arguments"),
)

meta: dict[str, Any] = {
"model": model,
"index": choice.get("index", 0),
"finish_reason": choice.get("finish_reason"),
"usage": usage,
}

completions.append(ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta, reasoning=reasoning))

return completions


@component
class MistralChatGenerator(OpenAIChatGenerator):
"""
Expand All @@ -28,9 +105,12 @@ class MistralChatGenerator(OpenAIChatGenerator):
parameter in `run` method.

Key Features and Compatibility:
- **Primary Compatibility**: Designed to work seamlessly with the Mistral API Chat Completion endpoint.
- **Primary Compatibility**: Compatible with the Mistral API Chat Completion endpoint.
- **Streaming Support**: Supports streaming responses from the Mistral API Chat Completion endpoint.
- **Customizability**: Supports all parameters supported by the Mistral API Chat Completion endpoint.
- **Reasoning Support**: Extracts reasoning/thinking content from models that support it
(e.g., mistral-small with `reasoning_effort`, magistral models) and stores it in the
`ReasoningContent` field on `ChatMessage`.

This component uses the ChatMessage format for structuring both input and output,
ensuring coherent and contextually relevant responses in chat-based text generation scenarios.
Expand Down Expand Up @@ -58,6 +138,22 @@ class MistralChatGenerator(OpenAIChatGenerator):
>> _meta={'model': 'mistral-small-latest', 'index': 0, 'finish_reason': 'stop',
>> 'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]}
```

Reasoning usage example:
```python
from haystack_integrations.components.generators.mistral import MistralChatGenerator
from haystack.dataclasses import ChatMessage

messages = [ChatMessage.from_user("Solve: if x + 3 = 7, what is x?")]

client = MistralChatGenerator(
model="mistral-small-latest",
generation_kwargs={"reasoning_effort": "high"},
)
response = client.run(messages)
print(response["replies"][0].reasoning) # Access reasoning content
print(response["replies"][0].text) # Access final answer
```
"""

SUPPORTED_MODELS: ClassVar[list[str]] = [
Expand Down Expand Up @@ -104,8 +200,6 @@ class MistralChatGenerator(OpenAIChatGenerator):
"voxtral-mini-2507",
"voxtral-mini-latest",
"voxtral-mini-2602",
"voxtral-mini-latest",
"voxtral-mini-2507",
]
"""A list of models supported by Mistral AI
see [Mistral AI docs](https://docs.mistral.ai/getting-started/models) for more information
Expand Down Expand Up @@ -153,7 +247,12 @@ def __init__(
events as they become available, with the stream terminated by a data: [DONE] message.
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
- `random_seed`: The seed to use for random sampling.
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
- `reasoning_effort`: Controls reasoning/thinking tokens for models that support adjustable reasoning
(e.g., `mistral-small-latest`, `mistral-medium`). Accepted values: `"high"`, `"none"`.
See [Mistral reasoning docs](https://docs.mistral.ai/capabilities/reasoning/).
- `prompt_mode`: For native reasoning models (magistral). Set to `"reasoning"` to use the default
reasoning system prompt, or omit for the model's default behavior.
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
If provided, the output will always be validated against this
format (unless the model returns a tool call).
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
Expand Down Expand Up @@ -202,12 +301,169 @@ def _prepare_api_call(
tools=tools,
tools_strict=tools_strict,
)
# Mistral does not support response_format and in Haystack 2.18 we always include response_format even if
# it's None

if "response_format" in api_args and api_args["response_format"] is None:
api_args.pop("response_format")

extra_body: dict[str, Any] = {}
for param in ("reasoning_effort", "prompt_mode", "safe_prompt"):
if param in api_args:
extra_body[param] = api_args.pop(param)
if extra_body:
api_args.setdefault("extra_body", {}).update(extra_body)

for i, chat_msg in enumerate(messages):
if chat_msg.reasoning and chat_msg.reasoning.reasoning_text:
formatted = api_args["messages"][i]
text_content = formatted.get("content", "") or ""
formatted["content"] = [
{"type": "thinking", "thinking": [{"type": "text", "text": chat_msg.reasoning.reasoning_text}]},
{"type": "text", "text": text_content},
]

return api_args

@component.output_types(replies=list[ChatMessage])
def run(
self,
messages: list[ChatMessage],
streaming_callback: StreamingCallbackT | None = None,
generation_kwargs: dict[str, Any] | None = None,
*,
tools: ToolsType | None = None,
tools_strict: bool | None = None,
) -> dict[str, list[ChatMessage]]:
"""
Invokes chat completion on the Mistral API.

:param messages:
A list of ChatMessage instances representing the input messages.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
:param generation_kwargs:
Additional keyword arguments for text generation. These parameters will
override the parameters passed during component initialization.
For details on Mistral API parameters, see
[Mistral docs](https://docs.mistral.ai/api/).
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
If set, it will override the `tools` parameter provided during initialization.
:param tools_strict:
Whether to enable strict schema adherence for tool calls.

:returns:
A dictionary with the following key:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if not self._is_warmed_up:
self.warm_up()

if len(messages) == 0:
return {"replies": []}

streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
)

if streaming_callback is not None:
merged_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
if merged_kwargs.get("reasoning_effort") or merged_kwargs.get("prompt_mode"):
logger.warning(
"Streaming with reasoning parameters is active. Reasoning content from thinking "
"blocks will not be captured during streaming. Use non-streaming mode to extract "
"reasoning content."
)

api_args = self._prepare_api_call(
messages=messages,
streaming_callback=streaming_callback,
generation_kwargs=generation_kwargs,
tools=tools,
tools_strict=tools_strict,
)
openai_endpoint = api_args.pop("openai_endpoint")

if streaming_callback is not None:
chat_completion = getattr(self.client.chat.completions, openai_endpoint)(**api_args)
completions = self._handle_stream_response(chat_completion, streaming_callback)
else:
raw_response = getattr(self.client.chat.completions.with_raw_response, openai_endpoint)(**api_args)
completions = _convert_mistral_response_to_chat_messages(raw_response.text)

for message in completions:
_check_finish_reason(message.meta)

return {"replies": completions}

@component.output_types(replies=list[ChatMessage])
async def run_async(
self,
messages: list[ChatMessage],
streaming_callback: StreamingCallbackT | None = None,
generation_kwargs: dict[str, Any] | None = None,
*,
tools: ToolsType | None = None,
tools_strict: bool | None = None,
) -> dict[str, list[ChatMessage]]:
"""
Asynchronously invokes chat completion on the Mistral API.

:param messages:
A list of ChatMessage instances representing the input messages.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
Must be a coroutine.
:param generation_kwargs:
Additional keyword arguments for text generation.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset.
:param tools_strict:
Whether to enable strict schema adherence for tool calls.

:returns:
A dictionary with the following key:
- `replies`: A list containing the generated responses as ChatMessage instances.
"""
if not self._is_warmed_up:
self.warm_up()

if len(messages) == 0:
return {"replies": []}

streaming_callback = select_streaming_callback(
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
)

if streaming_callback is not None:
merged_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
if merged_kwargs.get("reasoning_effort") or merged_kwargs.get("prompt_mode"):
logger.warning(
"Streaming with reasoning parameters is active. Reasoning content from thinking "
"blocks will not be captured during streaming. Use non-streaming mode to extract "
"reasoning content."
)

api_args = self._prepare_api_call(
messages=messages,
streaming_callback=streaming_callback,
generation_kwargs=generation_kwargs,
tools=tools,
tools_strict=tools_strict,
)
openai_endpoint = api_args.pop("openai_endpoint")

if streaming_callback is not None:
chat_completion = await getattr(self.async_client.chat.completions, openai_endpoint)(**api_args)
completions = await self._handle_async_stream_response(chat_completion, streaming_callback)
else:
raw_response = await getattr(self.async_client.chat.completions.with_raw_response, openai_endpoint)(
**api_args
)
completions = _convert_mistral_response_to_chat_messages(raw_response.text)

for message in completions:
_check_finish_reason(message.meta)

return {"replies": completions}

def to_dict(self) -> dict[str, Any]:
"""
Serialize this component to a dictionary.
Expand Down
Loading
Loading