Skip to content
177 changes: 173 additions & 4 deletions integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from haystack import logging
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
from haystack.tools import Tool, Toolset
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable

from .mcp_tool import (
AsyncExecutor,
Expand All @@ -27,6 +28,84 @@
logger = logging.getLogger(__name__)


def _serialize_state_config(config: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]] | None:
"""
Serialize a state configuration dictionary, converting any callable handlers to their string representation.

Works for both outputs_to_state (tool_name -> {state_key -> {source, handler}})
and outputs_to_string (tool_name -> {source, handler}).

:param config: The state configuration dictionary to serialize
:returns: The serialized configuration dictionary, or None if empty
"""
if not config:
Comment thread
mpangrazzi marked this conversation as resolved.
return None

serialized = {}
for tool_name, tool_config in config.items():
if not tool_config:
continue

# Check if this is outputs_to_string format (flat with optional source/handler)
# or outputs_to_state format (nested with state keys)
if "source" in tool_config or "handler" in tool_config:
Comment thread
mpangrazzi marked this conversation as resolved.
# outputs_to_string format: {source?, handler?}
serialized_tool_config = tool_config.copy()
if "handler" in tool_config and callable(tool_config["handler"]):
serialized_tool_config["handler"] = serialize_callable(tool_config["handler"])
serialized[tool_name] = serialized_tool_config
else:
# outputs_to_state format: {state_key -> {source?, handler?}}
serialized_tool_config = {}
for state_key, state_config in tool_config.items():
serialized_state_config = state_config.copy()
if "handler" in state_config and callable(state_config["handler"]):
serialized_state_config["handler"] = serialize_callable(state_config["handler"])
serialized_tool_config[state_key] = serialized_state_config
serialized[tool_name] = serialized_tool_config

return serialized if serialized else None


def _deserialize_state_config(config: dict[str, dict[str, Any]] | None) -> dict[str, dict[str, Any]]:
"""
Deserialize a state configuration dictionary, converting any serialized handlers back to callables.

Works for both outputs_to_state (tool_name -> {state_key -> {source, handler}})
and outputs_to_string (tool_name -> {source, handler}).

:param config: The state configuration dictionary to deserialize
:returns: The deserialized configuration dictionary
"""
if not config:
return {}

deserialized = {}
for tool_name, tool_config in config.items():
if not tool_config:
continue

# Check if this is outputs_to_string format (flat with optional source/handler)
# or outputs_to_state format (nested with state keys)
if "source" in tool_config or "handler" in tool_config:
# outputs_to_string format: {source?, handler?}
deserialized_tool_config = tool_config.copy()
if "handler" in tool_config and isinstance(tool_config["handler"], str):
deserialized_tool_config["handler"] = deserialize_callable(tool_config["handler"])
deserialized[tool_name] = deserialized_tool_config
else:
# outputs_to_state format: {state_key -> {source?, handler?}}
deserialized_tool_config = {}
for state_key, state_config in tool_config.items():
deserialized_state_config = state_config.copy()
if "handler" in state_config and isinstance(state_config["handler"], str):
deserialized_state_config["handler"] = deserialize_callable(state_config["handler"])
deserialized_tool_config[state_key] = deserialized_state_config
deserialized[tool_name] = deserialized_tool_config

return deserialized


class MCPToolset(Toolset):
"""
A Toolset that connects to an MCP (Model Context Protocol) server and provides
Expand Down Expand Up @@ -99,6 +178,30 @@ class MCPToolset(Toolset):
# Use the toolset as shown in the pipeline example above
```

Example with state configuration for Agent integration:
```python
from haystack_integrations.tools.mcp import MCPToolset, StdioServerInfo

# Create the toolset with per-tool state configuration
# This enables tools to read from and write to the Agent's State
toolset = MCPToolset(
server_info=StdioServerInfo(command="uvx", args=["mcp-server-git"]),
tool_names=["git_status", "git_diff", "git_log"],

# Maps the state key "repository" to the tool parameter "repo_path" for each tool
inputs_from_state={
"git_status": {"repository": "repo_path"},
"git_diff": {"repository": "repo_path"},
"git_log": {"repository": "repo_path"},
},
# Map tool outputs to state keys for each tool
outputs_to_state={
"git_status": {"status_result": {"source": "status"}}, # Extract "status" from output
"git_diff": {"diff_result": {}}, # use full output with default handling
},
)
```

Example using SSE (deprecated):
```python
from haystack_integrations.tools.mcp import MCPToolset, SSEServerInfo
Expand All @@ -121,6 +224,9 @@ def __init__(
connection_timeout: float = 30.0,
invocation_timeout: float = 30.0,
eager_connect: bool = False,
inputs_from_state: dict[str, dict[str, str]] | None = None,
outputs_to_state: dict[str, dict[str, dict[str, Any]]] | None = None,
outputs_to_string: dict[str, dict[str, Any]] | None = None,
):
"""
Initialize the MCP toolset.
Expand All @@ -132,14 +238,35 @@ def __init__(
:param invocation_timeout: Default timeout in seconds for tool invocations
:param eager_connect: If True, connect to server and load tools during initialization.
If False (default), defer connection to warm_up.
:param inputs_from_state: Optional dictionary mapping tool names to their inputs_from_state config.
Each config maps state keys to tool parameter names.
Tool names should match available tools from the server; a warning is logged for
unknown tools. Note: With Haystack >= 2.22.0, parameter names are validated;
ValueError is raised for invalid parameters. With earlier versions, invalid
parameters fail at runtime.
Example: `{"git_status": {"repository": "repo_path"}}`
Comment thread
sjrl marked this conversation as resolved.
:param outputs_to_state: Optional dictionary mapping tool names to their outputs_to_state config.
Each config defines how tool outputs map to state keys with optional handlers.
Tool names should match available tools from the server; a warning is logged for
unknown tools.
Example: `{"git_status": {"status_result": {"source": "status"}}}`
:param outputs_to_string: Optional dictionary mapping tool names to their outputs_to_string config.
Each config defines how tool outputs are converted to strings.
Tool names should match available tools from the server; a warning is logged for
unknown tools.
Example: `{"git_diff": {"source": "diff", "handler": format_diff}}`
:raises MCPToolNotFoundError: If any of the specified tool names are not found on the server
:raises ValueError: If parameter names in inputs_from_state are invalid (Haystack >= 2.22.0 only)
"""
# Store configuration
self.server_info = server_info
self.tool_names = tool_names
self.connection_timeout = connection_timeout
self.invocation_timeout = invocation_timeout
self.eager_connect = eager_connect
self.inputs_from_state = inputs_from_state or {}
self.outputs_to_state = outputs_to_state or {}
self.outputs_to_string = outputs_to_string or {}
self._warmup_called = False

if not eager_connect:
Expand Down Expand Up @@ -226,17 +353,23 @@ def invoke_tool(**kwargs: Any) -> Any:
description=tool_info.description or "",
parameters=tool_info.inputSchema,
function=create_invoke_tool(self, client, tool_info.name, self.invocation_timeout),
inputs_from_state=self.inputs_from_state.get(tool_info.name),
outputs_to_state=self.outputs_to_state.get(tool_info.name),
outputs_to_string=self.outputs_to_string.get(tool_info.name),
)
haystack_tools.append(tool)

# Validate state configs reference known tools
self._validate_state_configs({tool.name for tool in haystack_tools})

return haystack_tools
except Exception as e:
# We need to close because we could connect properly, retrieve tools yet
# fail because of an MCPToolNotFoundError
# fail because of validation errors
self.close()

if isinstance(e, MCPToolNotFoundError):
raise # re-raise MCPToolNotFoundError as is to show original message
if isinstance(e, (MCPToolNotFoundError, ValueError)):
raise # re-raise validation errors as is to show original message

# Create informative error message for SSE connection errors
# Common error handling for HTTP-based transports
Expand Down Expand Up @@ -292,6 +425,31 @@ def invoke_tool(**kwargs: Any) -> Any:

raise MCPConnectionError(message=message, server_info=self.server_info, operation="initialize") from e

def _validate_state_configs(self, available_tool_names: set[str]) -> None:
"""
Validate that state configuration tool names exist in the toolset.

Logs a warning for any tool names in the state configs that don't match
available tools in the toolset.

:param available_tool_names: Set of tool names that are available in the toolset
Comment thread
sjrl marked this conversation as resolved.
"""
configs: list[tuple[str, dict[str, Any]]] = [
("inputs_from_state", self.inputs_from_state),
("outputs_to_state", self.outputs_to_state),
("outputs_to_string", self.outputs_to_string),
]
for config_name, config in configs:
if config:
unknown_tools = set(config.keys()) - available_tool_names
if unknown_tools:
logger.warning(
"{config_name} references unknown tools: {unknown_tools}. Available tools: {available_tools}",
config_name=config_name,
unknown_tools=unknown_tools,
available_tools=available_tool_names,
)

def to_dict(self) -> dict[str, Any]:
"""
Serialize the MCPToolset to a dictionary.
Expand All @@ -306,6 +464,9 @@ def to_dict(self) -> dict[str, Any]:
"connection_timeout": self.connection_timeout,
"invocation_timeout": self.invocation_timeout,
"eager_connect": self.eager_connect,
"inputs_from_state": self.inputs_from_state if self.inputs_from_state else None,
"outputs_to_state": _serialize_state_config(self.outputs_to_state),
"outputs_to_string": _serialize_state_config(self.outputs_to_string),
},
}

Expand All @@ -324,13 +485,21 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset":
server_info_class = import_class_by_name(server_info_dict["type"])
server_info = cast(MCPServerInfo, server_info_class).from_dict(server_info_dict)

# Deserialize state configuration parameters
inputs_from_state = inner_data.get("inputs_from_state")
outputs_to_state = _deserialize_state_config(inner_data.get("outputs_to_state"))
outputs_to_string = _deserialize_state_config(inner_data.get("outputs_to_string"))

# Create a new MCPToolset instance
return cls(
server_info=server_info,
tool_names=inner_data.get("tool_names"),
connection_timeout=inner_data.get("connection_timeout", 30.0),
invocation_timeout=inner_data.get("invocation_timeout", 30.0),
eager_connect=inner_data.get("eager_connect", True),
eager_connect=inner_data.get("eager_connect", False),
inputs_from_state=inputs_from_state if inputs_from_state else None,
outputs_to_state=outputs_to_state if outputs_to_state else None,
outputs_to_string=outputs_to_string if outputs_to_string else None,
)

def close(self):
Expand Down
Loading