Skip to content

Commit 5e717fc

Browse files
committed
fix: return MCP connection errors to LLM instead of raising
When an MCP HTTP tool call fails with a persistent connection error (e.g. "connection reset by peer"), the error propagates to the LLM as a function error. The LLM interprets this as transient and retries the same tool call, creating a tight loop that burns 100% CPU for up to max_llm_calls (500) iterations. Wrap McpTool instances with ConnectionSafeMcpTool that catches connection errors (ConnectionError, TimeoutError, httpx.TransportError, McpError) and returns them as error text. This lets the LLM inform the user about the failure instead of retrying indefinitely. Fixes #1530 Signed-off-by: Jaison Paul <paul.jaison@gmail.com>
1 parent fcad888 commit 5e717fc

File tree

2 files changed

+240
-2
lines changed

2 files changed

+240
-2
lines changed

python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,37 @@
22

33
import asyncio
44
import logging
5-
from typing import Optional
5+
from typing import Any, Dict, Optional
66

7+
import httpx
78
from google.adk.tools import BaseTool
9+
from google.adk.tools.mcp_tool.mcp_tool import McpTool
810
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset, ReadonlyContext
11+
from google.adk.tools.tool_context import ToolContext
12+
from mcp.shared.exceptions import McpError
913

1014
logger = logging.getLogger("kagent_adk." + __name__)
1115

16+
# Connection errors that indicate an unreachable MCP server.
17+
# When these occur, the tool should return an error message to the LLM
18+
# instead of raising, so the LLM can respond to the user rather than
19+
# retrying the broken tool indefinitely.
20+
#
21+
# - ConnectionError: stdlib base for ConnectionResetError, ConnectionRefusedError, etc.
22+
# - TimeoutError: stdlib timeout (e.g. socket.timeout)
23+
# - httpx.TransportError: covers httpx.NetworkError (ConnectError, ReadError,
24+
# WriteError, CloseError), httpx.TimeoutException, httpx.ProtocolError, etc.
25+
# These do NOT inherit from stdlib ConnectionError/OSError.
26+
# - McpError: raised by mcp.shared.session.send_request() when the underlying
27+
# SSE/HTTP stream drops or a tool call hits the session read timeout. The MCP
28+
# client wraps the transport-level error into McpError before it reaches us.
29+
_CONNECTION_ERROR_TYPES = (
30+
ConnectionError,
31+
TimeoutError,
32+
httpx.TransportError,
33+
McpError,
34+
)
35+
1236

1337
def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError:
1438
message = "Failed to create MCP session: operation cancelled"
@@ -17,6 +41,36 @@ def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError:
1741
return asyncio.CancelledError(message)
1842

1943

44+
class ConnectionSafeMcpTool(McpTool):
45+
"""McpTool wrapper that catches connection errors and returns them as
46+
error text to the LLM instead of raising.
47+
48+
Without this, a persistent connection failure (e.g. "connection reset by
49+
peer") causes the LLM to retry the tool call in a tight loop, burning
50+
100% CPU for up to max_llm_calls iterations.
51+
52+
See: https://github.com/kagent-dev/kagent/issues/1530
53+
"""
54+
55+
async def run_async(
56+
self,
57+
*,
58+
args: Dict[str, Any],
59+
tool_context: ToolContext,
60+
) -> Dict[str, Any]:
61+
try:
62+
return await super().run_async(args=args, tool_context=tool_context)
63+
except _CONNECTION_ERROR_TYPES as error:
64+
error_message = (
65+
f"MCP tool '{self.name}' failed due to a connection error: "
66+
f"{type(error).__name__}: {error}. "
67+
"The MCP server may be unreachable. "
68+
"Do not retry this tool — inform the user about the failure."
69+
)
70+
logger.error(error_message)
71+
return {"error": error_message}
72+
73+
2074
class KAgentMcpToolset(McpToolset):
2175
"""McpToolset variant that catches and enriches errors during MCP session setup
2276
and handles cancel scope issues during cleanup.
@@ -27,10 +81,26 @@ class KAgentMcpToolset(McpToolset):
2781

2882
async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) -> list[BaseTool]:
2983
try:
30-
return await super().get_tools(readonly_context)
84+
tools = await super().get_tools(readonly_context)
3185
except asyncio.CancelledError as error:
3286
raise _enrich_cancelled_error(error) from error
3387

88+
# Wrap each McpTool with ConnectionSafeMcpTool so that connection
89+
# errors are returned as error text instead of raised.
90+
# Uses __new__ + __dict__ copy to re-type the instance without calling
91+
# McpTool.__init__ (which requires connection params we don't have).
92+
# This is safe because McpTool uses plain instance attributes, not
93+
# __slots__ or descriptors.
94+
wrapped_tools: list[BaseTool] = []
95+
for tool in tools:
96+
if isinstance(tool, McpTool) and not isinstance(tool, ConnectionSafeMcpTool):
97+
safe_tool = ConnectionSafeMcpTool.__new__(ConnectionSafeMcpTool)
98+
safe_tool.__dict__.update(tool.__dict__)
99+
wrapped_tools.append(safe_tool)
100+
else:
101+
wrapped_tools.append(tool)
102+
return wrapped_tools
103+
34104
async def close(self) -> None:
35105
"""Close MCP sessions and suppress known anyio cancel scope cleanup errors.
36106
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Tests for ConnectionSafeMcpTool — connection errors are returned as
2+
error text to the LLM instead of raised, preventing tight retry loops.
3+
4+
See: https://github.com/kagent-dev/kagent/issues/1530
5+
"""
6+
7+
import asyncio
8+
from unittest.mock import AsyncMock, MagicMock, patch
9+
10+
import httpx
11+
import pytest
12+
from google.adk.tools.mcp_tool.mcp_tool import McpTool
13+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
14+
from mcp.shared.exceptions import McpError
15+
16+
from kagent.adk._mcp_toolset import ConnectionSafeMcpTool, KAgentMcpToolset
17+
18+
19+
def _make_connection_safe_tool(side_effect):
20+
"""Create a ConnectionSafeMcpTool with a mocked super().run_async."""
21+
tool = ConnectionSafeMcpTool.__new__(ConnectionSafeMcpTool)
22+
tool.name = "test-tool"
23+
tool._mcp_tool = MagicMock()
24+
tool._mcp_tool.name = "test-tool"
25+
tool._mcp_session_manager = AsyncMock()
26+
tool._header_provider = None
27+
tool._auth_config = None
28+
tool._confirmation_config = None
29+
tool._progress_callback = None
30+
tool._parent_run_async = AsyncMock(side_effect=side_effect)
31+
return tool
32+
33+
34+
@pytest.mark.asyncio
35+
async def test_connection_reset_error_returns_error_dict():
36+
"""ConnectionResetError should be caught and returned as error text."""
37+
tool = _make_connection_safe_tool(ConnectionResetError("Connection reset by peer"))
38+
39+
with patch.object(McpTool, "run_async", tool._parent_run_async):
40+
result = await tool.run_async(args={"key": "value"}, tool_context=MagicMock())
41+
42+
assert "error" in result
43+
assert "ConnectionResetError" in result["error"]
44+
assert "Connection reset by peer" in result["error"]
45+
assert "Do not retry" in result["error"]
46+
47+
48+
@pytest.mark.asyncio
49+
async def test_connection_refused_error_returns_error_dict():
50+
"""ConnectionRefusedError should be caught and returned as error text."""
51+
tool = _make_connection_safe_tool(ConnectionRefusedError("Connection refused"))
52+
53+
with patch.object(McpTool, "run_async", tool._parent_run_async):
54+
result = await tool.run_async(args={}, tool_context=MagicMock())
55+
56+
assert "error" in result
57+
assert "ConnectionRefusedError" in result["error"]
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_timeout_error_returns_error_dict():
62+
"""TimeoutError should be caught and returned as error text."""
63+
tool = _make_connection_safe_tool(TimeoutError("timed out"))
64+
65+
with patch.object(McpTool, "run_async", tool._parent_run_async):
66+
result = await tool.run_async(args={}, tool_context=MagicMock())
67+
68+
assert "error" in result
69+
assert "TimeoutError" in result["error"]
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_httpx_connect_error_returns_error_dict():
74+
"""httpx.ConnectError should be caught via httpx.TransportError."""
75+
tool = _make_connection_safe_tool(httpx.ConnectError("connection refused"))
76+
77+
with patch.object(McpTool, "run_async", tool._parent_run_async):
78+
result = await tool.run_async(args={}, tool_context=MagicMock())
79+
80+
assert "error" in result
81+
assert "ConnectError" in result["error"]
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_httpx_read_error_returns_error_dict():
86+
"""httpx.ReadError (connection reset by peer) should be caught."""
87+
tool = _make_connection_safe_tool(httpx.ReadError("peer closed connection"))
88+
89+
with patch.object(McpTool, "run_async", tool._parent_run_async):
90+
result = await tool.run_async(args={}, tool_context=MagicMock())
91+
92+
assert "error" in result
93+
assert "ReadError" in result["error"]
94+
95+
96+
@pytest.mark.asyncio
97+
async def test_httpx_connect_timeout_returns_error_dict():
98+
"""httpx.ConnectTimeout should be caught via httpx.TransportError."""
99+
tool = _make_connection_safe_tool(httpx.ConnectTimeout("timed out"))
100+
101+
with patch.object(McpTool, "run_async", tool._parent_run_async):
102+
result = await tool.run_async(args={}, tool_context=MagicMock())
103+
104+
assert "error" in result
105+
assert "ConnectTimeout" in result["error"]
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_mcp_error_returns_error_dict():
110+
"""McpError (raised by MCP session on stream drop / read timeout) should be caught."""
111+
from mcp.types import ErrorData
112+
113+
tool = _make_connection_safe_tool(McpError(ErrorData(code=-1, message="session read timeout")))
114+
115+
with patch.object(McpTool, "run_async", tool._parent_run_async):
116+
result = await tool.run_async(args={}, tool_context=MagicMock())
117+
118+
assert "error" in result
119+
assert "McpError" in result["error"]
120+
assert "session read timeout" in result["error"]
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_non_connection_error_still_raises():
125+
"""Non-connection errors (e.g. ValueError) should still propagate."""
126+
tool = _make_connection_safe_tool(ValueError("bad argument"))
127+
128+
with patch.object(McpTool, "run_async", tool._parent_run_async):
129+
with pytest.raises(ValueError, match="bad argument"):
130+
await tool.run_async(args={}, tool_context=MagicMock())
131+
132+
133+
@pytest.mark.asyncio
134+
async def test_cancelled_error_still_raises():
135+
"""CancelledError must propagate — it's not a connection error."""
136+
tool = _make_connection_safe_tool(asyncio.CancelledError("cancelled"))
137+
138+
with patch.object(McpTool, "run_async", tool._parent_run_async):
139+
with pytest.raises(asyncio.CancelledError):
140+
await tool.run_async(args={}, tool_context=MagicMock())
141+
142+
143+
@pytest.mark.asyncio
144+
async def test_get_tools_wraps_mcp_tools():
145+
"""KAgentMcpToolset.get_tools should wrap McpTool instances with ConnectionSafeMcpTool."""
146+
# Create a real McpTool instance (bypassing __init__) so isinstance checks work
147+
fake_mcp_tool = McpTool.__new__(McpTool)
148+
fake_mcp_tool.name = "wrapped-tool"
149+
fake_mcp_tool._some_attr = "value"
150+
151+
# A non-McpTool object that should pass through unchanged
152+
fake_other_tool = MagicMock()
153+
fake_other_tool.name = "other-tool"
154+
155+
toolset = KAgentMcpToolset.__new__(KAgentMcpToolset)
156+
157+
async def mock_super_get_tools(self_arg, readonly_context=None):
158+
return [fake_mcp_tool, fake_other_tool]
159+
160+
with patch.object(McpToolset, "get_tools", mock_super_get_tools):
161+
tools = await toolset.get_tools()
162+
163+
assert len(tools) == 2
164+
assert isinstance(tools[0], ConnectionSafeMcpTool)
165+
assert tools[0].name == "wrapped-tool"
166+
assert tools[0]._some_attr == "value"
167+
# Non-McpTool should pass through unchanged
168+
assert tools[1] is fake_other_tool

0 commit comments

Comments
 (0)