1212from google .adk .tools .mcp_tool .mcp_tool import McpTool
1313from google .adk .tools .mcp_tool .mcp_toolset import McpToolset
1414from mcp .shared .exceptions import McpError
15+ from mcp .types import ErrorData
1516
1617from kagent .adk ._mcp_toolset import ConnectionSafeMcpTool , KAgentMcpToolset
1718
1819
1920def _make_connection_safe_tool (side_effect ):
20- """Create a ConnectionSafeMcpTool with a mocked super().run_async."""
21- tool = ConnectionSafeMcpTool .__new__ (ConnectionSafeMcpTool )
22- tool .name = "test-tool"
23- tool ._mcp_tool = MagicMock ()
24- tool ._mcp_tool .name = "test-tool"
25- tool ._mcp_session_manager = AsyncMock ()
26- tool ._header_provider = None
27- tool ._auth_config = None
28- tool ._confirmation_config = None
29- tool ._progress_callback = None
30- tool ._parent_run_async = AsyncMock (side_effect = side_effect )
31- return tool
21+ """Create a ConnectionSafeMcpTool wrapping a mock McpTool."""
22+ inner_tool = MagicMock (spec = McpTool )
23+ inner_tool .name = "test-tool"
24+ inner_tool .run_async = AsyncMock (side_effect = side_effect )
25+ return ConnectionSafeMcpTool (inner_tool )
3226
3327
3428@pytest .mark .asyncio
3529async def test_connection_reset_error_returns_error_dict ():
3630 """ConnectionResetError should be caught and returned as error text."""
3731 tool = _make_connection_safe_tool (ConnectionResetError ("Connection reset by peer" ))
3832
39- with patch .object (McpTool , "run_async" , tool ._parent_run_async ):
40- result = await tool .run_async (args = {"key" : "value" }, tool_context = MagicMock ())
33+ result = await tool .run_async (args = {"key" : "value" }, tool_context = MagicMock ())
4134
4235 assert "error" in result
4336 assert "ConnectionResetError" in result ["error" ]
@@ -50,8 +43,7 @@ async def test_connection_refused_error_returns_error_dict():
5043 """ConnectionRefusedError should be caught and returned as error text."""
5144 tool = _make_connection_safe_tool (ConnectionRefusedError ("Connection refused" ))
5245
53- with patch .object (McpTool , "run_async" , tool ._parent_run_async ):
54- result = await tool .run_async (args = {}, tool_context = MagicMock ())
46+ result = await tool .run_async (args = {}, tool_context = MagicMock ())
5547
5648 assert "error" in result
5749 assert "ConnectionRefusedError" in result ["error" ]
@@ -62,8 +54,7 @@ async def test_timeout_error_returns_error_dict():
6254 """TimeoutError should be caught and returned as error text."""
6355 tool = _make_connection_safe_tool (TimeoutError ("timed out" ))
6456
65- with patch .object (McpTool , "run_async" , tool ._parent_run_async ):
66- result = await tool .run_async (args = {}, tool_context = MagicMock ())
57+ result = await tool .run_async (args = {}, tool_context = MagicMock ())
6758
6859 assert "error" in result
6960 assert "TimeoutError" in result ["error" ]
@@ -74,8 +65,7 @@ async def test_httpx_connect_error_returns_error_dict():
7465 """httpx.ConnectError should be caught via httpx.TransportError."""
7566 tool = _make_connection_safe_tool (httpx .ConnectError ("connection refused" ))
7667
77- with patch .object (McpTool , "run_async" , tool ._parent_run_async ):
78- result = await tool .run_async (args = {}, tool_context = MagicMock ())
68+ result = await tool .run_async (args = {}, tool_context = MagicMock ())
7969
8070 assert "error" in result
8171 assert "ConnectError" in result ["error" ]
@@ -86,8 +76,7 @@ async def test_httpx_read_error_returns_error_dict():
8676 """httpx.ReadError (connection reset by peer) should be caught."""
8777 tool = _make_connection_safe_tool (httpx .ReadError ("peer closed connection" ))
8878
89- with patch .object (McpTool , "run_async" , tool ._parent_run_async ):
90- result = await tool .run_async (args = {}, tool_context = MagicMock ())
79+ result = await tool .run_async (args = {}, tool_context = MagicMock ())
9180
9281 assert "error" in result
9382 assert "ReadError" in result ["error" ]
@@ -98,57 +87,62 @@ async def test_httpx_connect_timeout_returns_error_dict():
9887 """httpx.ConnectTimeout should be caught via httpx.TransportError."""
9988 tool = _make_connection_safe_tool (httpx .ConnectTimeout ("timed out" ))
10089
101- with patch .object (McpTool , "run_async" , tool ._parent_run_async ):
102- result = await tool .run_async (args = {}, tool_context = MagicMock ())
90+ result = await tool .run_async (args = {}, tool_context = MagicMock ())
10391
10492 assert "error" in result
10593 assert "ConnectTimeout" in result ["error" ]
10694
10795
10896@pytest .mark .asyncio
109- async def test_mcp_error_returns_error_dict ():
110- """McpError (raised by MCP session on stream drop / read timeout) should be caught."""
111- from mcp .types import ErrorData
97+ async def test_transport_mcp_error_returns_error_dict ():
98+ """McpError with a transport-level message (e.g., session read timeout) should be caught."""
99+ tool = _make_connection_safe_tool (
100+ McpError (ErrorData (code = - 1 , message = "session read timeout" ))
101+ )
112102
113- tool = _make_connection_safe_tool (McpError (ErrorData (code = - 1 , message = "session read timeout" )))
114-
115- with patch .object (McpTool , "run_async" , tool ._parent_run_async ):
116- result = await tool .run_async (args = {}, tool_context = MagicMock ())
103+ result = await tool .run_async (args = {}, tool_context = MagicMock ())
117104
118105 assert "error" in result
119106 assert "McpError" in result ["error" ]
120107 assert "session read timeout" in result ["error" ]
121108
122109
110+ @pytest .mark .asyncio
111+ async def test_protocol_mcp_error_still_raises ():
112+ """McpError with a protocol-level message (e.g., invalid arguments) should propagate."""
113+ tool = _make_connection_safe_tool (
114+ McpError (ErrorData (code = - 32602 , message = "Invalid params: unknown tool" ))
115+ )
116+
117+ with pytest .raises (McpError , match = "Invalid params" ):
118+ await tool .run_async (args = {}, tool_context = MagicMock ())
119+
120+
123121@pytest .mark .asyncio
124122async def test_non_connection_error_still_raises ():
125123 """Non-connection errors (e.g. ValueError) should still propagate."""
126124 tool = _make_connection_safe_tool (ValueError ("bad argument" ))
127125
128- with patch .object (McpTool , "run_async" , tool ._parent_run_async ):
129- with pytest .raises (ValueError , match = "bad argument" ):
130- await tool .run_async (args = {}, tool_context = MagicMock ())
126+ with pytest .raises (ValueError , match = "bad argument" ):
127+ await tool .run_async (args = {}, tool_context = MagicMock ())
131128
132129
133130@pytest .mark .asyncio
134131async def test_cancelled_error_still_raises ():
135132 """CancelledError must propagate — it's not a connection error."""
136133 tool = _make_connection_safe_tool (asyncio .CancelledError ("cancelled" ))
137134
138- with patch .object (McpTool , "run_async" , tool ._parent_run_async ):
139- with pytest .raises (asyncio .CancelledError ):
140- await tool .run_async (args = {}, tool_context = MagicMock ())
135+ with pytest .raises (asyncio .CancelledError ):
136+ await tool .run_async (args = {}, tool_context = MagicMock ())
141137
142138
143139@pytest .mark .asyncio
144140async def test_get_tools_wraps_mcp_tools ():
145141 """KAgentMcpToolset.get_tools should wrap McpTool instances with ConnectionSafeMcpTool."""
146- # Create a real McpTool instance (bypassing __init__) so isinstance checks work
147142 fake_mcp_tool = McpTool .__new__ (McpTool )
148143 fake_mcp_tool .name = "wrapped-tool"
149144 fake_mcp_tool ._some_attr = "value"
150145
151- # A non-McpTool object that should pass through unchanged
152146 fake_other_tool = MagicMock ()
153147 fake_other_tool .name = "other-tool"
154148
@@ -164,5 +158,4 @@ async def mock_super_get_tools(self_arg, readonly_context=None):
164158 assert isinstance (tools [0 ], ConnectionSafeMcpTool )
165159 assert tools [0 ].name == "wrapped-tool"
166160 assert tools [0 ]._some_attr == "value"
167- # Non-McpTool should pass through unchanged
168161 assert tools [1 ] is fake_other_tool
0 commit comments