Skip to content

Commit 49ada7c

Browse files
committed
fix: detect MCP session disconnect during tool call to avoid 5min hang
When the MCP server returns a non-2xx HTTP response (e.g. 403), the transport crashes in a background task without propagating the error to the pending send_request() call. This causes the agent to hang for ~5 minutes until sse_read_timeout expires. Add a concurrent health check that polls session stream state alongside tool calls. If the session disconnects mid-call, raise ConnectionError immediately instead of waiting for the timeout. Fixes #4901
1 parent 0e93faf commit 49ada7c

File tree

2 files changed

+174
-2
lines changed

2 files changed

+174
-2
lines changed

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ async def _run_async_impl(
373373
# Resolve progress callback (may be a factory that needs runtime context)
374374
resolved_callback = self._resolve_progress_callback(tool_context)
375375

376-
response = await session.call_tool(
377-
self._mcp_tool.name,
376+
response = await self._call_tool_with_health_check(
377+
session=session,
378378
arguments=args,
379379
progress_callback=resolved_callback,
380380
meta=meta_trace_context,
@@ -396,6 +396,93 @@ async def _run_async_impl(
396396
)
397397
return result
398398

399+
async def _call_tool_with_health_check(
400+
self,
401+
*,
402+
session,
403+
arguments: dict[str, Any],
404+
progress_callback,
405+
meta,
406+
) -> Any:
407+
"""Calls an MCP tool while monitoring session health.
408+
409+
When the MCP server returns a non-2xx HTTP response, the underlying
410+
transport can crash in a background task without propagating the error
411+
to the pending send_request() call. This causes send_request() to hang
412+
until the read timeout (~5 minutes) expires.
413+
414+
This method races the tool call against a health check that polls the
415+
session's stream state. If the session disconnects mid-call, we fail
416+
fast with a ConnectionError instead of hanging.
417+
418+
Args:
419+
session: The MCP ClientSession to use.
420+
arguments: The arguments to pass to the tool.
421+
progress_callback: Optional progress callback.
422+
meta: Optional trace context metadata.
423+
424+
Returns:
425+
The tool call response.
426+
427+
Raises:
428+
ConnectionError: If the session disconnects during the tool call.
429+
"""
430+
431+
async def _health_check():
432+
"""Polls session stream health and raises if disconnected."""
433+
# Small initial delay to let the call_tool start
434+
await asyncio.sleep(0.1)
435+
while True:
436+
try:
437+
if self._mcp_session_manager._is_session_disconnected(session):
438+
raise ConnectionError(
439+
'MCP session disconnected during tool call. This typically'
440+
' happens when the MCP server returns a non-2xx HTTP'
441+
' response (e.g. 403 Forbidden). Check the server URL'
442+
' and authentication configuration.'
443+
)
444+
except (AttributeError, TypeError):
445+
# If stream attributes are not accessible, skip the check.
446+
# This can happen with certain transport implementations.
447+
pass
448+
await asyncio.sleep(0.5)
449+
450+
health_task = asyncio.create_task(_health_check())
451+
tool_task = asyncio.create_task(
452+
session.call_tool(
453+
self._mcp_tool.name,
454+
arguments=arguments,
455+
progress_callback=progress_callback,
456+
meta=meta,
457+
)
458+
)
459+
460+
try:
461+
done, pending = await asyncio.wait(
462+
[tool_task, health_task],
463+
return_when=asyncio.FIRST_COMPLETED,
464+
)
465+
466+
# Cancel whichever task didn't finish
467+
for task in pending:
468+
task.cancel()
469+
try:
470+
await task
471+
except (asyncio.CancelledError, Exception):
472+
pass
473+
474+
# If the tool call finished first, return its result
475+
if tool_task in done:
476+
return tool_task.result()
477+
478+
# Health check finished first (i.e., session disconnected)
479+
# Re-raise the ConnectionError from the health check
480+
health_task.result()
481+
except asyncio.CancelledError:
482+
tool_task.cancel()
483+
health_task.cancel()
484+
raise
485+
399486
def _resolve_progress_callback(
400487
self, tool_context: ToolContext
401488
) -> Optional[ProgressFnT]:

tests/unittests/tools/mcp_tool/test_mcp_tool.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
1516
import inspect
1617
from unittest.mock import AsyncMock
1718
from unittest.mock import Mock
@@ -1127,3 +1128,87 @@ def test_mcp_app_resource_uri_property_none(self):
11271128
mcp_session_manager=self.mock_session_manager,
11281129
)
11291130
assert tool2.mcp_app_resource_uri is None
1131+
1132+
@pytest.mark.asyncio
1133+
async def test_call_tool_with_health_check_disconnected_session(self):
1134+
"""Test that a disconnected session is detected quickly during tool call.
1135+
1136+
When the MCP server returns a non-2xx HTTP response, the transport can
1137+
crash in a background task without propagating the error to the pending
1138+
send_request(). This test verifies that the health check detects the
1139+
closed streams and raises a ConnectionError instead of hanging.
1140+
1141+
Regression test for https://github.com/google/adk-python/issues/4901
1142+
"""
1143+
tool = MCPTool(
1144+
mcp_tool=self.mock_mcp_tool,
1145+
mcp_session_manager=self.mock_session_manager,
1146+
)
1147+
1148+
# Create a mock session with streams that start open then close
1149+
mock_session = AsyncMock()
1150+
mock_read_stream = Mock()
1151+
mock_write_stream = Mock()
1152+
mock_read_stream._closed = False
1153+
mock_write_stream._closed = False
1154+
mock_session._read_stream = mock_read_stream
1155+
mock_session._write_stream = mock_write_stream
1156+
1157+
# Make call_tool hang indefinitely (simulating the bug)
1158+
hang_event = asyncio.Event()
1159+
1160+
async def hanging_call_tool(*args, **kwargs):
1161+
await hang_event.wait()
1162+
1163+
mock_session.call_tool = hanging_call_tool
1164+
1165+
# After a short delay, simulate the transport crashing by closing streams
1166+
async def simulate_transport_crash():
1167+
await asyncio.sleep(0.2)
1168+
mock_read_stream._closed = True
1169+
1170+
crash_task = asyncio.create_task(simulate_transport_crash())
1171+
1172+
with pytest.raises(ConnectionError, match='MCP session disconnected'):
1173+
await tool._call_tool_with_health_check(
1174+
session=mock_session,
1175+
arguments={"param1": "test"},
1176+
progress_callback=None,
1177+
meta=None,
1178+
)
1179+
1180+
crash_task.cancel()
1181+
try:
1182+
await crash_task
1183+
except asyncio.CancelledError:
1184+
pass
1185+
1186+
@pytest.mark.asyncio
1187+
async def test_call_tool_with_health_check_success(self):
1188+
"""Test that healthy sessions return tool results normally."""
1189+
tool = MCPTool(
1190+
mcp_tool=self.mock_mcp_tool,
1191+
mcp_session_manager=self.mock_session_manager,
1192+
)
1193+
1194+
mock_session = AsyncMock()
1195+
mock_read_stream = Mock()
1196+
mock_write_stream = Mock()
1197+
mock_read_stream._closed = False
1198+
mock_write_stream._closed = False
1199+
mock_session._read_stream = mock_read_stream
1200+
mock_session._write_stream = mock_write_stream
1201+
1202+
mcp_response = CallToolResult(
1203+
content=[TextContent(type="text", text="success")]
1204+
)
1205+
mock_session.call_tool = AsyncMock(return_value=mcp_response)
1206+
1207+
result = await tool._call_tool_with_health_check(
1208+
session=mock_session,
1209+
arguments={"param1": "test"},
1210+
progress_callback=None,
1211+
meta=None,
1212+
)
1213+
1214+
assert result == mcp_response

0 commit comments

Comments
 (0)