Skip to content

Commit ccb8138

Browse files
xuanyang15copybara-github
authored andcommitted
fix: Fix error swallowing in MCP session context
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 936815704
1 parent 171ae9e commit ccb8138

2 files changed

Lines changed: 76 additions & 3 deletions

File tree

src/google/adk/tools/mcp_tool/session_context.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,34 @@
3636
_T = TypeVar('_T')
3737

3838

39+
def _format_exception(exc: BaseException | None) -> str:
40+
"""Formats an exception into a readable string representation.
41+
42+
This handles `ExceptionGroup` (by flattening inner exceptions) and optionally
43+
extracts HTTP response bodies for network-related errors, truncating them
44+
to 1000 characters to prevent log/context overflow.
45+
46+
Args:
47+
exc: The exception to format.
48+
49+
Returns:
50+
A formatted string representing the exception and its pertinent details.
51+
"""
52+
if exc is None:
53+
return 'None'
54+
if hasattr(exc, 'exceptions') and getattr(exc, 'exceptions'):
55+
return ' | '.join(_format_exception(e) for e in exc.exceptions)
56+
if hasattr(exc, 'response') and exc.response is not None:
57+
try:
58+
response_text = exc.response.text
59+
if len(response_text) > 1000:
60+
response_text = response_text[:1000] + '... [truncated]'
61+
return f'{exc} (Response: {response_text})'
62+
except Exception:
63+
pass
64+
return str(exc)
65+
66+
3967
class SessionContext:
4068
"""Represents the context of a single MCP session within a dedicated task.
4169
@@ -143,7 +171,8 @@ def _retrieve_exception(t: asyncio.Task):
143171

144172
if self._task.done() and self._task.exception():
145173
raise ConnectionError(
146-
f'Failed to create MCP session: {self._task.exception()}'
174+
'Failed to create MCP session:'
175+
f' {_format_exception(self._task.exception())}'
147176
) from self._task.exception()
148177

149178
# Pre-fix code returned `self._session` here directly (typed as
@@ -186,7 +215,7 @@ async def _run_guarded(self, coro: Coroutine[Any, Any, _T]) -> _T:
186215
# Close the coroutine to avoid "was never awaited" warnings.
187216
coro.close()
188217
raise ConnectionError(
189-
f'MCP session task has already terminated: {exc}'
218+
f'MCP session task has already terminated: {_format_exception(exc)}'
190219
) from exc
191220

192221
coro_task = asyncio.ensure_future(coro)
@@ -212,7 +241,9 @@ async def _run_guarded(self, coro: Coroutine[Any, Any, _T]) -> _T:
212241
pass
213242

214243
exc = self._task.exception() if not self._task.cancelled() else None
215-
raise ConnectionError(f'MCP session connection lost: {exc}') from exc
244+
raise ConnectionError(
245+
f'MCP session connection lost: {_format_exception(exc)}'
246+
) from exc
216247

217248
async def close(self):
218249
"""Signal the context task to close and wait for cleanup."""

tests/unittests/tools/mcp_tool/test_session_context.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323

2424
from google.adk.features import FeatureName
2525
from google.adk.features._feature_registry import temporary_feature_override
26+
from google.adk.tools.mcp_tool.session_context import _format_exception
2627
from google.adk.tools.mcp_tool.session_context import SessionContext
28+
import httpx
2729
from mcp import ClientSession
2830
import pytest
2931

@@ -896,3 +898,43 @@ async def test_no_extra_none_check_when_flag_off(self):
896898
assert result is not None
897899
finally:
898900
await session_context.close()
901+
902+
903+
class TestFormatException:
904+
"""Test suite for _format_exception helper."""
905+
906+
def test_format_exception_normal(self):
907+
exc = ValueError('normal error')
908+
assert _format_exception(exc) == 'normal error'
909+
910+
def test_format_exception_http_status_error(self):
911+
request = httpx.Request('GET', 'http://test')
912+
response = httpx.Response(403, request=request, text='Forbidden access')
913+
exc = httpx.HTTPStatusError(
914+
'403 Forbidden', request=request, response=response
915+
)
916+
917+
formatted = _format_exception(exc)
918+
assert '403 Forbidden' in formatted
919+
assert 'Forbidden access' in formatted
920+
921+
def test_format_exception_group(self):
922+
class MockExceptionGroup(Exception):
923+
924+
def __init__(self, message, exceptions):
925+
super().__init__(message)
926+
self.exceptions = exceptions
927+
928+
request = httpx.Request('GET', 'http://test')
929+
response = httpx.Response(403, request=request, text='Forbidden access')
930+
exc1 = httpx.HTTPStatusError(
931+
'403 Forbidden', request=request, response=response
932+
)
933+
exc2 = ValueError('another error')
934+
935+
eg = MockExceptionGroup('Group', [exc1, exc2])
936+
formatted = _format_exception(eg)
937+
938+
assert '403 Forbidden' in formatted
939+
assert 'Forbidden access' in formatted
940+
assert 'another error' in formatted

0 commit comments

Comments
 (0)