diff --git a/integrations/mcp/tests/mcp_servers_fixtures.py b/integrations/mcp/tests/mcp_servers_fixtures.py index d7d54bf2ee..4260ba2c6b 100644 --- a/integrations/mcp/tests/mcp_servers_fixtures.py +++ b/integrations/mcp/tests/mcp_servers_fixtures.py @@ -1,3 +1,4 @@ +from mcp import types from mcp.server.fastmcp import FastMCP ################################################ @@ -55,3 +56,16 @@ def state_subtract(a: int, b: int) -> dict: def echo(text: str) -> str: """Echo the input text.""" return text + + +################################################ +# Image MCP Server +################################################ + +image_mcp = FastMCP("Image") + + +@image_mcp.tool() +def image_tool() -> list[types.ImageContent]: + """Return image content without any text blocks.""" + return [types.ImageContent(type="image", data="ZmFrZQ==", mimeType="image/png")] diff --git a/integrations/mcp/tests/test_mcp_tool.py b/integrations/mcp/tests/test_mcp_tool.py index 1681053a50..cc2500beb3 100644 --- a/integrations/mcp/tests/test_mcp_tool.py +++ b/integrations/mcp/tests/test_mcp_tool.py @@ -1,6 +1,7 @@ import io import json import os +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -13,12 +14,13 @@ from haystack_integrations.tools.mcp import ( MCPTool, + MCPToolNotFoundError, StdioServerInfo, ) from haystack_integrations.tools.mcp.mcp_tool import StdioClient, _extract_first_text_element from .mcp_memory_transport import InMemoryServerInfo -from .mcp_servers_fixtures import calculator_mcp, echo_mcp +from .mcp_servers_fixtures import calculator_mcp, echo_mcp, image_mcp, state_calculator_mcp @tool @@ -104,6 +106,41 @@ def test_mcp_tool_invoke(self, mcp_add_tool, mcp_echo_tool): echo_result = json.loads(echo_result) assert echo_result["content"][0]["text"] == "Hello MCP!" + def test_mcp_tool_outputs_to_state_falls_back_to_full_response_for_non_text_content(self, mcp_tool_cleanup): + """Test that non-text MCP content returns the full parsed response when state output is enabled.""" + server_info = InMemoryServerInfo(server=image_mcp._mcp_server) + tool = MCPTool( + name="image_tool", + server_info=server_info, + eager_connect=True, + outputs_to_state={"image_payload": {}}, + ) + mcp_tool_cleanup(tool) + + result = tool.invoke() + + assert isinstance(result, dict) + assert len(result["content"]) == 1 + assert result["content"][0]["type"] == "image" + assert result["content"][0]["data"] == "ZmFrZQ==" + assert result["content"][0]["mimeType"] == "image/png" + assert result["isError"] is False + + def test_mcp_tool_outputs_to_state_returns_raw_text_when_text_is_not_json(self, mcp_tool_cleanup): + """Test that plain text content is returned as-is when state output parsing cannot decode JSON.""" + server_info = InMemoryServerInfo(server=echo_mcp._mcp_server) + tool = MCPTool( + name="echo", + server_info=server_info, + eager_connect=True, + outputs_to_state={"echo_payload": {}}, + ) + mcp_tool_cleanup(tool) + + result = tool.invoke(text="Hello MCP!") + + assert result == "Hello MCP!" + def test_mcp_tool_error_handling(self, mcp_error_tool): """Test error handling with the in-memory server.""" with pytest.raises(ToolInvocationError) as exc_info: @@ -114,6 +151,47 @@ def test_mcp_tool_error_handling(self, mcp_error_tool): # The first part of the message comes from ToolInvocationError's formatting assert "Failed to invoke Tool `divide_by_zero`" in error_message + def test_mcp_tool_lazy_missing_tool_raises_with_available_tools(self, mcp_tool_cleanup): + """Test that lazy warm-up surfaces missing-tool errors with the available tool names.""" + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + tool = MCPTool(name="multiply", server_info=server_info, eager_connect=False) + mcp_tool_cleanup(tool) + + mock_worker = MagicMock() + mock_worker.tools.return_value = [ + SimpleNamespace(name="add"), + SimpleNamespace(name="subtract"), + SimpleNamespace(name="divide_by_zero"), + ] + + with ( + patch("haystack_integrations.tools.mcp.mcp_tool._MCPClientSessionManager", return_value=mock_worker), + pytest.raises(MCPToolNotFoundError) as exc_info, + ): + tool.warm_up() + + assert exc_info.value.tool_name == "multiply" + assert set(exc_info.value.available_tools) == {"add", "subtract", "divide_by_zero"} + + def test_mcp_tool_lazy_no_tools_server_raises_tool_not_found(self, mcp_tool_cleanup): + """Test that lazy warm-up fails cleanly when the server exposes no tools.""" + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + tool = MCPTool(name="anything", server_info=server_info, eager_connect=False) + mcp_tool_cleanup(tool) + + mock_worker = MagicMock() + mock_worker.tools.return_value = [] + + with ( + patch("haystack_integrations.tools.mcp.mcp_tool._MCPClientSessionManager", return_value=mock_worker), + pytest.raises(MCPToolNotFoundError) as exc_info, + ): + tool.warm_up() + + assert str(exc_info.value) == "No tools available on server" + assert exc_info.value.tool_name == "anything" + assert exc_info.value.available_tools == [] + def test_mcp_tool_serde(self, mcp_tool_cleanup): """Test serialization and deserialization of MCPTool with in-memory server.""" server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) @@ -186,6 +264,22 @@ def test_mcp_tool_state_mapping_parameters(self, mcp_tool_cleanup): assert "b" in tool.parameters["properties"] assert "b" in tool.parameters["required"] + def test_mcp_tool_eager_state_mapping_removes_inputs_from_schema(self, mcp_tool_cleanup): + """Test that eager MCPTool initialization removes state-injected params from its public schema.""" + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + tool = MCPTool( + name="add", + server_info=server_info, + eager_connect=True, + inputs_from_state={"state_a": "a"}, + ) + mcp_tool_cleanup(tool) + + assert "a" not in tool.parameters["properties"] + assert "a" not in tool.parameters.get("required", []) + assert "b" in tool.parameters["properties"] + assert "b" in tool.parameters["required"] + def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup): """Test serialization and deserialization of MCPTool with state-mapping parameters.""" server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) @@ -219,6 +313,62 @@ def test_mcp_tool_serde_with_state_mapping(self, mcp_tool_cleanup): assert new_tool._inputs_from_state == {"state_a": "a"} assert new_tool._outputs_to_state == {"result": {"source": "output"}} + @pytest.mark.skipif( + not hasattr(__import__("haystack.tools", fromlist=["Tool"]).Tool, "_get_valid_inputs"), + reason="Requires Haystack >= 2.22.0 for inputs_from_state validation", + ) + def test_mcp_tool_lazy_invalid_parameter_raises_on_warm_up(self, mcp_tool_cleanup): + """Test that lazy MCPTool defers invalid inputs_from_state validation until warm_up().""" + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + tool = MCPTool( + name="add", + server_info=server_info, + eager_connect=False, + inputs_from_state={"state_key": "non_existent_param"}, + ) + mcp_tool_cleanup(tool) + + assert tool.parameters == {"type": "object", "properties": {}, "additionalProperties": True} + + with pytest.raises(ValueError, match="unknown parameter"): + tool.warm_up() + + def test_mcp_tool_invoke_auto_warms_up_once(self, mcp_tool_cleanup): + """Test that lazy MCPTool initializes on first invoke and reuses that connection.""" + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + tool = MCPTool(name="add", server_info=server_info, eager_connect=False) + mcp_tool_cleanup(tool) + + assert tool.parameters == {"type": "object", "properties": {}, "additionalProperties": True} + + with patch.object(tool, "_connect_and_initialize", wraps=tool._connect_and_initialize) as mock_connect: + first_result = json.loads(tool.invoke(a=20, b=22)) + second_result = json.loads(tool.invoke(a=1, b=2)) + + assert first_result["content"][0]["text"] == "42" + assert second_result["content"][0]["text"] == "3" + assert "a" in tool.parameters["properties"] + assert "b" in tool.parameters["properties"] + assert mock_connect.call_count == 1 + + @pytest.mark.asyncio + async def test_mcp_tool_ainvoke_matches_invoke_with_outputs_to_state(self, mcp_tool_cleanup): + """Test that sync and async invocation paths return the same parsed state output.""" + server_info = InMemoryServerInfo(server=state_calculator_mcp._mcp_server) + tool = MCPTool( + name="state_add", + server_info=server_info, + eager_connect=True, + outputs_to_state={"result": {"source": "result"}}, + ) + mcp_tool_cleanup(tool) + + sync_result = tool.invoke(a=20, b=22) + async_result = await tool.ainvoke(a=20, b=22) + + assert sync_result == {"result": 42} + assert async_result == sync_result + @pytest.mark.asyncio @pytest.mark.parametrize( "fileno_side_effect,fileno_return_value,notebook_environment", @@ -255,6 +405,24 @@ async def test_stdio_client_stderr_handling(self, fileno_side_effect, fileno_ret else: assert errlog is mock_stderr + @pytest.mark.asyncio + async def test_mcp_client_aclose_clears_references_even_when_cleanup_fails(self, caplog): + """Test that client cleanup always clears connection state, even if exit_stack cleanup raises.""" + client = StdioClient(command="echo") + client.session = MagicMock() + client.stdio = MagicMock() + client.write = MagicMock() + client.exit_stack = MagicMock() + client.exit_stack.aclose = AsyncMock(side_effect=RuntimeError("cleanup failed")) + + with caplog.at_level("WARNING"): + await client.aclose() + + assert any("Error during MCP client cleanup: cleanup failed" in record.message for record in caplog.records) + assert client.session is None + assert client.stdio is None + assert client.write is None + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_pipeline_warmup_with_mcp_tool(self): diff --git a/integrations/mcp/tests/test_mcp_toolset.py b/integrations/mcp/tests/test_mcp_toolset.py index a58a8e6056..d94bacfdc5 100644 --- a/integrations/mcp/tests/test_mcp_toolset.py +++ b/integrations/mcp/tests/test_mcp_toolset.py @@ -30,7 +30,7 @@ # Import in-memory transport and fixtures from .mcp_memory_transport import InMemoryServerInfo -from .mcp_servers_fixtures import calculator_mcp, echo_mcp +from .mcp_servers_fixtures import calculator_mcp, echo_mcp, image_mcp, state_calculator_mcp logger = logging.getLogger(__name__) @@ -152,6 +152,16 @@ async def test_echo_toolset(self, echo_toolset): assert echo_tool.name == "echo" assert "Echo the input text." in echo_tool.description + async def test_toolset_invoke_returns_raw_json_string_without_outputs_to_state(self, echo_toolset): + """Test that toolset-created tools keep the raw MCP JSON when no state output parsing is configured.""" + echo_tool = echo_toolset.tools[0] + + result = echo_tool.invoke(text="Hello MCP!") + parsed = json.loads(result) + + assert parsed["content"][0]["text"] == "Hello MCP!" + assert parsed["isError"] is False + async def test_toolset_with_filtered_tools(self, calculator_toolset_with_tool_filter): """Test if the MCPToolset correctly filters tools based on tool_names parameter.""" toolset = calculator_toolset_with_tool_filter @@ -172,6 +182,24 @@ async def test_toolset_with_filtered_tools(self, calculator_toolset_with_tool_fi assert tool.name == "add" assert "Add two integers." in tool.description + async def test_toolset_warm_up_replaces_placeholder_and_is_idempotent(self, mcp_tool_cleanup): + """Test lazy toolsets swap the placeholder tool for real tools exactly once.""" + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + toolset = MCPToolset(server_info=server_info, eager_connect=False) + mcp_tool_cleanup(toolset) + + assert len(toolset.tools) == 1 + assert toolset.tools[0].name.startswith("mcp_not_connected_placeholder_") + + toolset.warm_up() + warmed_tool_names = [tool.name for tool in toolset.tools] + + assert set(warmed_tool_names) == {"add", "subtract", "divide_by_zero"} + assert not any(name.startswith("mcp_not_connected_placeholder_") for name in warmed_tool_names) + + toolset.warm_up() + assert [tool.name for tool in toolset.tools] == warmed_tool_names + async def test_toolset_serde(self, calculator_toolset): """Test serialization and deserialization of MCPToolset.""" toolset = calculator_toolset @@ -292,6 +320,59 @@ async def test_toolset_with_state_config(self, calculator_toolset_with_state_con assert add_tool.outputs_to_string is not None assert subtract_tool.outputs_to_string is None + async def test_toolset_invoke_returns_parsed_dict_when_outputs_to_state_configured(self, mcp_tool_cleanup): + """Test that toolset-created tools parse MCP text content into dicts for state updates.""" + server_info = InMemoryServerInfo(server=state_calculator_mcp._mcp_server) + toolset = MCPToolset( + server_info=server_info, + tool_names=["state_add"], + eager_connect=True, + outputs_to_state={"state_add": {"result": {"source": "result"}}}, + ) + mcp_tool_cleanup(toolset) + + add_tool = toolset.tools[0] + result = add_tool.invoke(a=20, b=22) + + assert result == {"result": 42} + + async def test_toolset_returns_full_response_for_non_text_content_with_outputs_to_state(self, mcp_tool_cleanup): + """Test that toolset-created tools preserve full MCP payloads when there is no text content to parse.""" + server_info = InMemoryServerInfo(server=image_mcp._mcp_server) + toolset = MCPToolset( + server_info=server_info, + tool_names=["image_tool"], + eager_connect=True, + outputs_to_state={"image_tool": {"image_payload": {}}}, + ) + mcp_tool_cleanup(toolset) + + image_tool = toolset.tools[0] + result = image_tool.invoke() + + assert isinstance(result, dict) + assert len(result["content"]) == 1 + assert result["content"][0]["type"] == "image" + assert result["content"][0]["data"] == "ZmFrZQ==" + assert result["content"][0]["mimeType"] == "image/png" + assert result["isError"] is False + + async def test_toolset_returns_raw_text_when_outputs_to_state_content_is_not_json(self, mcp_tool_cleanup): + """Test that toolset-created tools preserve plain text when JSON decoding is not possible.""" + server_info = InMemoryServerInfo(server=echo_mcp._mcp_server) + toolset = MCPToolset( + server_info=server_info, + tool_names=["echo"], + eager_connect=True, + outputs_to_state={"echo": {"echo_payload": {}}}, + ) + mcp_tool_cleanup(toolset) + + echo_tool = toolset.tools[0] + result = echo_tool.invoke(text="Hello MCP!") + + assert result == "Hello MCP!" + async def test_toolset_state_config_serde(self, calculator_toolset_with_state_config, mcp_tool_cleanup): """Test serialization and deserialization of MCPToolset with state configuration.""" toolset = calculator_toolset_with_state_config @@ -373,6 +454,29 @@ async def test_toolset_state_config_invalid_parameter_raises_error(self): }, ) + @pytest.mark.skipif( + not hasattr(__import__("haystack.tools", fromlist=["Tool"]).Tool, "_get_valid_inputs"), + reason="Requires Haystack >= 2.22.0 for inputs_from_state validation", + ) + async def test_toolset_lazy_invalid_parameter_raises_on_warm_up(self, mcp_tool_cleanup): + """Test that lazy toolsets defer invalid inputs_from_state validation until warm_up().""" + server_info = InMemoryServerInfo(server=calculator_mcp._mcp_server) + toolset = MCPToolset( + server_info=server_info, + tool_names=["add"], + eager_connect=False, + inputs_from_state={ + "add": {"state_key": "non_existent_param"}, + }, + ) + mcp_tool_cleanup(toolset) + + assert len(toolset.tools) == 1 + assert toolset.tools[0].name.startswith("mcp_not_connected_placeholder_") + + with pytest.raises(ValueError, match="unknown parameter"): + toolset.warm_up() + async def test_toolset_no_state_config(self, calculator_toolset): """Test that tools have no state config when none is provided.""" toolset = calculator_toolset