Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import concurrent.futures
import threading
import warnings
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from contextlib import AsyncExitStack
Expand All @@ -16,6 +17,7 @@
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
from haystack.tools import Tool
from haystack.tools.errors import ToolInvocationError
from haystack.utils.url_validation import is_valid_http_url

from mcp import ClientSession, StdioServerParameters, types
from mcp.client.sse import sse_client
Expand Down Expand Up @@ -351,18 +353,19 @@ class SSEClient(MCPClient):
MCP client that connects to servers using SSE transport.
"""

def __init__(self, base_url: str, token: str | None = None, timeout: int = 5) -> None:
def __init__(self, server_info: "SSEServerInfo") -> None:
"""
Initialize an SSE MCP client.
Initialize an SSE MCP client using server configuration.

:param base_url: Base URL of the server
:param token: Authentication token for the server (optional)
:param timeout: Connection timeout in seconds
:param server_info: Configuration object containing URL, token, timeout, etc.
"""
super().__init__()
self.base_url: str = base_url.rstrip("/") # Remove any trailing slashes
self.token: str | None = token
self.timeout: int = timeout

# in post_init we validate the url and set the url field so it is guaranteed to be valid
# safely ignore the mypy warning here
self.url: str = server_info.url # type: ignore[assignment]
self.token: str | None = server_info.token
self.timeout: int = server_info.timeout

async def connect(self) -> list[Tool]:
"""
Expand All @@ -371,12 +374,11 @@ async def connect(self) -> list[Tool]:
:returns: List of available tools on the server
:raises MCPConnectionError: If connection to the server fails
"""
sse_url = f"{self.base_url}/sse"
headers = {"Authorization": f"Bearer {self.token}"} if self.token else None
sse_transport = await self.exit_stack.enter_async_context(
sse_client(sse_url, headers=headers, timeout=self.timeout)
sse_client(self.url, headers=headers, timeout=self.timeout)
)
return await self._initialize_session_with_transport(sse_transport, f"HTTP server at {self.base_url}")
return await self._initialize_session_with_transport(sse_transport, f"HTTP server at {self.url}")


@dataclass
Expand Down Expand Up @@ -432,22 +434,51 @@ class SSEServerInfo(MCPServerInfo):
"""
Data class that encapsulates SSE MCP server connection parameters.

:param base_url: Base URL of the MCP server
:param url: Full URL of the MCP server (including /sse endpoint)
:param base_url: Base URL of the MCP server (deprecated, use url instead)
:param token: Authentication token for the server (optional)
:param timeout: Connection timeout in seconds
"""

base_url: str
url: str | None = None
base_url: str | None = None # deprecated
token: str | None = None
timeout: int = 30

def __post_init__(self):
"""Validate that either url or base_url is provided."""
if not self.url and not self.base_url:
message = "Either url or base_url must be provided"
raise ValueError(message)
if self.url and self.base_url:
message = "Only one of url or base_url should be provided, if both are provided, base_url will be ignored"
warnings.warn(message, DeprecationWarning, stacklevel=2)

if self.base_url:
if not is_valid_http_url(self.base_url):
message = f"Invalid base_url: {self.base_url}"
raise ValueError(message)

warnings.warn(
"base_url is deprecated and will be removed in a future version. Use url instead.",
DeprecationWarning,
stacklevel=2,
)
# from now on only use url for the lifetime of the SSEServerInfo instance, never base_url
self.url = f"{self.base_url.rstrip('/')}/sse"

elif not is_valid_http_url(self.url):
message = f"Invalid url: {self.url}"
raise ValueError(message)

def create_client(self) -> MCPClient:
"""
Create an SSE MCP client.

:returns: Configured HttpMCPClient instance
:returns: Configured MCPClient instance
"""
return SSEClient(self.base_url, self.token, self.timeout)
# Pass the validated SSEServerInfo instance directly
return SSEClient(server_info=self)


@dataclass
Expand Down Expand Up @@ -491,7 +522,7 @@ class MCPTool(Tool):
# Create tool instance
tool = MCPTool(
name="add",
server_info=SSEServerInfo(base_url="http://localhost:8000")
server_info=SSEServerInfo(url="http://localhost:8000/sse")
)

# Use the tool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from urllib.parse import urlparse

import httpx
from exceptiongroup import ExceptionGroup
Expand Down Expand Up @@ -88,7 +89,7 @@ class MCPToolset(Toolset):

# Create the toolset with an SSE connection
sse_toolset = MCPToolset(
server_info=SSEServerInfo(base_url="http://some-remote-server.com:8000"),
server_info=SSEServerInfo(url="http://some-remote-server.com:8000/sse"),
tool_names=["add", "subtract"] # Only include specific tools
)

Expand Down Expand Up @@ -175,7 +176,7 @@ def invoke_tool(**kwargs) -> Any:

except Exception as e:
if isinstance(self.server_info, SSEServerInfo):
base_message = f"Failed to connect to SSE server at {self.server_info.base_url}"
base_message = f"Failed to connect to SSE server at {self.server_info.url}"
checks = ["1. The server is running"]

# Check for ConnectError in exception group or direct exception
Expand All @@ -184,10 +185,24 @@ def invoke_tool(**kwargs) -> Any:
)

if has_connect_error:
port = self.server_info.base_url.split(":")[-1]
checks.append(f"2. The address and port are correct (attempted port: {port})")
# Use urlparse to reliably get scheme, hostname, and port
parsed_url = urlparse(self.server_info.url)
port_str = ""
if parsed_url.port:
port_str = str(parsed_url.port)
elif parsed_url.scheme == "http":
port_str = "80 (default)"
elif parsed_url.scheme == "https":
port_str = "443 (default)"
else:
port_str = "unknown (scheme not http/https or missing)" # Or handle more schemes if needed

# Ensure hostname is handled correctly (it might be None)
hostname_str = str(parsed_url.hostname) if parsed_url.hostname else "<unknown>"
message = f"2. The address '{hostname_str}' and port '{port_str}' are correct"
checks.append(message)
checks.append("3. There are no firewall or network connectivity issues")
message = f"{base_message}. Please check if:\n" + "\n".join(checks)
message = f"{base_message}. Please check if:\n" + "\\n".join(checks)
else:
message = f"{base_message}: {e}"
elif isinstance(self.server_info, StdioServerInfo): # stdio connection
Expand Down
22 changes: 21 additions & 1 deletion integrations/mcp/tests/test_mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,26 @@ def test_http_server_info_serde(self):
assert new_info.token == "test-token"
assert new_info.timeout == 45

def test_url_base_url_validation(self):
"""Test validation of url and base_url parameters."""
# Test with neither url nor base_url
with pytest.raises(ValueError, match="Either url or base_url must be provided"):
SSEServerInfo()

# Test with both url and base_url
with pytest.warns(DeprecationWarning, match="base_url is deprecated"):
SSEServerInfo(url="http://example.com/sse", base_url="http://example.com")

# Test with only url
server_info = SSEServerInfo(url="http://example.com/sse")
assert server_info.url == "http://example.com/sse"
assert server_info.base_url is None

# Test with only base_url (deprecated but supported)
with pytest.warns(DeprecationWarning, match="base_url is deprecated"):
server_info = SSEServerInfo(base_url="http://example.com")
assert server_info.base_url == "http://example.com" # Should preserve original base_url

def test_stdio_server_info_serde(self):
"""Test serialization/deserialization of StdioServerInfo."""
server_info = StdioServerInfo(command="python", args=["-m", "mcp_server_time"], env={"TEST_ENV": "value"})
Expand All @@ -157,7 +177,7 @@ def test_create_client(self):
http_client = http_info.create_client()
stdio_client = stdio_info.create_client()

assert http_client.base_url == "http://example.com"
assert http_client.url == "http://example.com/sse"
assert stdio_client.command == "python"


Expand Down