Skip to content

Commit 6d81a71

Browse files
vblagojempangrazzi
andauthored
feat: MCPTool and MCPToolset async resource management improvements (#1758)
* Add MCPClientSessionManager to connect/close mcp clients * Update and refactor mcp tests * More descriptive connection error raising * Proper test cleanup * Testing CI windows * linting * Improve connection error raise * PR feedback * Proper naming, and more precise cleanup sequence --------- Co-authored-by: Michele Pangrazzi <xmikex83@gmail.com>
1 parent 2531848 commit 6d81a71

9 files changed

Lines changed: 866 additions & 627 deletions

File tree

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py

Lines changed: 173 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import threading
88
import warnings
99
from abc import ABC, abstractmethod
10-
from collections.abc import Coroutine
10+
from collections.abc import Callable, Coroutine
11+
from concurrent.futures import Future
1112
from contextlib import AsyncExitStack
1213
from dataclasses import dataclass, fields
1314
from typing import Any, cast
@@ -73,6 +74,47 @@ def run(self, coro: Coroutine[Any, Any, Any], timeout: float | None = None) -> A
7374
message = f"Operation timed out after {timeout} seconds"
7475
raise TimeoutError(message) from e
7576

77+
def get_loop(self):
78+
"""
79+
Get the event loop.
80+
81+
:returns: The event loop
82+
"""
83+
return self._loop
84+
85+
def run_background(
86+
self, coro_factory: Callable[[asyncio.Event], Coroutine[Any, Any, Any]], timeout: float | None = None
87+
) -> tuple[concurrent.futures.Future[Any], asyncio.Event]:
88+
"""
89+
Schedule `coro_factory` to run in the executor's event loop **without** blocking the
90+
caller thread.
91+
92+
The factory receives an :class:`asyncio.Event` that can be used to cooperatively shut
93+
the coroutine down. The method returns **both** the concurrent future (to observe
94+
completion or failure) and the created *stop_event* so that callers can signal termination.
95+
96+
:param coro_factory: A callable receiving the stop_event and returning the coroutine to execute.
97+
:param timeout: Optional timeout while waiting for the stop_event to be created.
98+
:returns: Tuple ``(future, stop_event)``.
99+
"""
100+
# A promise that will be fulfilled from inside the coroutine_with_stop_event coroutine once the
101+
# stop_event is created *inside* the target event loop to ensure it is bound to the
102+
# correct loop and can safely be set from other threads via *call_soon_threadsafe*.
103+
stop_event_promise: Future[asyncio.Event] = Future()
104+
105+
async def _coroutine_with_stop_event():
106+
stop_event = asyncio.Event()
107+
stop_event_promise.set_result(stop_event)
108+
await coro_factory(stop_event)
109+
110+
# Schedule the coroutine
111+
future = asyncio.run_coroutine_threadsafe(_coroutine_with_stop_event(), self._loop)
112+
113+
# This ensures that the stop_event is fully initialized and ready for use before
114+
# the run_background method returns, allowing the caller to immediately
115+
# use it to control the coroutine.
116+
return future, stop_event_promise.result(timeout)
117+
76118
def shutdown(self, timeout: float = 2):
77119
"""
78120
Shut down the background event loop and thread.
@@ -213,7 +255,7 @@ async def call_tool(self, tool_name: str, tool_args: dict[str, Any]) -> Any:
213255
raise
214256
except Exception as e:
215257
# Wrap other exceptions with context about which tool failed
216-
message = f"Failed to invoke tool '{tool_name}'"
258+
message = f"Failed to invoke tool '{tool_name}' due to: {e}"
217259
raise MCPInvocationError(message, tool_name, tool_args) from e
218260

219261
def _validate_response(self, tool_name: str, result: types.CallToolResult) -> types.CallToolResult:
@@ -254,7 +296,7 @@ def _validate_response(self, tool_name: str, result: types.CallToolResult) -> ty
254296
# Return the original result object
255297
return result
256298

257-
async def close(self) -> None:
299+
async def aclose(self) -> None:
258300
"""
259301
Close the connection and clean up resources.
260302
@@ -273,15 +315,6 @@ async def close(self) -> None:
273315
self.stdio = None
274316
self.write = None
275317

276-
def close_sync(self) -> None:
277-
"""Synchronous version of close for use in __del__ - ensures resources are cleaned up."""
278-
logger.debug("PROCESS: Closing StdioClient (sync)")
279-
280-
try:
281-
AsyncExecutor.get_instance().run(self.close(), timeout=2)
282-
except Exception as e:
283-
logger.debug(f"PROCESS: Error during async cleanup in sync close: {e!s}")
284-
285318
async def _initialize_session_with_transport(
286319
self,
287320
transport_tuple: tuple[
@@ -310,7 +343,7 @@ async def _initialize_session_with_transport(
310343
return response.tools
311344

312345
except Exception as e:
313-
await self.close()
346+
# We'll clean up the session in the calling code, so we don't need to do it here.
314347
message = f"Failed to connect to {connection_type}: {e}"
315348
raise MCPConnectionError(message=message, operation="connect") from e
316349

@@ -572,22 +605,17 @@ def __init__(
572605

573606
logger.debug(f"TOOL: Initializing MCPTool '{name}'")
574607

575-
# Create client
576-
self._client = server_info.create_client()
577-
logger.debug(f"TOOL: Created client for MCPTool '{name}'")
578-
579608
try:
609+
# Create client and spin up a long-lived worker that keeps the
610+
# connect/close lifecycle inside one coroutine.
611+
self._client = server_info.create_client()
612+
logger.debug(f"TOOL: Created client for MCPTool '{name}'")
580613

581-
async def connect():
582-
logger.debug(f"TOOL: Inside connect coroutine for '{name}'")
583-
result = await asyncio.wait_for(self._client.connect(), timeout=connection_timeout)
584-
logger.debug(f"TOOL: Connect successful for '{name}', found {len(result)} tools")
585-
return result
586-
587-
logger.debug(f"TOOL: About to run connect for '{name}'")
588-
tools = AsyncExecutor.get_instance().run(connect(), timeout=connection_timeout)
589-
logger.debug(f"TOOL: Connection complete for '{name}'")
614+
# The worker starts immediately and blocks here until the connection
615+
# is established (or fails), returning the tool list.
616+
self._worker = _MCPClientSessionManager(self._client, timeout=connection_timeout)
590617

618+
tools = self._worker.tools()
591619
# Handle no tools case
592620
if not tools:
593621
logger.debug(f"TOOL: No tools found for '{name}'")
@@ -617,17 +645,28 @@ async def connect():
617645
logger.debug(f"TOOL: Initialization complete for '{name}'")
618646

619647
except Exception as e:
620-
# Clean up resources on error
621-
logger.debug(f"TOOL: Error during initialization of '{name}': {e!s}")
622-
if self._client:
623-
try:
624-
logger.debug(f"TOOL: Attempting cleanup after initialization failure for '{name}'")
625-
AsyncExecutor.get_instance().run(self._client.close(), timeout=5)
626-
logger.debug(f"TOOL: Cleanup successful for '{name}'")
627-
except Exception as cleanup_error:
628-
logger.debug(f"TOOL: Error during cleanup after initialization failure: {cleanup_error!s}")
629-
630-
message = f"Failed to initialize MCPTool '{name}': {e}"
648+
# We need to close because we could connect properly, retrieve tools yet
649+
# fail because of an MCPToolNotFoundError
650+
self.close()
651+
652+
# Extract more detailed error information from TaskGroup/ExceptionGroup exceptions
653+
from exceptiongroup import ExceptionGroup
654+
655+
error_message = str(e)
656+
# Handle ExceptionGroup to extract more useful error messages
657+
if isinstance(e, ExceptionGroup):
658+
if e.exceptions:
659+
first_exception = e.exceptions[0]
660+
error_message = (
661+
first_exception.message if hasattr(first_exception, "message") else str(first_exception)
662+
)
663+
664+
# Ensure we always have a meaningful error message
665+
if not error_message or error_message.strip() == "":
666+
# Provide platform-independent fallback message for connection errors
667+
error_message = f"Connection failed to MCP server (using {type(server_info).__name__})"
668+
669+
message = f"Failed to initialize MCPTool '{name}': {error_message}"
631670
raise MCPConnectionError(message=message, server_info=server_info, operation="initialize") from e
632671

633672
def _invoke_tool(self, **kwargs: Any) -> Any:
@@ -743,13 +782,106 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":
743782
invocation_timeout=invocation_timeout,
744783
)
745784

785+
def close(self):
786+
"""Close the tool synchronously."""
787+
if hasattr(self, "_client") and self._client:
788+
try:
789+
# Tell the background worker to shut down gracefully.
790+
if hasattr(self, "_worker") and self._worker:
791+
self._worker.stop()
792+
except Exception as e:
793+
logger.debug(f"TOOL: Error during synchronous worker stop: {e!s}")
794+
746795
def __del__(self):
747796
"""Cleanup resources when the tool is garbage collected."""
748797
logger.debug(f"TOOL: __del__ called for MCPTool '{self.name if hasattr(self, 'name') else 'unknown'}'")
749798

750-
# Call synchronous close on the client
751-
if hasattr(self, "_client") and self._client:
799+
self.close()
800+
801+
802+
class _MCPClientSessionManager:
803+
"""Runs an MCPClient connect/close inside the AsyncExecutor's event loop.
804+
805+
Life-cycle:
806+
1. Create the worker to schedule a long-running coroutine in the
807+
dedicated background loop.
808+
2. The coroutine calls *connect* on mcp client; when it has the tool list it fulfils
809+
a concurrent future so the synchronous thread can continue.
810+
3. It then waits on an `asyncio.Event`.
811+
4. `stop()` sets the event from any thread. The same coroutine then calls
812+
*close()* on mcp client and finishes without the dreaded
813+
`Attempted to exit cancel scope in a different task than it was entered in` error
814+
thus properly closing the client.
815+
"""
816+
817+
# Maximum time to wait for worker shutdown in seconds
818+
WORKER_SHUTDOWN_TIMEOUT = 2.0
819+
820+
def __init__(self, client: "MCPClient", *, timeout: float | None = None):
821+
self._client = client
822+
self.executor = AsyncExecutor.get_instance()
823+
824+
# Where the tool list (or an exception) will be delivered.
825+
self._tools_promise: Future[list[Tool]] = Future()
826+
827+
# Kick off the worker coroutine in the background loop
828+
self._worker_future, self._stop_event = self.executor.run_background(self._run, timeout=None)
829+
830+
# Wait (in the caller thread) until connect() finishes or raises.
831+
try:
832+
self._tools_promise.result(timeout)
833+
except BaseException:
834+
# If connect failed we should cancel the worker so it doesn't hang.
835+
self.stop()
836+
raise
837+
838+
def tools(self) -> list[Tool]:
839+
"""Return the tool list already collected during startup."""
840+
841+
return self._tools_promise.result()
842+
843+
def stop(self) -> None:
844+
"""Request the worker to shut down and block until done."""
845+
846+
def _set(ev: asyncio.Event):
847+
if not ev.is_set():
848+
ev.set()
849+
850+
if self.executor.get_loop().is_closed():
851+
return
852+
853+
# The stop event is created inside the worker *before* the connect
854+
# promise is fulfilled, so at this point it must exist.
855+
self.executor.get_loop().call_soon_threadsafe(_set, self._stop_event) # type: ignore[attr-defined]
856+
857+
# Wait for the worker coroutine to finish so resources are fully
858+
# released before returning. Swallow any errors during shutdown.
859+
try:
860+
self._worker_future.result(timeout=self.WORKER_SHUTDOWN_TIMEOUT)
861+
except Exception as e:
862+
logger.debug(f"Error during worker future result: {e}")
863+
pass
864+
865+
async def _run(self, stop_event: asyncio.Event):
866+
"""Background coroutine living in AsyncExecutor's loop."""
867+
868+
try:
869+
# logger.debug(f"TOOL: _run current task: {asyncio.current_task()}")
870+
tools = await self._client.connect()
871+
# Deliver the tool list to the waiting synchronous code.
872+
if not self._tools_promise.done():
873+
self._tools_promise.set_result(tools)
874+
# Park until told to stop.
875+
await stop_event.wait()
876+
except Exception as exc:
877+
logger.debug(f"Error during _run: {exc}")
878+
if not self._tools_promise.done():
879+
self._tools_promise.set_exception(exc)
880+
raise
881+
finally:
882+
# logger.debug(f"TOOL: _run current task: {asyncio.current_task()}")
883+
# Close the client in the same couroutine that connected it
752884
try:
753-
self._client.close_sync()
885+
await self._client.aclose()
754886
except Exception as e:
755-
logger.debug(f"TOOL: Error during synchronous client close: {e!s}")
887+
logger.debug(f"Error during MCP client cleanup: {e!s}")

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
MCPToolNotFoundError,
1919
SSEServerInfo,
2020
StdioServerInfo,
21+
_MCPClientSessionManager,
2122
)
2223

2324
logger = logging.getLogger(__name__)
@@ -122,11 +123,12 @@ def __init__(
122123

123124
# Connect and load tools
124125
try:
125-
# Create the appropriate client using the factory method
126+
# Create the client and spin up a worker so open/close happen in the
127+
# same coroutine, avoiding AnyIO cancel-scope issues.
126128
client = self.server_info.create_client()
129+
self._worker = _MCPClientSessionManager(client, timeout=self.connection_timeout)
127130

128-
# Connect and get available tools using AsyncExecutor
129-
tools = AsyncExecutor.get_instance().run(client.connect(), timeout=self.connection_timeout)
131+
tools = self._worker.tools()
130132

131133
# If tool_names is provided, validate that all requested tools exist
132134
if self.tool_names:
@@ -175,6 +177,11 @@ def invoke_tool(**kwargs) -> Any:
175177
super().__init__(tools=haystack_tools)
176178

177179
except Exception as e:
180+
# We need to close because we could connect properly, retrieve tools yet
181+
# fail because of an MCPToolNotFoundError
182+
self.close()
183+
184+
# Create informative error message for SSE connection errors
178185
if isinstance(self.server_info, SSEServerInfo):
179186
base_message = f"Failed to connect to SSE server at {self.server_info.url}"
180187
checks = ["1. The server is running"]
@@ -205,6 +212,7 @@ def invoke_tool(**kwargs) -> Any:
205212
message = f"{base_message}. Please check if:\n" + "\\n".join(checks)
206213
else:
207214
message = f"{base_message}: {e}"
215+
# and for stdio connection errors
208216
elif isinstance(self.server_info, StdioServerInfo): # stdio connection
209217
base_message = "Failed to start MCP server process"
210218
stdio_info = self.server_info
@@ -255,3 +263,14 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset":
255263
connection_timeout=inner_data.get("connection_timeout", 30.0),
256264
invocation_timeout=inner_data.get("invocation_timeout", 30.0),
257265
)
266+
267+
def close(self):
268+
"""Close the underlying MCP client safely."""
269+
if hasattr(self, "_worker") and self._worker:
270+
try:
271+
self._worker.stop()
272+
except Exception as e:
273+
logger.debug(f"TOOLSET: error during worker stop: {e!s}")
274+
275+
def __del__(self):
276+
self.close()

integrations/mcp/tests/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
from haystack_integrations.tools.mcp import MCPTool, MCPToolset
4+
5+
6+
@pytest.fixture
7+
def mcp_tool_cleanup():
8+
"""Fixture to ensure all MCPTool and MCPToolset instances are properly closed after tests."""
9+
tools = []
10+
toolsets = []
11+
12+
def _register(item):
13+
"""Register an MCP component for cleanup."""
14+
if isinstance(item, MCPTool):
15+
tools.append(item)
16+
elif isinstance(item, MCPToolset):
17+
toolsets.append(item)
18+
return item
19+
20+
yield _register
21+
22+
# Finalizer to close all tools and toolsets
23+
for tool in tools:
24+
tool.close()
25+
26+
for toolset in toolsets:
27+
toolset.close()

0 commit comments

Comments
 (0)