Skip to content

Commit 657ff18

Browse files
committed
Refactor to increase code reuse, simplify
1 parent a65926e commit 657ff18

2 files changed

Lines changed: 244 additions & 170 deletions

File tree

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py

Lines changed: 224 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,207 @@ def create_client(self) -> MCPClient:
632632
return StdioClient(self.command, self.args, self.env)
633633

634634

635+
def _extract_error_message(exception: Exception) -> str:
636+
"""
637+
Extracts meaningful error message from various exception types.
638+
Handles ExceptionGroup, empty messages, etc.
639+
"""
640+
error_message = str(exception)
641+
# Handle ExceptionGroup to extract more useful error messages
642+
if isinstance(exception, ExceptionGroup):
643+
if exception.exceptions:
644+
first_exception = exception.exceptions[0]
645+
error_message = first_exception.message if hasattr(first_exception, "message") else str(first_exception)
646+
647+
# Ensure we always have a meaningful error message
648+
if not error_message or error_message.strip() == "":
649+
# Provide platform-independent fallback message for connection errors
650+
error_message = "Connection failed to MCP server"
651+
652+
return error_message
653+
654+
655+
def _create_stdio_connection_error_message(server_info: StdioServerInfo, operation: str, context: str) -> str:
656+
"""
657+
Creates stdio-specific error messages with command troubleshooting.
658+
"""
659+
base_message = f"Failed to {operation} {context} via stdio"
660+
661+
# Build command string for diagnostics
662+
args_str = " ".join(server_info.args) if server_info.args else ""
663+
cmd = f"{server_info.command}{' ' + args_str if args_str else ''}"
664+
665+
checks = [f"1. The command and arguments are correct (attempted: {cmd})"]
666+
667+
return f"{base_message}. Please check if:\n" + "\n".join(checks)
668+
669+
670+
def _create_http_connection_error_message(
671+
server_info: SSEServerInfo | StreamableHttpServerInfo, exception: Exception, operation: str, context: str
672+
) -> str:
673+
"""
674+
Creates HTTP-specific error messages with troubleshooting guidance.
675+
"""
676+
# Determine transport type
677+
transport_name = "SSE" if isinstance(server_info, SSEServerInfo) else "streamable HTTP"
678+
server_url = server_info.url
679+
680+
base_message = f"Failed to {operation} {context} via {transport_name}"
681+
682+
# Standard troubleshooting steps
683+
checks = [
684+
f"1. The server URL is correct (attempted: {server_url})",
685+
"2. The server is running and accessible",
686+
"3. Authentication token is correct (if required)",
687+
]
688+
689+
# Check if exception indicates a network connection error
690+
import httpx
691+
692+
has_connect_error = isinstance(exception, httpx.ConnectError) or (
693+
isinstance(exception, ExceptionGroup)
694+
and any(isinstance(exc, httpx.ConnectError) for exc in exception.exceptions)
695+
)
696+
697+
# Add network-specific guidance for connection errors
698+
if has_connect_error:
699+
from urllib.parse import urlparse
700+
701+
# Use urlparse to reliably get scheme, hostname, and port
702+
parsed_url = urlparse(server_url)
703+
port_str = ""
704+
if parsed_url.port:
705+
port_str = str(parsed_url.port)
706+
elif parsed_url.scheme == "http":
707+
port_str = "80 (default)"
708+
elif parsed_url.scheme == "https":
709+
port_str = "443 (default)"
710+
else:
711+
port_str = "unknown (scheme not http/https or missing)"
712+
713+
# Ensure hostname is handled correctly (it might be None)
714+
hostname_str = str(parsed_url.hostname) if parsed_url.hostname else "<unknown>"
715+
716+
checks[1] = f"2. The address '{hostname_str}' and port '{port_str}' are correct"
717+
checks.append("4. There are no firewall or network connectivity issues")
718+
719+
return f"{base_message}. Please check if:\n" + "\n".join(checks)
720+
721+
722+
def _create_connection_error_message(
723+
server_info: MCPServerInfo, exception: Exception, operation: str, context: str = ""
724+
) -> str:
725+
"""
726+
Creates contextual error messages based on server type and failure details.
727+
This replaces the duplicate error handling blocks in both classes.
728+
"""
729+
730+
# Generate server-type specific guidance
731+
if isinstance(server_info, SSEServerInfo | StreamableHttpServerInfo):
732+
return _create_http_connection_error_message(server_info, exception, operation, context)
733+
elif isinstance(server_info, StdioServerInfo):
734+
return _create_stdio_connection_error_message(server_info, operation, context)
735+
else:
736+
error_message = _extract_error_message(exception)
737+
return f"Failed to {operation} {context}: {error_message}"
738+
739+
740+
class MCPConnectionManager:
741+
"""
742+
Utility class that encapsulates common MCP connection logic shared between
743+
MCPTool and MCPToolset.
744+
"""
745+
746+
def __init__(self, server_info: MCPServerInfo, connection_timeout: float):
747+
self.server_info = server_info
748+
self.connection_timeout = connection_timeout
749+
self._client = None
750+
self._worker = None
751+
752+
def connect_and_discover_tools(self) -> list[Tool]:
753+
"""
754+
Establishes connection and returns available tools.
755+
This replaces the duplicate connection logic in both classes.
756+
"""
757+
try:
758+
# Create the client and spin up a worker so open/close happen in the
759+
# same coroutine, avoiding AnyIO cancel-scope issues.
760+
self._client = self.server_info.create_client()
761+
self._worker = _MCPClientSessionManager(self._client, timeout=self.connection_timeout)
762+
return self._worker.tools()
763+
except Exception:
764+
# Handle cleanup internally
765+
self.close()
766+
raise
767+
768+
def validate_requested_tools(self, requested_tool_names: list[str], available_tools: list[Tool]) -> None:
769+
"""
770+
Validates that requested tools exist on the server.
771+
Shared validation logic between both classes.
772+
"""
773+
available_tool_names = {tool.name for tool in available_tools}
774+
missing_tools = set(requested_tool_names) - available_tool_names
775+
if missing_tools:
776+
message = (
777+
f"The following tools were not found: {', '.join(missing_tools)}. "
778+
f"Available tools: {', '.join(available_tool_names)}"
779+
)
780+
raise MCPToolNotFoundError(
781+
message=message, tool_name=next(iter(missing_tools)), available_tools=list(available_tool_names)
782+
)
783+
784+
def create_tool_invoke_function(self, tool_name: str, invocation_timeout: float):
785+
"""
786+
Creates the invoke function that both classes use.
787+
MCPTool uses this directly, MCPToolset uses this in its closure factory.
788+
"""
789+
790+
def invoke_tool(**kwargs) -> str:
791+
"""Unified invoke logic - no more duplication"""
792+
logger.debug(f"Invoking tool '{tool_name}' with args: {kwargs}")
793+
try:
794+
795+
async def invoke():
796+
logger.debug(f"Inside invoke coroutine for '{tool_name}'")
797+
result = await asyncio.wait_for(
798+
self._client.call_tool(tool_name, kwargs), timeout=invocation_timeout
799+
)
800+
logger.debug(f"Invoke successful for '{tool_name}'")
801+
return result
802+
803+
logger.debug(f"About to run invoke for '{tool_name}'")
804+
result = AsyncExecutor.get_instance().run(invoke(), timeout=invocation_timeout)
805+
logger.debug(f"Invoke complete for '{tool_name}', result type: {type(result)}")
806+
return result
807+
except (MCPError, TimeoutError) as e:
808+
logger.debug(f"Known error during invoke of '{tool_name}': {e!s}")
809+
# Pass through known errors
810+
raise
811+
except Exception as e:
812+
# Wrap other errors
813+
logger.debug(f"Unknown error during invoke of '{tool_name}': {e!s}")
814+
message = f"Failed to invoke tool '{tool_name}' with args: {kwargs} , got error: {e!s}"
815+
raise MCPInvocationError(message, tool_name, kwargs) from e
816+
817+
return invoke_tool
818+
819+
def get_client(self):
820+
"""Allow direct access to client for MCPTool's async method access"""
821+
return self._client
822+
823+
def close(self):
824+
"""Shared cleanup logic"""
825+
if hasattr(self, "_worker") and self._worker:
826+
try:
827+
self._worker.stop()
828+
except Exception as e:
829+
logger.debug(f"Error during worker stop: {e!s}")
830+
831+
# Clear references
832+
self._worker = None
833+
self._client = None
834+
835+
635836
class MCPTool(Tool):
636837
"""
637838
A Tool that represents a single tool from an MCP server.
@@ -706,41 +907,35 @@ def __init__(
706907
logger.debug(f"TOOL: Initializing MCPTool '{name}'")
707908

708909
try:
709-
# Create client and spin up a long-lived worker that keeps the
710-
# connect/close lifecycle inside one coroutine.
711-
self._client = server_info.create_client()
712-
logger.debug(f"TOOL: Created client for MCPTool '{name}'")
910+
# Use shared connection logic
911+
self._connection_manager = MCPConnectionManager(server_info, connection_timeout)
912+
available_tools = self._connection_manager.connect_and_discover_tools()
713913

714-
# The worker starts immediately and blocks here until the connection
715-
# is established (or fails), returning the tool list.
716-
self._worker = _MCPClientSessionManager(self._client, timeout=connection_timeout)
717-
718-
tools = self._worker.tools()
719914
# Handle no tools case
720-
if not tools:
915+
if not available_tools:
721916
logger.debug(f"TOOL: No tools found for '{name}'")
722917
message = "No tools available on server"
723918
raise MCPToolNotFoundError(message, tool_name=name)
724919

920+
# Validate that the requested tool exists
921+
self._connection_manager.validate_requested_tools([name], available_tools)
922+
725923
# Find the specified tool
726-
tool_dict = {t.name: t for t in tools}
924+
tool_dict = {t.name: t for t in available_tools}
727925
logger.debug(f"TOOL: Available tools: {list(tool_dict.keys())}")
926+
tool_info = tool_dict[name] # Safe to use direct access since validation passed
728927

729-
tool_info = tool_dict.get(name)
928+
logger.debug(f"TOOL: Found tool '{name}', initializing Tool parent class")
730929

731-
if not tool_info:
732-
available = list(tool_dict.keys())
733-
logger.debug(f"TOOL: Tool '{name}' not found in available tools")
734-
message = f"Tool '{name}' not found on server. Available tools: {', '.join(available)}"
735-
raise MCPToolNotFoundError(message, tool_name=name, available_tools=available)
930+
# Create shared invoke function
931+
invoke_func = self._connection_manager.create_tool_invoke_function(name, invocation_timeout)
736932

737-
logger.debug(f"TOOL: Found tool '{name}', initializing Tool parent class")
738933
# Initialize the parent class
739934
super().__init__(
740935
name=name,
741936
description=description or tool_info.description,
742937
parameters=tool_info.inputSchema,
743-
function=self._invoke_tool,
938+
function=invoke_func,
744939
)
745940
logger.debug(f"TOOL: Initialization complete for '{name}'")
746941

@@ -749,55 +944,11 @@ def __init__(
749944
# fail because of an MCPToolNotFoundError
750945
self.close()
751946

752-
# Extract more detailed error information from TaskGroup/ExceptionGroup exceptions
753-
error_message = str(e)
754-
# Handle ExceptionGroup to extract more useful error messages
755-
if isinstance(e, ExceptionGroup):
756-
if e.exceptions:
757-
first_exception = e.exceptions[0]
758-
error_message = (
759-
first_exception.message if hasattr(first_exception, "message") else str(first_exception)
760-
)
761-
762-
# Ensure we always have a meaningful error message
763-
if not error_message or error_message.strip() == "":
764-
# Provide platform-independent fallback message for connection errors
765-
error_message = f"Connection failed to MCP server (using {type(server_info).__name__})"
766-
767-
message = f"Failed to initialize MCPTool '{name}': {error_message}"
768-
raise MCPConnectionError(message=message, server_info=server_info, operation="initialize") from e
769-
770-
def _invoke_tool(self, **kwargs: Any) -> str:
771-
"""
772-
Synchronous tool invocation.
773-
774-
:param kwargs: Arguments to pass to the tool
775-
:returns: JSON string representation of the tool invocation result
776-
"""
777-
logger.debug(f"TOOL: Invoking tool '{self.name}' with args: {kwargs}")
778-
try:
779-
780-
async def invoke():
781-
logger.debug(f"TOOL: Inside invoke coroutine for '{self.name}'")
782-
result = await asyncio.wait_for(
783-
self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout
784-
)
785-
logger.debug(f"TOOL: Invoke successful for '{self.name}'")
786-
return result
787-
788-
logger.debug(f"TOOL: About to run invoke for '{self.name}'")
789-
result = AsyncExecutor.get_instance().run(invoke(), timeout=self._invocation_timeout)
790-
logger.debug(f"TOOL: Invoke complete for '{self.name}', result type: {type(result)}")
791-
return result
792-
except (MCPError, TimeoutError) as e:
793-
logger.debug(f"TOOL: Known error during invoke of '{self.name}': {e!s}")
794-
# Pass through known errors
795-
raise
796-
except Exception as e:
797-
# Wrap other errors
798-
logger.debug(f"TOOL: Unknown error during invoke of '{self.name}': {e!s}")
799-
message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}"
800-
raise MCPInvocationError(message, self.name, kwargs) from e
947+
# Use shared error handling
948+
error_message = _create_connection_error_message(
949+
server_info=server_info, exception=e, operation="initialize", context=f"MCPTool '{name}'"
950+
)
951+
raise MCPConnectionError(message=error_message, server_info=server_info, operation="initialize") from e
801952

802953
async def ainvoke(self, **kwargs: Any) -> str:
803954
"""
@@ -809,7 +960,8 @@ async def ainvoke(self, **kwargs: Any) -> str:
809960
:raises TimeoutError: If the operation times out
810961
"""
811962
try:
812-
return await asyncio.wait_for(self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
963+
client = self._connection_manager.get_client()
964+
return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
813965
except asyncio.TimeoutError as e:
814966
message = f"Tool invocation timed out after {self._invocation_timeout} seconds"
815967
raise TimeoutError(message) from e
@@ -881,13 +1033,11 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
8811033

8821034
def close(self):
8831035
"""Close the tool synchronously."""
884-
if hasattr(self, "_client") and self._client:
1036+
if hasattr(self, "_connection_manager") and self._connection_manager:
8851037
try:
886-
# Tell the background worker to shut down gracefully.
887-
if hasattr(self, "_worker") and self._worker:
888-
self._worker.stop()
1038+
self._connection_manager.close()
8891039
except Exception as e:
890-
logger.debug(f"TOOL: Error during synchronous worker stop: {e!s}")
1040+
logger.debug(f"TOOL: Error during connection manager close: {e!s}")
8911041

8921042
def __del__(self):
8931043
"""Cleanup resources when the tool is garbage collected."""

0 commit comments

Comments
 (0)