diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index 8f8e3eaf14..68f1bf7de0 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -5,6 +5,7 @@ import asyncio import concurrent.futures import threading +import warnings from abc import ABC, abstractmethod from collections.abc import Coroutine from contextlib import AsyncExitStack @@ -16,6 +17,7 @@ from haystack.core.serialization import generate_qualified_class_name, import_class_by_name from haystack.tools import Tool from haystack.tools.errors import ToolInvocationError +from haystack.utils.url_validation import is_valid_http_url from mcp import ClientSession, StdioServerParameters, types from mcp.client.sse import sse_client @@ -351,18 +353,19 @@ class SSEClient(MCPClient): MCP client that connects to servers using SSE transport. """ - def __init__(self, base_url: str, token: str | None = None, timeout: int = 5) -> None: + def __init__(self, server_info: "SSEServerInfo") -> None: """ - Initialize an SSE MCP client. + Initialize an SSE MCP client using server configuration. - :param base_url: Base URL of the server - :param token: Authentication token for the server (optional) - :param timeout: Connection timeout in seconds + :param server_info: Configuration object containing URL, token, timeout, etc. """ super().__init__() - self.base_url: str = base_url.rstrip("/") # Remove any trailing slashes - self.token: str | None = token - self.timeout: int = timeout + + # in post_init we validate the url and set the url field so it is guaranteed to be valid + # safely ignore the mypy warning here + self.url: str = server_info.url # type: ignore[assignment] + self.token: str | None = server_info.token + self.timeout: int = server_info.timeout async def connect(self) -> list[Tool]: """ @@ -371,12 +374,11 @@ async def connect(self) -> list[Tool]: :returns: List of available tools on the server :raises MCPConnectionError: If connection to the server fails """ - sse_url = f"{self.base_url}/sse" headers = {"Authorization": f"Bearer {self.token}"} if self.token else None sse_transport = await self.exit_stack.enter_async_context( - sse_client(sse_url, headers=headers, timeout=self.timeout) + sse_client(self.url, headers=headers, timeout=self.timeout) ) - return await self._initialize_session_with_transport(sse_transport, f"HTTP server at {self.base_url}") + return await self._initialize_session_with_transport(sse_transport, f"HTTP server at {self.url}") @dataclass @@ -432,22 +434,51 @@ class SSEServerInfo(MCPServerInfo): """ Data class that encapsulates SSE MCP server connection parameters. - :param base_url: Base URL of the MCP server + :param url: Full URL of the MCP server (including /sse endpoint) + :param base_url: Base URL of the MCP server (deprecated, use url instead) :param token: Authentication token for the server (optional) :param timeout: Connection timeout in seconds """ - base_url: str + url: str | None = None + base_url: str | None = None # deprecated token: str | None = None timeout: int = 30 + def __post_init__(self): + """Validate that either url or base_url is provided.""" + if not self.url and not self.base_url: + message = "Either url or base_url must be provided" + raise ValueError(message) + if self.url and self.base_url: + message = "Only one of url or base_url should be provided, if both are provided, base_url will be ignored" + warnings.warn(message, DeprecationWarning, stacklevel=2) + + if self.base_url: + if not is_valid_http_url(self.base_url): + message = f"Invalid base_url: {self.base_url}" + raise ValueError(message) + + warnings.warn( + "base_url is deprecated and will be removed in a future version. Use url instead.", + DeprecationWarning, + stacklevel=2, + ) + # from now on only use url for the lifetime of the SSEServerInfo instance, never base_url + self.url = f"{self.base_url.rstrip('/')}/sse" + + elif not is_valid_http_url(self.url): + message = f"Invalid url: {self.url}" + raise ValueError(message) + def create_client(self) -> MCPClient: """ Create an SSE MCP client. - :returns: Configured HttpMCPClient instance + :returns: Configured MCPClient instance """ - return SSEClient(self.base_url, self.token, self.timeout) + # Pass the validated SSEServerInfo instance directly + return SSEClient(server_info=self) @dataclass @@ -491,7 +522,7 @@ class MCPTool(Tool): # Create tool instance tool = MCPTool( name="add", - server_info=SSEServerInfo(base_url="http://localhost:8000") + server_info=SSEServerInfo(url="http://localhost:8000/sse") ) # Use the tool diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index df0bd15994..93928b84b9 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any +from urllib.parse import urlparse import httpx from exceptiongroup import ExceptionGroup @@ -88,7 +89,7 @@ class MCPToolset(Toolset): # Create the toolset with an SSE connection sse_toolset = MCPToolset( - server_info=SSEServerInfo(base_url="http://some-remote-server.com:8000"), + server_info=SSEServerInfo(url="http://some-remote-server.com:8000/sse"), tool_names=["add", "subtract"] # Only include specific tools ) @@ -175,7 +176,7 @@ def invoke_tool(**kwargs) -> Any: except Exception as e: if isinstance(self.server_info, SSEServerInfo): - base_message = f"Failed to connect to SSE server at {self.server_info.base_url}" + base_message = f"Failed to connect to SSE server at {self.server_info.url}" checks = ["1. The server is running"] # Check for ConnectError in exception group or direct exception @@ -184,10 +185,24 @@ def invoke_tool(**kwargs) -> Any: ) if has_connect_error: - port = self.server_info.base_url.split(":")[-1] - checks.append(f"2. The address and port are correct (attempted port: {port})") + # Use urlparse to reliably get scheme, hostname, and port + parsed_url = urlparse(self.server_info.url) + port_str = "" + if parsed_url.port: + port_str = str(parsed_url.port) + elif parsed_url.scheme == "http": + port_str = "80 (default)" + elif parsed_url.scheme == "https": + port_str = "443 (default)" + else: + port_str = "unknown (scheme not http/https or missing)" # Or handle more schemes if needed + + # Ensure hostname is handled correctly (it might be None) + hostname_str = str(parsed_url.hostname) if parsed_url.hostname else "" + message = f"2. The address '{hostname_str}' and port '{port_str}' are correct" + checks.append(message) checks.append("3. There are no firewall or network connectivity issues") - message = f"{base_message}. Please check if:\n" + "\n".join(checks) + message = f"{base_message}. Please check if:\n" + "\\n".join(checks) else: message = f"{base_message}: {e}" elif isinstance(self.server_info, StdioServerInfo): # stdio connection diff --git a/integrations/mcp/tests/test_mcp_tool.py b/integrations/mcp/tests/test_mcp_tool.py index 2997ca6b03..cf6d106a42 100644 --- a/integrations/mcp/tests/test_mcp_tool.py +++ b/integrations/mcp/tests/test_mcp_tool.py @@ -132,6 +132,26 @@ def test_http_server_info_serde(self): assert new_info.token == "test-token" assert new_info.timeout == 45 + def test_url_base_url_validation(self): + """Test validation of url and base_url parameters.""" + # Test with neither url nor base_url + with pytest.raises(ValueError, match="Either url or base_url must be provided"): + SSEServerInfo() + + # Test with both url and base_url + with pytest.warns(DeprecationWarning, match="base_url is deprecated"): + SSEServerInfo(url="http://example.com/sse", base_url="http://example.com") + + # Test with only url + server_info = SSEServerInfo(url="http://example.com/sse") + assert server_info.url == "http://example.com/sse" + assert server_info.base_url is None + + # Test with only base_url (deprecated but supported) + with pytest.warns(DeprecationWarning, match="base_url is deprecated"): + server_info = SSEServerInfo(base_url="http://example.com") + assert server_info.base_url == "http://example.com" # Should preserve original base_url + def test_stdio_server_info_serde(self): """Test serialization/deserialization of StdioServerInfo.""" server_info = StdioServerInfo(command="python", args=["-m", "mcp_server_time"], env={"TEST_ENV": "value"}) @@ -157,7 +177,7 @@ def test_create_client(self): http_client = http_info.create_client() stdio_client = stdio_info.create_client() - assert http_client.base_url == "http://example.com" + assert http_client.url == "http://example.com/sse" assert stdio_client.command == "python"