Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 75 additions & 6 deletions src/services/mcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import time
from typing import Any, Callable

from src.utils.abort_controller import AbortError, AbortSignal

from .errors import (
McpAuthError,
McpSessionExpiredError,
Expand Down Expand Up @@ -335,6 +337,11 @@ async def _receive_loop(self) -> None:
break
if msg.id is not None and msg.id in self._pending_requests:
future = self._pending_requests.pop(msg.id)
if future.done():
# Aborted (cancelled) request whose response raced
# the cleanup — resolving it would raise
# InvalidStateError and kill the receive loop.
continue
if msg.error:
future.set_exception(
McpToolCallError(
Expand All @@ -355,24 +362,81 @@ async def _send_request(
self,
method: str,
params: dict[str, Any] | None = None,
abort_signal: AbortSignal | None = None,
) -> Any:
if self._transport is None:
raise RuntimeError("Transport not connected")
if abort_signal is not None:
abort_signal.throw_if_aborted()
request_id = self._next_id()
msg = JsonRpcMessage(
method=method,
params=params,
id=request_id,
)
future: asyncio.Future[Any] = asyncio.get_event_loop().create_future()
loop = asyncio.get_running_loop()
future: asyncio.Future[Any] = loop.create_future()
self._pending_requests[request_id] = future
await self._transport.send(msg)

# ESC-cancel (#277): the abort listener fires on the aborting
# thread (TUI/REPL ESC handler), so hop onto this loop to cancel
# the pending future — wait_for then raises CancelledError, which
# we convert to AbortError after notifying the server. When user
# abort and a genuine task-cancel coincide, AbortError wins: this
# coroutine runs on a per-call asyncio.run loop in production, so
# no external task cancellation can reach it anyway.
registered_abort: Callable[[], None] | None = None
if abort_signal is not None:
def _on_abort() -> None:
loop.call_soon_threadsafe(future.cancel)

registered_abort = abort_signal.add_listener(_on_abort, once=True)
if abort_signal.aborted:
# Abort fired between throw_if_aborted and add_listener —
# the listener will never fire. Skip the send entirely.
self._pending_requests.pop(request_id, None)
abort_signal.remove_listener(registered_abort)
raise AbortError(abort_signal.reason or "user_interrupt")

timeout_s = _get_tool_timeout_ms() / 1000.0
try:
# Inside the try so the finally's listener/pending cleanup
# also covers a send that raises (closed transport, broken
# pipe) or is cancelled mid-await.
await self._transport.send(msg)
return await asyncio.wait_for(future, timeout=timeout_s)
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
except asyncio.CancelledError:
if abort_signal is not None and abort_signal.aborted:
# JSON-RPC cancellation per MCP spec: best-effort
# notification so a compliant server stops the work; a
# server that ignores it merely leaks one request — the
# client is already unblocked. Bounded so a wedged
# transport (the likely cause of the hang being escaped)
# cannot block the unblock path.
try:
await asyncio.wait_for(
self._send_notification(
"notifications/cancelled",
{
"requestId": request_id,
"reason": abort_signal.reason or "user_interrupt",
},
),
timeout=2,
)
except Exception:
logger.debug(
"failed to send cancellation for request %s", request_id
)
raise AbortError(abort_signal.reason or "user_interrupt") from None
raise
finally:
# No-op on the success path (the receive loop pops when it
# resolves); guarantees no stranded never-resolving future on
# timeout, abort, task-cancel, or send failure.
self._pending_requests.pop(request_id, None)
if abort_signal is not None and registered_abort is not None:
abort_signal.remove_listener(registered_abort)

async def _send_notification(
self,
Expand Down Expand Up @@ -411,6 +475,7 @@ async def call_tool(
tool_name: str,
arguments: dict[str, Any] | None = None,
meta: dict[str, Any] | None = None,
abort_signal: AbortSignal | None = None,
) -> McpToolResult:
params: dict[str, Any] = {
"name": tool_name,
Expand All @@ -425,7 +490,9 @@ async def call_tool(
# cache is cleared on detection so the next request reconnects
# against a fresh session rather than reusing the expired one.
try:
result = await self._send_request("tools/call", params)
result = await self._send_request(
"tools/call", params, abort_signal=abort_signal
)
except McpToolCallError as err:
if not is_mcp_session_expired_error(err):
# Regular tool error (invalid params, server-rejected, etc.) —
Expand All @@ -437,7 +504,9 @@ async def call_tool(
# failed and re-raised). A second session-expired here means
# the server is unstable / the retry hit a fresh session that
# already expired — propagate so we don't loop indefinitely.
result = await self._send_request("tools/call", params)
result = await self._send_request(
"tools/call", params, abort_signal=abort_signal
)
if not result or not isinstance(result, dict):
return McpToolResult()

Expand Down
17 changes: 16 additions & 1 deletion src/services/mcp/tool_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from src.tool_system.build_tool import McpInfo, Tool, build_tool
from src.tool_system.context import ToolContext
from src.tool_system.protocol import ToolResult
from src.utils.abort_controller import AbortError

from .client import McpClient
from .mcp_string_utils import build_mcp_tool_name
Expand Down Expand Up @@ -253,8 +254,18 @@ async def _async_call(args: dict[str, Any], ctx: ToolContext) -> ToolResult:
is_error=True,
)

# getattr: duck-typed/mocked contexts (spec'd mocks don't expose
# default_factory dataclass fields) may lack abort_controller.
abort_controller = getattr(ctx, "abort_controller", None)
abort_signal = (
abort_controller.signal if abort_controller is not None else None
)
try:
result = await client.call_tool(mcp_tool.name, args)
result = await client.call_tool(
mcp_tool.name,
args,
abort_signal=abort_signal,
)
content_blocks: list[dict[str, Any]] = list(result.content) if result.content else []

# WI-8.2: budget-truncate before rendering so the model never
Expand Down Expand Up @@ -324,6 +335,10 @@ async def _async_call(args: dict[str, Any], ctx: ToolContext) -> ToolResult:
is_error=False,
mcp_meta=mcp_meta,
)
except AbortError:
# ESC-cancel (#277): propagate so the dispatch layer renders
# the user-cancel message instead of a generic tool error.
raise
except Exception as e:
return ToolResult(
name=fully_qualified_name,
Expand Down
Loading