Skip to content
Merged
14 changes: 14 additions & 0 deletions integrations/mcp/tests/mcp_servers_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mcp import types
from mcp.server.fastmcp import FastMCP

################################################
Expand Down Expand Up @@ -55,3 +56,16 @@ def state_subtract(a: int, b: int) -> dict:
def echo(text: str) -> str:
"""Echo the input text."""
return text


################################################
# Image MCP Server
################################################

image_mcp = FastMCP("Image")


@image_mcp.tool()
def image_tool() -> list[types.ImageContent]:
"""Return image content without any text blocks."""
return [types.ImageContent(type="image", data="ZmFrZQ==", mimeType="image/png")]
170 changes: 169 additions & 1 deletion integrations/mcp/tests/test_mcp_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import json
import os
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand All @@ -13,12 +14,13 @@

from haystack_integrations.tools.mcp import (
MCPTool,
MCPToolNotFoundError,
StdioServerInfo,
)
from haystack_integrations.tools.mcp.mcp_tool import StdioClient, _extract_first_text_element

from .mcp_memory_transport import InMemoryServerInfo
from .mcp_servers_fixtures import calculator_mcp, echo_mcp
from .mcp_servers_fixtures import calculator_mcp, echo_mcp, image_mcp, state_calculator_mcp


@tool
Expand Down Expand Up @@ -104,6 +106,41 @@ def test_mcp_tool_invoke(self, mcp_add_tool, mcp_echo_tool):
echo_result = json.loads(echo_result)
assert echo_result["content"][0]["text"] == "Hello MCP!"

def test_mcp_tool_outputs_to_state_falls_back_to_full_response_for_non_text_content(self, mcp_tool_cleanup):
"""Test that non-text MCP content returns the full parsed response when state output is enabled."""
server_info = InMemoryServerInfo(server=image_mcp._mcp_server)
tool = MCPTool(
name="image_tool",
server_info=server_info,
eager_connect=True,
outputs_to_state={"image_payload": {}},
)
mcp_tool_cleanup(tool)

result = tool.invoke()

assert isinstance(result, dict)
assert len(result["content"]) == 1
assert result["content"][0]["type"] == "image"
assert result["content"][0]["data"] == "ZmFrZQ=="
assert result["content"][0]["mimeType"] == "image/png"
assert result["isError"] is False

def test_mcp_tool_outputs_to_state_returns_raw_text_when_text_is_not_json(self, mcp_tool_cleanup):
"""Test that plain text content is returned as-is when state output parsing cannot decode JSON."""
server_info = InMemoryServerInfo(server=echo_mcp._mcp_server)
tool = MCPTool(
name="echo",
server_info=server_info,
eager_connect=True,
outputs_to_state={"echo_payload": {}},
)
mcp_tool_cleanup(tool)

result = tool.invoke(text="Hello MCP!")

assert result == "Hello MCP!"

def test_mcp_tool_error_handling(self, mcp_error_tool):
"""Test error handling with the in-memory server."""
with pytest.raises(ToolInvocationError) as exc_info:
Expand All @@ -114,6 +151,47 @@ def test_mcp_tool_error_handling(self, mcp_error_tool):
# The first part of the message comes from ToolInvocationError's formatting
assert "Failed to invoke Tool `divide_by_zero`" in error_message

def test_mcp_tool_lazy_missing_tool_raises_with_available_tools(self, mcp_tool_cleanup):
"""Test that lazy warm-up surfaces missing-tool errors with the available tool names."""
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
tool = MCPTool(name="multiply", server_info=server_info, eager_connect=False)
mcp_tool_cleanup(tool)

mock_worker = MagicMock()
mock_worker.tools.return_value = [
SimpleNamespace(name="add"),
SimpleNamespace(name="subtract"),
SimpleNamespace(name="divide_by_zero"),
]

with (
patch("haystack_integrations.tools.mcp.mcp_tool._MCPClientSessionManager", return_value=mock_worker),
pytest.raises(MCPToolNotFoundError) as exc_info,
):
tool.warm_up()

assert exc_info.value.tool_name == "multiply"
assert set(exc_info.value.available_tools) == {"add", "subtract", "divide_by_zero"}

def test_mcp_tool_lazy_no_tools_server_raises_tool_not_found(self, mcp_tool_cleanup):
"""Test that lazy warm-up fails cleanly when the server exposes no tools."""
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
tool = MCPTool(name="anything", server_info=server_info, eager_connect=False)
mcp_tool_cleanup(tool)

mock_worker = MagicMock()
mock_worker.tools.return_value = []

with (
patch("haystack_integrations.tools.mcp.mcp_tool._MCPClientSessionManager", return_value=mock_worker),
pytest.raises(MCPToolNotFoundError) as exc_info,
):
tool.warm_up()

assert str(exc_info.value) == "No tools available on server"
assert exc_info.value.tool_name == "anything"
assert exc_info.value.available_tools == []

def test_mcp_tool_serde(self, mcp_tool_cleanup):
"""Test serialization and deserialization of MCPTool with in-memory server."""
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
Expand Down Expand Up @@ -186,6 +264,22 @@ def test_mcp_tool_state_mapping_parameters(self, mcp_tool_cleanup):
assert "b" in tool.parameters["properties"]
assert "b" in tool.parameters["required"]

def test_mcp_tool_eager_state_mapping_removes_inputs_from_schema(self, mcp_tool_cleanup):
"""Test that eager MCPTool initialization removes state-injected params from its public schema."""
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
tool = MCPTool(
name="add",
server_info=server_info,
eager_connect=True,
inputs_from_state={"state_a": "a"},
)
mcp_tool_cleanup(tool)

assert "a" not in tool.parameters["properties"]
assert "a" not in tool.parameters.get("required", [])
assert "b" in tool.parameters["properties"]
assert "b" in tool.parameters["required"]

def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup):
"""Test serialization and deserialization of MCPTool with state-mapping parameters."""
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
Expand Down Expand Up @@ -219,6 +313,62 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup):
assert new_tool._inputs_from_state == {"state_a": "a"}
assert new_tool._outputs_to_state == {"result": {"source": "output"}}

@pytest.mark.skipif(
not hasattr(__import__("haystack.tools", fromlist=["Tool"]).Tool, "_get_valid_inputs"),
reason="Requires Haystack >= 2.22.0 for inputs_from_state validation",
)
def test_mcp_tool_lazy_invalid_parameter_raises_on_warm_up(self, mcp_tool_cleanup):
"""Test that lazy MCPTool defers invalid inputs_from_state validation until warm_up()."""
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
tool = MCPTool(
name="add",
server_info=server_info,
eager_connect=False,
inputs_from_state={"state_key": "non_existent_param"},
)
mcp_tool_cleanup(tool)

assert tool.parameters == {"type": "object", "properties": {}, "additionalProperties": True}

with pytest.raises(ValueError, match="unknown parameter"):
tool.warm_up()

def test_mcp_tool_invoke_auto_warms_up_once(self, mcp_tool_cleanup):
"""Test that lazy MCPTool initializes on first invoke and reuses that connection."""
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
tool = MCPTool(name="add", server_info=server_info, eager_connect=False)
mcp_tool_cleanup(tool)

assert tool.parameters == {"type": "object", "properties": {}, "additionalProperties": True}

with patch.object(tool, "_connect_and_initialize", wraps=tool._connect_and_initialize) as mock_connect:
first_result = json.loads(tool.invoke(a=20, b=22))
second_result = json.loads(tool.invoke(a=1, b=2))

assert first_result["content"][0]["text"] == "42"
assert second_result["content"][0]["text"] == "3"
assert "a" in tool.parameters["properties"]
assert "b" in tool.parameters["properties"]
assert mock_connect.call_count == 1

@pytest.mark.asyncio
async def test_mcp_tool_ainvoke_matches_invoke_with_outputs_to_state(self, mcp_tool_cleanup):
"""Test that sync and async invocation paths return the same parsed state output."""
server_info = InMemoryServerInfo(server=state_calculator_mcp._mcp_server)
tool = MCPTool(
name="state_add",
server_info=server_info,
eager_connect=True,
outputs_to_state={"result": {"source": "result"}},
)
mcp_tool_cleanup(tool)

sync_result = tool.invoke(a=20, b=22)
async_result = await tool.ainvoke(a=20, b=22)

assert sync_result == {"result": 42}
assert async_result == sync_result

@pytest.mark.asyncio
@pytest.mark.parametrize(
"fileno_side_effect,fileno_return_value,notebook_environment",
Expand Down Expand Up @@ -255,6 +405,24 @@ async def test_stdio_client_stderr_handling(self, fileno_side_effect, fileno_ret
else:
assert errlog is mock_stderr

@pytest.mark.asyncio
async def test_mcp_client_aclose_clears_references_even_when_cleanup_fails(self, caplog):
"""Test that client cleanup always clears connection state, even if exit_stack cleanup raises."""
client = StdioClient(command="echo")
client.session = MagicMock()
client.stdio = MagicMock()
client.write = MagicMock()
client.exit_stack = MagicMock()
client.exit_stack.aclose = AsyncMock(side_effect=RuntimeError("cleanup failed"))

with caplog.at_level("WARNING"):
await client.aclose()

assert any("Error during MCP client cleanup: cleanup failed" in record.message for record in caplog.records)
assert client.session is None
assert client.stdio is None
assert client.write is None

@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_pipeline_warmup_with_mcp_tool(self):
Expand Down
106 changes: 105 additions & 1 deletion integrations/mcp/tests/test_mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# Import in-memory transport and fixtures
from .mcp_memory_transport import InMemoryServerInfo
from .mcp_servers_fixtures import calculator_mcp, echo_mcp
from .mcp_servers_fixtures import calculator_mcp, echo_mcp, image_mcp, state_calculator_mcp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,6 +152,16 @@ async def test_echo_toolset(self, echo_toolset):
assert echo_tool.name == "echo"
assert "Echo the input text." in echo_tool.description

async def test_toolset_invoke_returns_raw_json_string_without_outputs_to_state(self, echo_toolset):
"""Test that toolset-created tools keep the raw MCP JSON when no state output parsing is configured."""
echo_tool = echo_toolset.tools[0]

result = echo_tool.invoke(text="Hello MCP!")
parsed = json.loads(result)

assert parsed["content"][0]["text"] == "Hello MCP!"
assert parsed["isError"] is False

async def test_toolset_with_filtered_tools(self, calculator_toolset_with_tool_filter):
"""Test if the MCPToolset correctly filters tools based on tool_names parameter."""
toolset = calculator_toolset_with_tool_filter
Expand All @@ -172,6 +182,24 @@ async def test_toolset_with_filtered_tools(self, calculator_toolset_with_tool_fi
assert tool.name == "add"
assert "Add two integers." in tool.description

async def test_toolset_warm_up_replaces_placeholder_and_is_idempotent(self, mcp_tool_cleanup):
"""Test lazy toolsets swap the placeholder tool for real tools exactly once."""
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
toolset = MCPToolset(server_info=server_info, eager_connect=False)
mcp_tool_cleanup(toolset)

assert len(toolset.tools) == 1
assert toolset.tools[0].name.startswith("mcp_not_connected_placeholder_")

toolset.warm_up()
warmed_tool_names = [tool.name for tool in toolset.tools]

assert set(warmed_tool_names) == {"add", "subtract", "divide_by_zero"}
assert not any(name.startswith("mcp_not_connected_placeholder_") for name in warmed_tool_names)

toolset.warm_up()
assert [tool.name for tool in toolset.tools] == warmed_tool_names

async def test_toolset_serde(self, calculator_toolset):
"""Test serialization and deserialization of MCPToolset."""
toolset = calculator_toolset
Expand Down Expand Up @@ -292,6 +320,59 @@ async def test_toolset_with_state_config(self, calculator_toolset_with_state_con
assert add_tool.outputs_to_string is not None
assert subtract_tool.outputs_to_string is None

async def test_toolset_invoke_returns_parsed_dict_when_outputs_to_state_configured(self, mcp_tool_cleanup):
"""Test that toolset-created tools parse MCP text content into dicts for state updates."""
server_info = InMemoryServerInfo(server=state_calculator_mcp._mcp_server)
toolset = MCPToolset(
server_info=server_info,
tool_names=["state_add"],
eager_connect=True,
outputs_to_state={"state_add": {"result": {"source": "result"}}},
)
mcp_tool_cleanup(toolset)

add_tool = toolset.tools[0]
result = add_tool.invoke(a=20, b=22)

assert result == {"result": 42}

async def test_toolset_returns_full_response_for_non_text_content_with_outputs_to_state(self, mcp_tool_cleanup):
"""Test that toolset-created tools preserve full MCP payloads when there is no text content to parse."""
server_info = InMemoryServerInfo(server=image_mcp._mcp_server)
toolset = MCPToolset(
server_info=server_info,
tool_names=["image_tool"],
eager_connect=True,
outputs_to_state={"image_tool": {"image_payload": {}}},
)
mcp_tool_cleanup(toolset)

image_tool = toolset.tools[0]
result = image_tool.invoke()

assert isinstance(result, dict)
assert len(result["content"]) == 1
assert result["content"][0]["type"] == "image"
assert result["content"][0]["data"] == "ZmFrZQ=="
assert result["content"][0]["mimeType"] == "image/png"
assert result["isError"] is False

async def test_toolset_returns_raw_text_when_outputs_to_state_content_is_not_json(self, mcp_tool_cleanup):
"""Test that toolset-created tools preserve plain text when JSON decoding is not possible."""
server_info = InMemoryServerInfo(server=echo_mcp._mcp_server)
toolset = MCPToolset(
server_info=server_info,
tool_names=["echo"],
eager_connect=True,
outputs_to_state={"echo": {"echo_payload": {}}},
)
mcp_tool_cleanup(toolset)

echo_tool = toolset.tools[0]
result = echo_tool.invoke(text="Hello MCP!")

assert result == "Hello MCP!"

async def test_toolset_state_config_serde(self, calculator_toolset_with_state_config, mcp_tool_cleanup):
"""Test serialization and deserialization of MCPToolset with state configuration."""
toolset = calculator_toolset_with_state_config
Expand Down Expand Up @@ -373,6 +454,29 @@ async def test_toolset_state_config_invalid_parameter_raises_error(self):
},
)

@pytest.mark.skipif(
not hasattr(__import__("haystack.tools", fromlist=["Tool"]).Tool, "_get_valid_inputs"),
reason="Requires Haystack >= 2.22.0 for inputs_from_state validation",
)
async def test_toolset_lazy_invalid_parameter_raises_on_warm_up(self, mcp_tool_cleanup):
"""Test that lazy toolsets defer invalid inputs_from_state validation until warm_up()."""
server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server)
toolset = MCPToolset(
server_info=server_info,
tool_names=["add"],
eager_connect=False,
inputs_from_state={
"add": {"state_key": "non_existent_param"},
},
)
mcp_tool_cleanup(toolset)

assert len(toolset.tools) == 1
assert toolset.tools[0].name.startswith("mcp_not_connected_placeholder_")

with pytest.raises(ValueError, match="unknown parameter"):
toolset.warm_up()

async def test_toolset_no_state_config(self, calculator_toolset):
"""Test that tools have no state config when none is provided."""
toolset = calculator_toolset
Expand Down
Loading