Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 115 additions & 2 deletions python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,65 @@

import asyncio
import logging
from typing import Optional
from typing import Any, Optional

import httpx
from google.adk.tools import BaseTool
from google.adk.tools.mcp_tool.mcp_tool import McpTool
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset, ReadonlyContext
from google.adk.tools.tool_context import ToolContext
from mcp.shared.exceptions import McpError

logger = logging.getLogger("kagent_adk." + __name__)

# Connection errors that indicate an unreachable MCP server.
# When these occur, the tool should return an error message to the LLM
# instead of raising, so the LLM can respond to the user rather than
# retrying the broken tool indefinitely.
#
# - ConnectionError: stdlib base for ConnectionResetError, ConnectionRefusedError, etc.
# - TimeoutError: stdlib timeout (e.g. socket.timeout)
# - httpx.TransportError: covers httpx.NetworkError (ConnectError, ReadError,
# WriteError, CloseError), httpx.TimeoutException, httpx.ProtocolError, etc.
# These do NOT inherit from stdlib ConnectionError/OSError.
#
# McpError is handled separately in ConnectionSafeMcpTool.run_async() because
# it is the general MCP protocol error class. Only transport-level McpErrors
# (e.g., session read timeouts) should be caught; protocol-level McpErrors
# (e.g., invalid tool arguments) must propagate so the LLM can correct itself.
_CONNECTION_ERROR_TYPES = (
ConnectionError,
TimeoutError,
httpx.TransportError,
)

# Keywords in McpError messages that indicate transport-level failures
# (as opposed to protocol-level errors like invalid arguments).
_TRANSPORT_MCP_ERROR_KEYWORDS = (
"timeout",
"timed out",
"connection",
"eof",
"reset",
"closed",
"transport",
"stream",
"unreachable",
)


def _is_transport_mcp_error(error: McpError) -> bool:
"""Check if an McpError represents a transport-level failure.

McpError wraps all MCP protocol errors, but only transport-level failures
(e.g., session read timeouts, stream closures) should be caught and
returned to the LLM as non-retryable errors. Protocol-level errors
(e.g., invalid tool arguments, server validation failures) should
propagate so the LLM can correct its behavior.
"""
message = error.error.message.lower()
return any(keyword in message for keyword in _TRANSPORT_MCP_ERROR_KEYWORDS)


def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError:
message = "Failed to create MCP session: operation cancelled"
Expand All @@ -17,6 +69,57 @@ def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError:
return asyncio.CancelledError(message)


class ConnectionSafeMcpTool(McpTool):
"""McpTool wrapper that catches connection errors and returns them as
error text to the LLM instead of raising.

Without this, a persistent connection failure (e.g. "connection reset by
peer") causes the LLM to retry the tool call in a tight loop, burning
100% CPU for up to max_llm_calls iterations.

Uses composition: delegates to an inner McpTool instance via __getattr__,
avoiding the fragile __new__ + __dict__ copy pattern that would break if
upstream McpTool adds __slots__, properties, or post-init hooks.

See: https://github.com/kagent-dev/kagent/issues/1530
"""

_inner_tool: McpTool

def __init__(self, inner_tool: McpTool):
# Store the inner tool without calling McpTool.__init__
# (which requires connection params we don't have).
object.__setattr__(self, "_inner_tool", inner_tool)

def __getattr__(self, name: str) -> Any:
return getattr(self._inner_tool, name)

def _connection_error_response(self, error: Exception) -> dict[str, Any]:
error_message = (
f"MCP tool '{self.name}' failed due to a connection error: "
f"{type(error).__name__}: {error}. "
"The MCP server may be unreachable. "
"Do not retry this tool — inform the user about the failure."
)
logger.error(error_message, exc_info=error)
return {"error": error_message}

async def run_async(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> dict[str, Any]:
try:
return await self._inner_tool.run_async(args=args, tool_context=tool_context)
except _CONNECTION_ERROR_TYPES as error:
return self._connection_error_response(error)
except McpError as error:
if not _is_transport_mcp_error(error):
raise
return self._connection_error_response(error)


class KAgentMcpToolset(McpToolset):
"""McpToolset variant that catches and enriches errors during MCP session setup
and handles cancel scope issues during cleanup.
Expand All @@ -27,10 +130,20 @@ class KAgentMcpToolset(McpToolset):

async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]:
try:
return await super().get_tools(readonly_context)
tools = await super().get_tools(readonly_context)
except asyncio.CancelledError as error:
raise _enrich_cancelled_error(error) from error

# Wrap each McpTool with ConnectionSafeMcpTool so that connection
# errors are returned as error text instead of raised.
wrapped_tools: list[BaseTool] = []
for tool in tools:
if isinstance(tool, McpTool) and not isinstance(tool, ConnectionSafeMcpTool):
wrapped_tools.append(ConnectionSafeMcpTool(tool))
else:
wrapped_tools.append(tool)
return wrapped_tools

async def close(self) -> None:
"""Close MCP sessions and suppress known anyio cancel scope cleanup errors.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Tests for ConnectionSafeMcpTool — connection errors are returned as
error text to the LLM instead of raised, preventing tight retry loops.

See: https://github.com/kagent-dev/kagent/issues/1530
"""

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest
from google.adk.tools.mcp_tool.mcp_tool import McpTool
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
from mcp.shared.exceptions import McpError
from mcp.types import ErrorData

from kagent.adk._mcp_toolset import ConnectionSafeMcpTool, KAgentMcpToolset


def _make_connection_safe_tool(side_effect):
"""Create a ConnectionSafeMcpTool wrapping a mock McpTool."""
inner_tool = MagicMock(spec=McpTool)
inner_tool.name = "test-tool"
inner_tool.run_async = AsyncMock(side_effect=side_effect)
return ConnectionSafeMcpTool(inner_tool)


@pytest.mark.asyncio
async def test_connection_reset_error_returns_error_dict():
"""ConnectionResetError should be caught and returned as error text."""
tool = _make_connection_safe_tool(ConnectionResetError("Connection reset by peer"))

result = await tool.run_async(args={"key": "value"}, tool_context=MagicMock())

assert "error" in result
assert "ConnectionResetError" in result["error"]
assert "Connection reset by peer" in result["error"]
assert "Do not retry" in result["error"]


@pytest.mark.asyncio
async def test_connection_refused_error_returns_error_dict():
"""ConnectionRefusedError should be caught and returned as error text."""
tool = _make_connection_safe_tool(ConnectionRefusedError("Connection refused"))

result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "ConnectionRefusedError" in result["error"]


@pytest.mark.asyncio
async def test_timeout_error_returns_error_dict():
"""TimeoutError should be caught and returned as error text."""
tool = _make_connection_safe_tool(TimeoutError("timed out"))

result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "TimeoutError" in result["error"]


@pytest.mark.asyncio
async def test_httpx_connect_error_returns_error_dict():
"""httpx.ConnectError should be caught via httpx.TransportError."""
tool = _make_connection_safe_tool(httpx.ConnectError("connection refused"))

result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "ConnectError" in result["error"]


@pytest.mark.asyncio
async def test_httpx_read_error_returns_error_dict():
"""httpx.ReadError (connection reset by peer) should be caught."""
tool = _make_connection_safe_tool(httpx.ReadError("peer closed connection"))

result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "ReadError" in result["error"]


@pytest.mark.asyncio
async def test_httpx_connect_timeout_returns_error_dict():
"""httpx.ConnectTimeout should be caught via httpx.TransportError."""
tool = _make_connection_safe_tool(httpx.ConnectTimeout("timed out"))

result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "ConnectTimeout" in result["error"]


@pytest.mark.asyncio
async def test_transport_mcp_error_returns_error_dict():
"""McpError with a transport-level message (e.g., session read timeout) should be caught."""
tool = _make_connection_safe_tool(McpError(ErrorData(code=-1, message="session read timeout")))

result = await tool.run_async(args={}, tool_context=MagicMock())

assert "error" in result
assert "McpError" in result["error"]
assert "session read timeout" in result["error"]


@pytest.mark.asyncio
async def test_protocol_mcp_error_still_raises():
"""McpError with a protocol-level message (e.g., invalid arguments) should propagate."""
tool = _make_connection_safe_tool(McpError(ErrorData(code=-32602, message="Invalid params: unknown tool")))

with pytest.raises(McpError, match="Invalid params"):
await tool.run_async(args={}, tool_context=MagicMock())


@pytest.mark.asyncio
async def test_non_connection_error_still_raises():
"""Non-connection errors (e.g. ValueError) should still propagate."""
tool = _make_connection_safe_tool(ValueError("bad argument"))

with pytest.raises(ValueError, match="bad argument"):
await tool.run_async(args={}, tool_context=MagicMock())


@pytest.mark.asyncio
async def test_cancelled_error_still_raises():
"""CancelledError must propagate — it's not a connection error."""
tool = _make_connection_safe_tool(asyncio.CancelledError("cancelled"))

with pytest.raises(asyncio.CancelledError):
await tool.run_async(args={}, tool_context=MagicMock())


@pytest.mark.asyncio
async def test_get_tools_wraps_mcp_tools():
"""KAgentMcpToolset.get_tools should wrap McpTool instances with ConnectionSafeMcpTool."""
fake_mcp_tool = McpTool.__new__(McpTool)
fake_mcp_tool.name = "wrapped-tool"
fake_mcp_tool._some_attr = "value"

fake_other_tool = MagicMock()
fake_other_tool.name = "other-tool"

toolset = KAgentMcpToolset.__new__(KAgentMcpToolset)

async def mock_super_get_tools(self_arg, readonly_context=None):
return [fake_mcp_tool, fake_other_tool]

with patch.object(McpToolset, "get_tools", mock_super_get_tools):
tools = await toolset.get_tools()

assert len(tools) == 2
assert isinstance(tools[0], ConnectionSafeMcpTool)
assert tools[0].name == "wrapped-tool"
assert tools[0]._some_attr == "value"
assert tools[1] is fake_other_tool
Loading