Skip to content

Commit e130d30

Browse files
authored
feat!: Deprecate SSEServerInfo base_url in favour of url (#1662)
* Deprecate SSEServerInfo base_url * Lint and mypy * Leave base_url field until deprecation removed * Add unit test, final touches * Update example * More URL checks, better error reporting * Simplify internals * Minor simplification * Small test fix * PR feedback
1 parent 86178c1 commit e130d30

3 files changed

Lines changed: 88 additions & 22 deletions

File tree

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import concurrent.futures
77
import threading
8+
import warnings
89
from abc import ABC, abstractmethod
910
from collections.abc import Coroutine
1011
from contextlib import AsyncExitStack
@@ -16,6 +17,7 @@
1617
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
1718
from haystack.tools import Tool
1819
from haystack.tools.errors import ToolInvocationError
20+
from haystack.utils.url_validation import is_valid_http_url
1921

2022
from mcp import ClientSession, StdioServerParameters, types
2123
from mcp.client.sse import sse_client
@@ -351,18 +353,19 @@ class SSEClient(MCPClient):
351353
MCP client that connects to servers using SSE transport.
352354
"""
353355

354-
def __init__(self, base_url: str, token: str | None = None, timeout: int = 5) -> None:
356+
def __init__(self, server_info: "SSEServerInfo") -> None:
355357
"""
356-
Initialize an SSE MCP client.
358+
Initialize an SSE MCP client using server configuration.
357359
358-
:param base_url: Base URL of the server
359-
:param token: Authentication token for the server (optional)
360-
:param timeout: Connection timeout in seconds
360+
:param server_info: Configuration object containing URL, token, timeout, etc.
361361
"""
362362
super().__init__()
363-
self.base_url: str = base_url.rstrip("/") # Remove any trailing slashes
364-
self.token: str | None = token
365-
self.timeout: int = timeout
363+
364+
# in post_init we validate the url and set the url field so it is guaranteed to be valid
365+
# safely ignore the mypy warning here
366+
self.url: str = server_info.url # type: ignore[assignment]
367+
self.token: str | None = server_info.token
368+
self.timeout: int = server_info.timeout
366369

367370
async def connect(self) -> list[Tool]:
368371
"""
@@ -371,12 +374,11 @@ async def connect(self) -> list[Tool]:
371374
:returns: List of available tools on the server
372375
:raises MCPConnectionError: If connection to the server fails
373376
"""
374-
sse_url = f"{self.base_url}/sse"
375377
headers = {"Authorization": f"Bearer {self.token}"} if self.token else None
376378
sse_transport = await self.exit_stack.enter_async_context(
377-
sse_client(sse_url, headers=headers, timeout=self.timeout)
379+
sse_client(self.url, headers=headers, timeout=self.timeout)
378380
)
379-
return await self._initialize_session_with_transport(sse_transport, f"HTTP server at {self.base_url}")
381+
return await self._initialize_session_with_transport(sse_transport, f"HTTP server at {self.url}")
380382

381383

382384
@dataclass
@@ -432,22 +434,51 @@ class SSEServerInfo(MCPServerInfo):
432434
"""
433435
Data class that encapsulates SSE MCP server connection parameters.
434436
435-
:param base_url: Base URL of the MCP server
437+
:param url: Full URL of the MCP server (including /sse endpoint)
438+
:param base_url: Base URL of the MCP server (deprecated, use url instead)
436439
:param token: Authentication token for the server (optional)
437440
:param timeout: Connection timeout in seconds
438441
"""
439442

440-
base_url: str
443+
url: str | None = None
444+
base_url: str | None = None # deprecated
441445
token: str | None = None
442446
timeout: int = 30
443447

448+
def __post_init__(self):
449+
"""Validate that either url or base_url is provided."""
450+
if not self.url and not self.base_url:
451+
message = "Either url or base_url must be provided"
452+
raise ValueError(message)
453+
if self.url and self.base_url:
454+
message = "Only one of url or base_url should be provided, if both are provided, base_url will be ignored"
455+
warnings.warn(message, DeprecationWarning, stacklevel=2)
456+
457+
if self.base_url:
458+
if not is_valid_http_url(self.base_url):
459+
message = f"Invalid base_url: {self.base_url}"
460+
raise ValueError(message)
461+
462+
warnings.warn(
463+
"base_url is deprecated and will be removed in a future version. Use url instead.",
464+
DeprecationWarning,
465+
stacklevel=2,
466+
)
467+
# from now on only use url for the lifetime of the SSEServerInfo instance, never base_url
468+
self.url = f"{self.base_url.rstrip('/')}/sse"
469+
470+
elif not is_valid_http_url(self.url):
471+
message = f"Invalid url: {self.url}"
472+
raise ValueError(message)
473+
444474
def create_client(self) -> MCPClient:
445475
"""
446476
Create an SSE MCP client.
447477
448-
:returns: Configured HttpMCPClient instance
478+
:returns: Configured MCPClient instance
449479
"""
450-
return SSEClient(self.base_url, self.token, self.timeout)
480+
# Pass the validated SSEServerInfo instance directly
481+
return SSEClient(server_info=self)
451482

452483

453484
@dataclass
@@ -491,7 +522,7 @@ class MCPTool(Tool):
491522
# Create tool instance
492523
tool = MCPTool(
493524
name="add",
494-
server_info=SSEServerInfo(base_url="http://localhost:8000")
525+
server_info=SSEServerInfo(url="http://localhost:8000/sse")
495526
)
496527
497528
# Use the tool

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from typing import Any
6+
from urllib.parse import urlparse
67

78
import httpx
89
from exceptiongroup import ExceptionGroup
@@ -88,7 +89,7 @@ class MCPToolset(Toolset):
8889
8990
# Create the toolset with an SSE connection
9091
sse_toolset = MCPToolset(
91-
server_info=SSEServerInfo(base_url="http://some-remote-server.com:8000"),
92+
server_info=SSEServerInfo(url="http://some-remote-server.com:8000/sse"),
9293
tool_names=["add", "subtract"] # Only include specific tools
9394
)
9495
@@ -175,7 +176,7 @@ def invoke_tool(**kwargs) -> Any:
175176

176177
except Exception as e:
177178
if isinstance(self.server_info, SSEServerInfo):
178-
base_message = f"Failed to connect to SSE server at {self.server_info.base_url}"
179+
base_message = f"Failed to connect to SSE server at {self.server_info.url}"
179180
checks = ["1. The server is running"]
180181

181182
# Check for ConnectError in exception group or direct exception
@@ -184,10 +185,24 @@ def invoke_tool(**kwargs) -> Any:
184185
)
185186

186187
if has_connect_error:
187-
port = self.server_info.base_url.split(":")[-1]
188-
checks.append(f"2. The address and port are correct (attempted port: {port})")
188+
# Use urlparse to reliably get scheme, hostname, and port
189+
parsed_url = urlparse(self.server_info.url)
190+
port_str = ""
191+
if parsed_url.port:
192+
port_str = str(parsed_url.port)
193+
elif parsed_url.scheme == "http":
194+
port_str = "80 (default)"
195+
elif parsed_url.scheme == "https":
196+
port_str = "443 (default)"
197+
else:
198+
port_str = "unknown (scheme not http/https or missing)" # Or handle more schemes if needed
199+
200+
# Ensure hostname is handled correctly (it might be None)
201+
hostname_str = str(parsed_url.hostname) if parsed_url.hostname else "<unknown>"
202+
message = f"2. The address '{hostname_str}' and port '{port_str}' are correct"
203+
checks.append(message)
189204
checks.append("3. There are no firewall or network connectivity issues")
190-
message = f"{base_message}. Please check if:\n" + "\n".join(checks)
205+
message = f"{base_message}. Please check if:\n" + "\\n".join(checks)
191206
else:
192207
message = f"{base_message}: {e}"
193208
elif isinstance(self.server_info, StdioServerInfo): # stdio connection

integrations/mcp/tests/test_mcp_tool.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,26 @@ def test_http_server_info_serde(self):
132132
assert new_info.token == "test-token"
133133
assert new_info.timeout == 45
134134

135+
def test_url_base_url_validation(self):
136+
"""Test validation of url and base_url parameters."""
137+
# Test with neither url nor base_url
138+
with pytest.raises(ValueError, match="Either url or base_url must be provided"):
139+
SSEServerInfo()
140+
141+
# Test with both url and base_url
142+
with pytest.warns(DeprecationWarning, match="base_url is deprecated"):
143+
SSEServerInfo(url="http://example.com/sse", base_url="http://example.com")
144+
145+
# Test with only url
146+
server_info = SSEServerInfo(url="http://example.com/sse")
147+
assert server_info.url == "http://example.com/sse"
148+
assert server_info.base_url is None
149+
150+
# Test with only base_url (deprecated but supported)
151+
with pytest.warns(DeprecationWarning, match="base_url is deprecated"):
152+
server_info = SSEServerInfo(base_url="http://example.com")
153+
assert server_info.base_url == "http://example.com" # Should preserve original base_url
154+
135155
def test_stdio_server_info_serde(self):
136156
"""Test serialization/deserialization of StdioServerInfo."""
137157
server_info = StdioServerInfo(command="python", args=["-m", "mcp_server_time"], env={"TEST_ENV": "value"})
@@ -157,7 +177,7 @@ def test_create_client(self):
157177
http_client = http_info.create_client()
158178
stdio_client = stdio_info.create_client()
159179

160-
assert http_client.base_url == "http://example.com"
180+
assert http_client.url == "http://example.com/sse"
161181
assert stdio_client.command == "python"
162182

163183

0 commit comments

Comments
 (0)