Skip to content

Commit 48764d8

Browse files
committed
fix: narrow McpError handling, use composition for ConnectionSafeMcpTool
Address review feedback on #1531: - Only catch transport-level McpErrors (timeouts, stream drops) via keyword inspection; protocol-level McpErrors (invalid args, validation) now propagate so the LLM can correct its behavior - Replace fragile __new__ + __dict__ copy with composition pattern: store inner McpTool and delegate via __getattr__ - Add exc_info to logger.error() for operator-visible tracebacks - Remove unused Dict import, use lowercase dict[str, Any] - Simplify test setup: mock inner tool directly, remove patch boilerplate - Add test_protocol_mcp_error_still_raises to verify narrowing Signed-off-by: Jaison Paul <paul.jaison@gmail.com>
1 parent c5c20e5 commit 48764d8

2 files changed

Lines changed: 92 additions & 63 deletions

File tree

python/packages/kagent-adk/src/kagent/adk/_mcp_toolset.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
import logging
5-
from typing import Any, Dict, Optional
5+
from typing import Any, Optional
66

77
import httpx
88
from google.adk.tools import BaseTool
@@ -23,16 +23,37 @@
2323
# - httpx.TransportError: covers httpx.NetworkError (ConnectError, ReadError,
2424
# WriteError, CloseError), httpx.TimeoutException, httpx.ProtocolError, etc.
2525
# These do NOT inherit from stdlib ConnectionError/OSError.
26-
# - McpError: raised by mcp.shared.session.send_request() when the underlying
27-
# SSE/HTTP stream drops or a tool call hits the session read timeout. The MCP
28-
# client wraps the transport-level error into McpError before it reaches us.
26+
#
27+
# McpError is handled separately in ConnectionSafeMcpTool.run_async() because
28+
# it is the general MCP protocol error class. Only transport-level McpErrors
29+
# (e.g., session read timeouts) should be caught; protocol-level McpErrors
30+
# (e.g., invalid tool arguments) must propagate so the LLM can correct itself.
2931
_CONNECTION_ERROR_TYPES = (
3032
ConnectionError,
3133
TimeoutError,
3234
httpx.TransportError,
33-
McpError,
3435
)
3536

37+
# Keywords in McpError messages that indicate transport-level failures
38+
# (as opposed to protocol-level errors like invalid arguments).
39+
_TRANSPORT_MCP_ERROR_KEYWORDS = (
40+
"timeout", "timed out", "connection", "eof", "reset",
41+
"closed", "transport", "stream", "unreachable",
42+
)
43+
44+
45+
def _is_transport_mcp_error(error: McpError) -> bool:
46+
"""Check if an McpError represents a transport-level failure.
47+
48+
McpError wraps all MCP protocol errors, but only transport-level failures
49+
(e.g., session read timeouts, stream closures) should be caught and
50+
returned to the LLM as non-retryable errors. Protocol-level errors
51+
(e.g., invalid tool arguments, server validation failures) should
52+
propagate so the LLM can correct its behavior.
53+
"""
54+
message = error.error.message.lower()
55+
return any(keyword in message for keyword in _TRANSPORT_MCP_ERROR_KEYWORDS)
56+
3657

3758
def _enrich_cancelled_error(error: BaseException) -> asyncio.CancelledError:
3859
message = "Failed to create MCP session: operation cancelled"
@@ -49,26 +70,47 @@ class ConnectionSafeMcpTool(McpTool):
4970
peer") causes the LLM to retry the tool call in a tight loop, burning
5071
100% CPU for up to max_llm_calls iterations.
5172
73+
Uses composition: delegates to an inner McpTool instance via __getattr__,
74+
avoiding the fragile __new__ + __dict__ copy pattern that would break if
75+
upstream McpTool adds __slots__, properties, or post-init hooks.
76+
5277
See: https://github.com/kagent-dev/kagent/issues/1530
5378
"""
5479

80+
_inner_tool: McpTool
81+
82+
def __init__(self, inner_tool: McpTool):
83+
# Store the inner tool without calling McpTool.__init__
84+
# (which requires connection params we don't have).
85+
object.__setattr__(self, "_inner_tool", inner_tool)
86+
87+
def __getattr__(self, name: str) -> Any:
88+
return getattr(self._inner_tool, name)
89+
90+
def _connection_error_response(self, error: Exception) -> dict[str, Any]:
91+
error_message = (
92+
f"MCP tool '{self.name}' failed due to a connection error: "
93+
f"{type(error).__name__}: {error}. "
94+
"The MCP server may be unreachable. "
95+
"Do not retry this tool — inform the user about the failure."
96+
)
97+
logger.error(error_message, exc_info=error)
98+
return {"error": error_message}
99+
55100
async def run_async(
56101
self,
57102
*,
58-
args: Dict[str, Any],
103+
args: dict[str, Any],
59104
tool_context: ToolContext,
60-
) -> Dict[str, Any]:
105+
) -> dict[str, Any]:
61106
try:
62-
return await super().run_async(args=args, tool_context=tool_context)
107+
return await self._inner_tool.run_async(args=args, tool_context=tool_context)
63108
except _CONNECTION_ERROR_TYPES as error:
64-
error_message = (
65-
f"MCP tool '{self.name}' failed due to a connection error: "
66-
f"{type(error).__name__}: {error}. "
67-
"The MCP server may be unreachable. "
68-
"Do not retry this tool — inform the user about the failure."
69-
)
70-
logger.error(error_message)
71-
return {"error": error_message}
109+
return self._connection_error_response(error)
110+
except McpError as error:
111+
if not _is_transport_mcp_error(error):
112+
raise
113+
return self._connection_error_response(error)
72114

73115

74116
class KAgentMcpToolset(McpToolset):
@@ -87,16 +129,10 @@ async def get_tools(self, readonly_context: Optional[ReadonlyContext] = None) ->
87129

88130
# Wrap each McpTool with ConnectionSafeMcpTool so that connection
89131
# errors are returned as error text instead of raised.
90-
# Uses __new__ + __dict__ copy to re-type the instance without calling
91-
# McpTool.__init__ (which requires connection params we don't have).
92-
# This is safe because McpTool uses plain instance attributes, not
93-
# __slots__ or descriptors.
94132
wrapped_tools: list[BaseTool] = []
95133
for tool in tools:
96134
if isinstance(tool, McpTool) and not isinstance(tool, ConnectionSafeMcpTool):
97-
safe_tool = ConnectionSafeMcpTool.__new__(ConnectionSafeMcpTool)
98-
safe_tool.__dict__.update(tool.__dict__)
99-
wrapped_tools.append(safe_tool)
135+
wrapped_tools.append(ConnectionSafeMcpTool(tool))
100136
else:
101137
wrapped_tools.append(tool)
102138
return wrapped_tools

python/packages/kagent-adk/tests/unittests/test_mcp_connection_error_handling.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,25 @@
1212
from google.adk.tools.mcp_tool.mcp_tool import McpTool
1313
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
1414
from mcp.shared.exceptions import McpError
15+
from mcp.types import ErrorData
1516

1617
from kagent.adk._mcp_toolset import ConnectionSafeMcpTool, KAgentMcpToolset
1718

1819

1920
def _make_connection_safe_tool(side_effect):
20-
"""Create a ConnectionSafeMcpTool with a mocked super().run_async."""
21-
tool = ConnectionSafeMcpTool.__new__(ConnectionSafeMcpTool)
22-
tool.name = "test-tool"
23-
tool._mcp_tool = MagicMock()
24-
tool._mcp_tool.name = "test-tool"
25-
tool._mcp_session_manager = AsyncMock()
26-
tool._header_provider = None
27-
tool._auth_config = None
28-
tool._confirmation_config = None
29-
tool._progress_callback = None
30-
tool._parent_run_async = AsyncMock(side_effect=side_effect)
31-
return tool
21+
"""Create a ConnectionSafeMcpTool wrapping a mock McpTool."""
22+
inner_tool = MagicMock(spec=McpTool)
23+
inner_tool.name = "test-tool"
24+
inner_tool.run_async = AsyncMock(side_effect=side_effect)
25+
return ConnectionSafeMcpTool(inner_tool)
3226

3327

3428
@pytest.mark.asyncio
3529
async def test_connection_reset_error_returns_error_dict():
3630
"""ConnectionResetError should be caught and returned as error text."""
3731
tool = _make_connection_safe_tool(ConnectionResetError("Connection reset by peer"))
3832

39-
with patch.object(McpTool, "run_async", tool._parent_run_async):
40-
result = await tool.run_async(args={"key": "value"}, tool_context=MagicMock())
33+
result = await tool.run_async(args={"key": "value"}, tool_context=MagicMock())
4134

4235
assert "error" in result
4336
assert "ConnectionResetError" in result["error"]
@@ -50,8 +43,7 @@ async def test_connection_refused_error_returns_error_dict():
5043
"""ConnectionRefusedError should be caught and returned as error text."""
5144
tool = _make_connection_safe_tool(ConnectionRefusedError("Connection refused"))
5245

53-
with patch.object(McpTool, "run_async", tool._parent_run_async):
54-
result = await tool.run_async(args={}, tool_context=MagicMock())
46+
result = await tool.run_async(args={}, tool_context=MagicMock())
5547

5648
assert "error" in result
5749
assert "ConnectionRefusedError" in result["error"]
@@ -62,8 +54,7 @@ async def test_timeout_error_returns_error_dict():
6254
"""TimeoutError should be caught and returned as error text."""
6355
tool = _make_connection_safe_tool(TimeoutError("timed out"))
6456

65-
with patch.object(McpTool, "run_async", tool._parent_run_async):
66-
result = await tool.run_async(args={}, tool_context=MagicMock())
57+
result = await tool.run_async(args={}, tool_context=MagicMock())
6758

6859
assert "error" in result
6960
assert "TimeoutError" in result["error"]
@@ -74,8 +65,7 @@ async def test_httpx_connect_error_returns_error_dict():
7465
"""httpx.ConnectError should be caught via httpx.TransportError."""
7566
tool = _make_connection_safe_tool(httpx.ConnectError("connection refused"))
7667

77-
with patch.object(McpTool, "run_async", tool._parent_run_async):
78-
result = await tool.run_async(args={}, tool_context=MagicMock())
68+
result = await tool.run_async(args={}, tool_context=MagicMock())
7969

8070
assert "error" in result
8171
assert "ConnectError" in result["error"]
@@ -86,8 +76,7 @@ async def test_httpx_read_error_returns_error_dict():
8676
"""httpx.ReadError (connection reset by peer) should be caught."""
8777
tool = _make_connection_safe_tool(httpx.ReadError("peer closed connection"))
8878

89-
with patch.object(McpTool, "run_async", tool._parent_run_async):
90-
result = await tool.run_async(args={}, tool_context=MagicMock())
79+
result = await tool.run_async(args={}, tool_context=MagicMock())
9180

9281
assert "error" in result
9382
assert "ReadError" in result["error"]
@@ -98,57 +87,62 @@ async def test_httpx_connect_timeout_returns_error_dict():
9887
"""httpx.ConnectTimeout should be caught via httpx.TransportError."""
9988
tool = _make_connection_safe_tool(httpx.ConnectTimeout("timed out"))
10089

101-
with patch.object(McpTool, "run_async", tool._parent_run_async):
102-
result = await tool.run_async(args={}, tool_context=MagicMock())
90+
result = await tool.run_async(args={}, tool_context=MagicMock())
10391

10492
assert "error" in result
10593
assert "ConnectTimeout" in result["error"]
10694

10795

10896
@pytest.mark.asyncio
109-
async def test_mcp_error_returns_error_dict():
110-
"""McpError (raised by MCP session on stream drop / read timeout) should be caught."""
111-
from mcp.types import ErrorData
97+
async def test_transport_mcp_error_returns_error_dict():
98+
"""McpError with a transport-level message (e.g., session read timeout) should be caught."""
99+
tool = _make_connection_safe_tool(
100+
McpError(ErrorData(code=-1, message="session read timeout"))
101+
)
112102

113-
tool = _make_connection_safe_tool(McpError(ErrorData(code=-1, message="session read timeout")))
114-
115-
with patch.object(McpTool, "run_async", tool._parent_run_async):
116-
result = await tool.run_async(args={}, tool_context=MagicMock())
103+
result = await tool.run_async(args={}, tool_context=MagicMock())
117104

118105
assert "error" in result
119106
assert "McpError" in result["error"]
120107
assert "session read timeout" in result["error"]
121108

122109

110+
@pytest.mark.asyncio
111+
async def test_protocol_mcp_error_still_raises():
112+
"""McpError with a protocol-level message (e.g., invalid arguments) should propagate."""
113+
tool = _make_connection_safe_tool(
114+
McpError(ErrorData(code=-32602, message="Invalid params: unknown tool"))
115+
)
116+
117+
with pytest.raises(McpError, match="Invalid params"):
118+
await tool.run_async(args={}, tool_context=MagicMock())
119+
120+
123121
@pytest.mark.asyncio
124122
async def test_non_connection_error_still_raises():
125123
"""Non-connection errors (e.g. ValueError) should still propagate."""
126124
tool = _make_connection_safe_tool(ValueError("bad argument"))
127125

128-
with patch.object(McpTool, "run_async", tool._parent_run_async):
129-
with pytest.raises(ValueError, match="bad argument"):
130-
await tool.run_async(args={}, tool_context=MagicMock())
126+
with pytest.raises(ValueError, match="bad argument"):
127+
await tool.run_async(args={}, tool_context=MagicMock())
131128

132129

133130
@pytest.mark.asyncio
134131
async def test_cancelled_error_still_raises():
135132
"""CancelledError must propagate — it's not a connection error."""
136133
tool = _make_connection_safe_tool(asyncio.CancelledError("cancelled"))
137134

138-
with patch.object(McpTool, "run_async", tool._parent_run_async):
139-
with pytest.raises(asyncio.CancelledError):
140-
await tool.run_async(args={}, tool_context=MagicMock())
135+
with pytest.raises(asyncio.CancelledError):
136+
await tool.run_async(args={}, tool_context=MagicMock())
141137

142138

143139
@pytest.mark.asyncio
144140
async def test_get_tools_wraps_mcp_tools():
145141
"""KAgentMcpToolset.get_tools should wrap McpTool instances with ConnectionSafeMcpTool."""
146-
# Create a real McpTool instance (bypassing __init__) so isinstance checks work
147142
fake_mcp_tool = McpTool.__new__(McpTool)
148143
fake_mcp_tool.name = "wrapped-tool"
149144
fake_mcp_tool._some_attr = "value"
150145

151-
# A non-McpTool object that should pass through unchanged
152146
fake_other_tool = MagicMock()
153147
fake_other_tool.name = "other-tool"
154148

@@ -164,5 +158,4 @@ async def mock_super_get_tools(self_arg, readonly_context=None):
164158
assert isinstance(tools[0], ConnectionSafeMcpTool)
165159
assert tools[0].name == "wrapped-tool"
166160
assert tools[0]._some_attr == "value"
167-
# Non-McpTool should pass through unchanged
168161
assert tools[1] is fake_other_tool

0 commit comments

Comments
 (0)