Skip to content

Commit 4309159

Browse files
sasha-gitgcopybara-github
authored andcommitted
fix(tools): Prevent AnyIO CancelScope task boundary violations during MCP session creation failure
Co-authored-by: Sasha Sobran <asobran@google.com> PiperOrigin-RevId: 916220631
1 parent ec54bd4 commit 4309159

2 files changed

Lines changed: 39 additions & 4 deletions

File tree

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -542,10 +542,13 @@ async def create_session(
542542
sampling_capabilities=self._sampling_capabilities,
543543
)
544544

545-
session = await asyncio.wait_for(
546-
exit_stack.enter_async_context(session_context),
547-
timeout=timeout_in_seconds,
548-
)
545+
if is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING): # pylint: disable=protected-access
546+
session = await exit_stack.enter_async_context(session_context)
547+
else:
548+
session = await asyncio.wait_for(
549+
exit_stack.enter_async_context(session_context),
550+
timeout=timeout_in_seconds,
551+
)
549552

550553
# Store session, exit stack, and loop in the pool. The pool storage
551554
# remains a tuple for backward-compatibility with downstream tests

tests/unittests/tools/mcp_tool/test_mcp_session_manager.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,3 +1043,35 @@ def test_env_var_disable_acts_as_kill_switch(self):
10431043
os.environ[disable] = saved_disable
10441044
if saved_enable is not None:
10451045
os.environ[enable] = saved_enable
1046+
1047+
@pytest.mark.asyncio
1048+
@patch("google.adk.tools.mcp_tool.mcp_session_manager.asyncio.wait_for")
1049+
async def test_create_session_does_not_use_wait_for_when_ge_is_enabled(
1050+
self, mock_wait_for
1051+
):
1052+
"""create_session must not wrap enter_async_context in asyncio.wait_for when GE is enabled."""
1053+
from google.adk.features import FeatureName
1054+
from google.adk.features._feature_registry import temporary_feature_override
1055+
1056+
manager = MCPSessionManager(
1057+
StdioConnectionParams(
1058+
server_params=StdioServerParameters(command="dummy", args=[]),
1059+
timeout=5.0,
1060+
)
1061+
)
1062+
with temporary_feature_override(
1063+
FeatureName._MCP_GRACEFUL_ERROR_HANDLING, True
1064+
):
1065+
with patch(
1066+
"google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack"
1067+
) as mock_stack:
1068+
mock_stack.return_value.enter_async_context = AsyncMock()
1069+
with patch(
1070+
"google.adk.tools.mcp_tool.mcp_session_manager.SessionContext"
1071+
):
1072+
with patch(
1073+
"google.adk.tools.mcp_tool.mcp_session_manager.stdio_client"
1074+
):
1075+
await manager.create_session()
1076+
1077+
mock_wait_for.assert_not_called()

0 commit comments

Comments
 (0)