diff --git a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py index 26c4c6df7..817f5eb9e 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py +++ b/python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py @@ -2,13 +2,65 @@ import asyncio import logging -from typing import Optional +from typing import Any, Optional +import httpx from google.adk.tools import BaseTool +from google.adk.tools.mcp_tool.mcp_tool import McpTool from google.adk.tools.mcp_tool.mcp_toolset import McpToolset, ReadonlyContext +from google.adk.tools.tool_context import ToolContext +from mcp.shared.exceptions import McpError logger = logging.getLogger("kagent_adk." + __name__) +# Connection errors that indicate an unreachable MCP server. +# When these occur, the tool should return an error message to the LLM +# instead of raising, so the LLM can respond to the user rather than +# retrying the broken tool indefinitely. +# +# - ConnectionError: stdlib base for ConnectionResetError, ConnectionRefusedError, etc. +# - TimeoutError: stdlib timeout (e.g. socket.timeout) +# - httpx.TransportError: covers httpx.NetworkError (ConnectError, ReadError, +# WriteError, CloseError), httpx.TimeoutException, httpx.ProtocolError, etc. +# These do NOT inherit from stdlib ConnectionError/OSError. +# +# McpError is handled separately in ConnectionSafeMcpTool.run_async() because +# it is the general MCP protocol error class. Only transport-level McpErrors +# (e.g., session read timeouts) should be caught; protocol-level McpErrors +# (e.g., invalid tool arguments) must propagate so the LLM can correct itself. +_CONNECTION_ERROR_TYPES = ( + ConnectionError, + TimeoutError, + httpx.TransportError, +) + +# Keywords in McpError messages that indicate transport-level failures +# (as opposed to protocol-level errors like invalid arguments). +_TRANSPORT_MCP_ERROR_KEYWORDS = ( + "timeout", + "timed out", + "connection", + "eof", + "reset", + "closed", + "transport", + "stream", + "unreachable", +) + + +def _is_transport_mcp_error(error: McpError) -> bool: + """Check if an McpError represents a transport-level failure. + + McpError wraps all MCP protocol errors, but only transport-level failures + (e.g., session read timeouts, stream closures) should be caught and + returned to the LLM as non-retryable errors. Protocol-level errors + (e.g., invalid tool arguments, server validation failures) should + propagate so the LLM can correct its behavior. + """ + message = error.error.message.lower() + return any(keyword in message for keyword in _TRANSPORT_MCP_ERROR_KEYWORDS) + def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError: message = "Failed to create MCP session: operation cancelled" @@ -17,6 +69,57 @@ def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError: return asyncio.CancelledError(message) +class ConnectionSafeMcpTool(McpTool): + """McpTool wrapper that catches connection errors and returns them as + error text to the LLM instead of raising. + + Without this, a persistent connection failure (e.g. "connection reset by + peer") causes the LLM to retry the tool call in a tight loop, burning + 100% CPU for up to max_llm_calls iterations. + + Uses composition: delegates to an inner McpTool instance via __getattr__, + avoiding the fragile __new__ + __dict__ copy pattern that would break if + upstream McpTool adds __slots__, properties, or post-init hooks. + + See: https://github.com/kagent-dev/kagent/issues/1530 + """ + + _inner_tool: McpTool + + def __init__(self, inner_tool: McpTool): + # Store the inner tool without calling McpTool.__init__ + # (which requires connection params we don't have). + object.__setattr__(self, "_inner_tool", inner_tool) + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner_tool, name) + + def _connection_error_response(self, error: Exception) -> dict[str, Any]: + error_message = ( + f"MCP tool '{self.name}' failed due to a connection error: " + f"{type(error).__name__}: {error}. " + "The MCP server may be unreachable. " + "Do not retry this tool — inform the user about the failure." + ) + logger.error(error_message, exc_info=error) + return {"error": error_message} + + async def run_async( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + ) -> dict[str, Any]: + try: + return await self._inner_tool.run_async(args=args, tool_context=tool_context) + except _CONNECTION_ERROR_TYPES as error: + return self._connection_error_response(error) + except McpError as error: + if not _is_transport_mcp_error(error): + raise + return self._connection_error_response(error) + + class KAgentMcpToolset(McpToolset): """McpToolset variant that catches and enriches errors during MCP session setup and handles cancel scope issues during cleanup. @@ -27,10 +130,20 @@ class KAgentMcpToolset(McpToolset): async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]: try: - return await super().get_tools(readonly_context) + tools = await super().get_tools(readonly_context) except asyncio.CancelledError as error: raise _enrich_cancelled_error(error) from error + # Wrap each McpTool with ConnectionSafeMcpTool so that connection + # errors are returned as error text instead of raised. + wrapped_tools: list[BaseTool] = [] + for tool in tools: + if isinstance(tool, McpTool) and not isinstance(tool, ConnectionSafeMcpTool): + wrapped_tools.append(ConnectionSafeMcpTool(tool)) + else: + wrapped_tools.append(tool) + return wrapped_tools + async def close(self) -> None: """Close MCP sessions and suppress known anyio cancel scope cleanup errors. diff --git a/python/packages/kagent-adk/tests/unittests/test_mcp_connection_error_handling.py b/python/packages/kagent-adk/tests/unittests/test_mcp_connection_error_handling.py new file mode 100644 index 000000000..a062f2eaf --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/test_mcp_connection_error_handling.py @@ -0,0 +1,157 @@ +"""Tests for ConnectionSafeMcpTool — connection errors are returned as +error text to the LLM instead of raised, preventing tight retry loops. + +See: https://github.com/kagent-dev/kagent/issues/1530 +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from google.adk.tools.mcp_tool.mcp_tool import McpTool +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +from mcp.shared.exceptions import McpError +from mcp.types import ErrorData + +from kagent.adk._mcp_toolset import ConnectionSafeMcpTool, KAgentMcpToolset + + +def _make_connection_safe_tool(side_effect): + """Create a ConnectionSafeMcpTool wrapping a mock McpTool.""" + inner_tool = MagicMock(spec=McpTool) + inner_tool.name = "test-tool" + inner_tool.run_async = AsyncMock(side_effect=side_effect) + return ConnectionSafeMcpTool(inner_tool) + + +@pytest.mark.asyncio +async def test_connection_reset_error_returns_error_dict(): + """ConnectionResetError should be caught and returned as error text.""" + tool = _make_connection_safe_tool(ConnectionResetError("Connection reset by peer")) + + result = await tool.run_async(args={"key": "value"}, tool_context=MagicMock()) + + assert "error" in result + assert "ConnectionResetError" in result["error"] + assert "Connection reset by peer" in result["error"] + assert "Do not retry" in result["error"] + + +@pytest.mark.asyncio +async def test_connection_refused_error_returns_error_dict(): + """ConnectionRefusedError should be caught and returned as error text.""" + tool = _make_connection_safe_tool(ConnectionRefusedError("Connection refused")) + + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "ConnectionRefusedError" in result["error"] + + +@pytest.mark.asyncio +async def test_timeout_error_returns_error_dict(): + """TimeoutError should be caught and returned as error text.""" + tool = _make_connection_safe_tool(TimeoutError("timed out")) + + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "TimeoutError" in result["error"] + + +@pytest.mark.asyncio +async def test_httpx_connect_error_returns_error_dict(): + """httpx.ConnectError should be caught via httpx.TransportError.""" + tool = _make_connection_safe_tool(httpx.ConnectError("connection refused")) + + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "ConnectError" in result["error"] + + +@pytest.mark.asyncio +async def test_httpx_read_error_returns_error_dict(): + """httpx.ReadError (connection reset by peer) should be caught.""" + tool = _make_connection_safe_tool(httpx.ReadError("peer closed connection")) + + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "ReadError" in result["error"] + + +@pytest.mark.asyncio +async def test_httpx_connect_timeout_returns_error_dict(): + """httpx.ConnectTimeout should be caught via httpx.TransportError.""" + tool = _make_connection_safe_tool(httpx.ConnectTimeout("timed out")) + + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "ConnectTimeout" in result["error"] + + +@pytest.mark.asyncio +async def test_transport_mcp_error_returns_error_dict(): + """McpError with a transport-level message (e.g., session read timeout) should be caught.""" + tool = _make_connection_safe_tool(McpError(ErrorData(code=-1, message="session read timeout"))) + + result = await tool.run_async(args={}, tool_context=MagicMock()) + + assert "error" in result + assert "McpError" in result["error"] + assert "session read timeout" in result["error"] + + +@pytest.mark.asyncio +async def test_protocol_mcp_error_still_raises(): + """McpError with a protocol-level message (e.g., invalid arguments) should propagate.""" + tool = _make_connection_safe_tool(McpError(ErrorData(code=-32602, message="Invalid params: unknown tool"))) + + with pytest.raises(McpError, match="Invalid params"): + await tool.run_async(args={}, tool_context=MagicMock()) + + +@pytest.mark.asyncio +async def test_non_connection_error_still_raises(): + """Non-connection errors (e.g. ValueError) should still propagate.""" + tool = _make_connection_safe_tool(ValueError("bad argument")) + + with pytest.raises(ValueError, match="bad argument"): + await tool.run_async(args={}, tool_context=MagicMock()) + + +@pytest.mark.asyncio +async def test_cancelled_error_still_raises(): + """CancelledError must propagate — it's not a connection error.""" + tool = _make_connection_safe_tool(asyncio.CancelledError("cancelled")) + + with pytest.raises(asyncio.CancelledError): + await tool.run_async(args={}, tool_context=MagicMock()) + + +@pytest.mark.asyncio +async def test_get_tools_wraps_mcp_tools(): + """KAgentMcpToolset.get_tools should wrap McpTool instances with ConnectionSafeMcpTool.""" + fake_mcp_tool = McpTool.__new__(McpTool) + fake_mcp_tool.name = "wrapped-tool" + fake_mcp_tool._some_attr = "value" + + fake_other_tool = MagicMock() + fake_other_tool.name = "other-tool" + + toolset = KAgentMcpToolset.__new__(KAgentMcpToolset) + + async def mock_super_get_tools(self_arg, readonly_context=None): + return [fake_mcp_tool, fake_other_tool] + + with patch.object(McpToolset, "get_tools", mock_super_get_tools): + tools = await toolset.get_tools() + + assert len(tools) == 2 + assert isinstance(tools[0], ConnectionSafeMcpTool) + assert tools[0].name == "wrapped-tool" + assert tools[0]._some_attr == "value" + assert tools[1] is fake_other_tool