Skip to content

Commit 750b6ff

Browse files
authored
Merge branch 'main' into codex/adk-build-graph-image-route
2 parents 6d64415 + e3060ca commit 750b6ff

7 files changed

Lines changed: 1014 additions & 23 deletions

File tree

src/google/adk/features/_feature_registry.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ class FeatureName(str, Enum):
4141
GOOGLE_CREDENTIALS_CONFIG = "GOOGLE_CREDENTIALS_CONFIG"
4242
GOOGLE_TOOL = "GOOGLE_TOOL"
4343
JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL"
44+
# Private (leading underscore): not part of the public API surface.
45+
# GE flips this on by setting the env var
46+
# `ADK_ENABLE_MCP_GRACEFUL_ERROR_HANDLING=1`; nothing should import this
47+
# enum member by name. Keeping it private avoids a backward-compat
48+
# obligation for what is intended as a temporary, internal kill-switch.
49+
_MCP_GRACEFUL_ERROR_HANDLING = "MCP_GRACEFUL_ERROR_HANDLING"
4450
PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING"
4551
PUBSUB_TOOL_CONFIG = "PUBSUB_TOOL_CONFIG"
4652
PUBSUB_TOOLSET = "PUBSUB_TOOLSET"
@@ -131,6 +137,9 @@ class FeatureConfig:
131137
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL: FeatureConfig(
132138
FeatureStage.WIP, default_on=False
133139
),
140+
FeatureName._MCP_GRACEFUL_ERROR_HANDLING: FeatureConfig(
141+
FeatureStage.EXPERIMENTAL, default_on=False
142+
),
134143
FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig(
135144
FeatureStage.EXPERIMENTAL, default_on=True
136145
),

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
from pydantic import BaseModel
4545
from pydantic import ConfigDict
4646

47+
from ...features import FeatureName
48+
from ...features import is_feature_enabled
4749
from .session_context import SessionContext
4850

4951
logger = logging.getLogger('google_adk.' + __name__)
@@ -237,11 +239,18 @@ def __init__(
237239
self._connection_params = connection_params
238240
self._errlog = errlog
239241

240-
# Session pool: maps session keys to (session, exit_stack, loop) tuples
242+
# Session pool: maps session keys to (session, exit_stack, loop) tuples.
243+
# Kept as a tuple for backward-compatibility with downstream tests
244+
# that construct or unpack entries directly.
241245
self._sessions: Dict[
242246
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop]
243247
] = {}
244248

249+
# Sibling pool: maps session keys to their SessionContext. Stored
250+
# separately from `_sessions` so the tuple shape above stays stable.
251+
# Used by McpTool to access `_run_guarded` for transport-crash detection.
252+
self._session_contexts: Dict[str, SessionContext] = {}
253+
245254
# Map of event loops to their respective locks to prevent race conditions
246255
# across different event loops in session creation.
247256
self._session_lock_map: dict[asyncio.AbstractEventLoop, asyncio.Lock] = {}
@@ -323,6 +332,26 @@ def _is_session_disconnected(self, session: ClientSession) -> bool:
323332
"""
324333
return session._read_stream._closed or session._write_stream._closed
325334

335+
def _get_session_context(
336+
self, headers: Optional[Dict[str, str]] = None
337+
) -> Optional[SessionContext]:
338+
"""Returns the SessionContext for the session matching the given headers.
339+
340+
Note: This method reads from the session-context pool without acquiring
341+
``_session_lock``. This is safe because it is called immediately after
342+
``create_session()`` (which populates the entry under the lock) within
343+
the same task, and dict reads are atomic in CPython.
344+
345+
Args:
346+
headers: Optional headers used to identify the session.
347+
348+
Returns:
349+
The SessionContext if a matching session exists, None otherwise.
350+
"""
351+
merged_headers = self._merge_headers(headers)
352+
session_key = self._generate_session_key(merged_headers)
353+
return self._session_contexts.get(session_key)
354+
326355
async def _cleanup_session(
327356
self,
328357
session_key: str,
@@ -378,6 +407,10 @@ def cleanup_done(f: asyncio.Future):
378407
finally:
379408
if session_key in self._sessions:
380409
del self._sessions[session_key]
410+
# Also drop the SessionContext reference so we don't leak the
411+
# SessionContext after its underlying session is gone.
412+
if session_key in self._session_contexts:
413+
del self._session_contexts[session_key]
381414

382415
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
383416
"""Creates an MCP client based on the connection parameters.
@@ -453,15 +486,30 @@ async def create_session(
453486
if session_key in self._sessions:
454487
session, exit_stack, stored_loop = self._sessions[session_key]
455488

456-
# Check if the existing session is still connected and bound to the current loop
489+
# Check if the existing session is still connected and bound to
490+
# the current loop. When the feature flag is on, we ALSO check the
491+
# SessionContext's background task: a crashed transport can leave
492+
# the session's read/write streams open even though the underlying
493+
# task has already died (e.g. after a 4xx/5xx HTTP response).
494+
# Without that extra check, callers would reuse a dead session and
495+
# hang on the next call. The check is gated because it triggers
496+
# session re-creation in some test mocks where `_task` looks
497+
# "not alive" but the streams are otherwise reusable.
457498
current_loop = asyncio.get_running_loop()
458-
if stored_loop is current_loop and not self._is_session_disconnected(
459-
session
499+
if is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING): # pylint: disable=protected-access
500+
ctx = self._session_contexts.get(session_key)
501+
ctx_alive = ctx is None or ctx._is_task_alive # pylint: disable=protected-access
502+
else:
503+
ctx_alive = True # Pre-fix: do not consult task aliveness
504+
if (
505+
stored_loop is current_loop
506+
and not self._is_session_disconnected(session)
507+
and ctx_alive
460508
):
461509
# Session is still good, return it
462510
return session
463511
else:
464-
# Session is disconnected or from a different loop, clean it up
512+
# Session is disconnected, dead, or from a different loop; clean up.
465513
logger.info(
466514
'Cleaning up session (disconnected or different loop): %s',
467515
session_key,
@@ -485,26 +533,32 @@ async def create_session(
485533
client = self._create_client(merged_headers)
486534
is_stdio = isinstance(self._connection_params, StdioConnectionParams)
487535

536+
session_context = SessionContext(
537+
client=client,
538+
timeout=timeout_in_seconds,
539+
sse_read_timeout=sse_read_timeout_in_seconds,
540+
is_stdio=is_stdio,
541+
sampling_callback=self._sampling_callback,
542+
sampling_capabilities=self._sampling_capabilities,
543+
)
544+
488545
session = await asyncio.wait_for(
489-
exit_stack.enter_async_context(
490-
SessionContext(
491-
client=client,
492-
timeout=timeout_in_seconds,
493-
sse_read_timeout=sse_read_timeout_in_seconds,
494-
is_stdio=is_stdio,
495-
sampling_callback=self._sampling_callback,
496-
sampling_capabilities=self._sampling_capabilities,
497-
)
498-
),
546+
exit_stack.enter_async_context(session_context),
499547
timeout=timeout_in_seconds,
500548
)
501549

502-
# Store session, exit stack, and loop in the pool
550+
# Store session, exit stack, and loop in the pool. The pool storage
551+
# remains a tuple for backward-compatibility with downstream tests
552+
# that construct or unpack entries directly.
503553
self._sessions[session_key] = (
504554
session,
505555
exit_stack,
506556
asyncio.get_running_loop(),
507557
)
558+
# Track the SessionContext in a sibling dict so McpTool can call
559+
# `_run_guarded` on it. Stored separately to avoid changing the
560+
# shape of `_sessions` (which is a public-ish internal surface).
561+
self._session_contexts[session_key] = session_context
508562
logger.debug('Created new session: %s', session_key)
509563
return session
510564

@@ -524,6 +578,7 @@ def __getstate__(self):
524578
state = self.__dict__.copy()
525579
# Remove unpicklable entries or those that shouldn't persist across pickle
526580
state['_sessions'] = {}
581+
state['_session_contexts'] = {}
527582
state['_session_lock_map'] = {}
528583

529584
# Locks and file-like objects cannot be pickled
@@ -537,6 +592,7 @@ def __setstate__(self, state):
537592
self.__dict__.update(state)
538593
# Re-initialize members that were not pickled
539594
self._sessions = {}
595+
self._session_contexts = {}
540596
self._session_lock_map = {}
541597
self._lock_map_lock = threading.Lock()
542598
# If _errlog was removed during pickling, default to sys.stderr

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from fastapi.openapi.models import APIKeyIn
3333
from google.genai.types import FunctionDeclaration
34+
from mcp.shared.exceptions import McpError
3435
from mcp.shared.session import ProgressFnT
3536
from mcp.types import Tool as McpBaseTool
3637
from opentelemetry import propagate
@@ -45,11 +46,18 @@
4546
from ...features import FeatureName
4647
from ...features import is_feature_enabled
4748
from ...utils.context_utils import find_context_parameter
49+
# `is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING)` gates the
50+
# error-boundary and transport-crash-detection behavior added in this module.
51+
# When the flag is off (default) or via ADK_DISABLE_MCP_GRACEFUL_ERROR_HANDLING=1
52+
# `run_async` and `_run_async_impl` fall back to the pre-fix behavior.
53+
# The enum member is intentionally private (leading underscore) so it is not
54+
# part of the ADK public API; consumers flip the env var, not the symbol.
4855
from .._gemini_schema_util import _to_gemini_schema
4956
from ..base_authenticated_tool import BaseAuthenticatedTool
5057
from ..tool_context import ToolContext
5158
from .mcp_session_manager import MCPSessionManager
5259
from .mcp_session_manager import retry_on_errors
60+
from .session_context import SessionContext
5361

5462
logger = logging.getLogger("google_adk." + __name__)
5563

@@ -339,7 +347,26 @@ async def run_async(
339347
}
340348
elif not tool_context.tool_confirmation.confirmed:
341349
return {"error": "This tool call is rejected."}
342-
return await super().run_async(args=args, tool_context=tool_context)
350+
351+
if not is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING): # pylint: disable=protected-access
352+
# Pre-fix behavior: exceptions bubble up to the agent runner.
353+
return await super().run_async(args=args, tool_context=tool_context)
354+
355+
# New behavior: convert MCP-level and unexpected errors into a
356+
# structured `{"error": "..."}` dict so the agent loop can continue
357+
# gracefully instead of being killed by an unhandled exception. This
358+
# is the primary fix for the 5-minute hang seen when Model Armor (or
359+
# any AGW policy) returns a 403 mid-tool-call.
360+
try:
361+
return await super().run_async(args=args, tool_context=tool_context)
362+
except McpError as e:
363+
logger.warning("MCP tool execution failed with McpError: %s", e)
364+
return {"error": f"MCP tool execution failed: {e}"}
365+
except Exception as e: # pylint: disable=broad-exception-caught
366+
logger.warning(
367+
"Unexpected error during MCP tool execution: %s", e, exc_info=True
368+
)
369+
return {"error": f"Unexpected error during MCP tool execution: {e}"}
343370

344371
@retry_on_errors
345372
@override
@@ -384,12 +411,39 @@ async def _run_async_impl(
384411
# Resolve progress callback (may be a factory that needs runtime context)
385412
resolved_callback = self._resolve_progress_callback(tool_context)
386413

387-
response = await session.call_tool(
414+
call_coro = session.call_tool(
388415
self._mcp_tool.name,
389416
arguments=args,
390417
progress_callback=resolved_callback,
391418
meta=meta_trace_context,
392419
)
420+
421+
if is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING): # pylint: disable=protected-access
422+
# Race the tool call against the background session task so that
423+
# transport crashes (e.g. non-2xx HTTP responses from an AGW with
424+
# Model Armor) surface immediately instead of hanging until
425+
# sse_read_timeout (default 5 minutes) expires. ConnectionError is
426+
# intentionally NOT caught here; it propagates to retry_on_errors,
427+
# which will create a fresh session and retry once before finally
428+
# surfacing the failure to the agent (where the run_async wrapper
429+
# converts it into an `{"error": ...}` dict).
430+
#
431+
# The isinstance check is intentional: tests and external subclasses
432+
# may inject mock session managers whose `_get_session_context`
433+
# returns a Mock instead of a real SessionContext (or None). Falling
434+
# back to the direct await keeps those callers working.
435+
session_context = self._mcp_session_manager._get_session_context( # pylint: disable=protected-access
436+
headers=final_headers
437+
)
438+
if isinstance(session_context, SessionContext):
439+
response = await session_context._run_guarded(call_coro) # pylint: disable=protected-access
440+
else:
441+
response = await call_coro
442+
else:
443+
# Pre-fix behavior: await the call directly. This is what causes the
444+
# ~300s hang when the underlying transport crashes.
445+
response = await call_coro
446+
393447
result = response.model_dump(exclude_none=True, mode="json")
394448

395449
# Push UI widget to the event actions if the tool supports it.

0 commit comments

Comments
 (0)