Skip to content

Commit 8dca09f

Browse files
authored
Merge branch 'main' into fix/anthropic-thinking-streaming-continuity
2 parents 03c8681 + bdece00 commit 8dca09f

14 files changed

Lines changed: 1055 additions & 186 deletions

File tree

src/google/adk/agents/base_agent.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import inspect
1818
import logging
19-
import sys
2019
from typing import Any
2120
from typing import AsyncGenerator
2221
from typing import Awaitable
@@ -286,9 +285,7 @@ async def run_async(
286285
Event: the events generated by the agent.
287286
"""
288287

289-
cm = tracer.start_as_current_span(f'invoke_agent {self.name}')
290-
span = cm.__enter__()
291-
try:
288+
with tracer.start_as_current_span(f'invoke_agent {self.name}') as span:
292289
ctx = self._create_invocation_context(parent_context)
293290
tracing.trace_agent_invocation(span, self, ctx)
294291
if event := await self._handle_before_agent_callback(ctx):
@@ -305,23 +302,6 @@ async def run_async(
305302

306303
if event := await self._handle_after_agent_callback(ctx):
307304
yield event
308-
except BaseException:
309-
try:
310-
cm.__exit__(*sys.exc_info())
311-
except ValueError:
312-
logger.warning(
313-
'Failed to detach context during generator cleanup, likely due to'
314-
' cancellation.'
315-
)
316-
raise
317-
else:
318-
try:
319-
cm.__exit__(None, None, None)
320-
except ValueError:
321-
logger.warning(
322-
'Failed to detach context during generator cleanup, likely due to'
323-
' cancellation.'
324-
)
325305

326306
@final
327307
async def run_live(

src/google/adk/features/_feature_registry.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ class FeatureName(str, Enum):
4141
GOOGLE_CREDENTIALS_CONFIG = "GOOGLE_CREDENTIALS_CONFIG"
4242
GOOGLE_TOOL = "GOOGLE_TOOL"
4343
JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL"
44+
# Private (leading underscore): not part of the public API surface.
45+
# GE flips this on by setting the env var
46+
# `ADK_ENABLE_MCP_GRACEFUL_ERROR_HANDLING=1`; nothing should import this
47+
# enum member by name. Keeping it private avoids a backward-compat
48+
# obligation for what is intended as a temporary, internal kill-switch.
49+
_MCP_GRACEFUL_ERROR_HANDLING = "MCP_GRACEFUL_ERROR_HANDLING"
4450
PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING"
4551
PUBSUB_TOOL_CONFIG = "PUBSUB_TOOL_CONFIG"
4652
PUBSUB_TOOLSET = "PUBSUB_TOOLSET"
@@ -131,6 +137,9 @@ class FeatureConfig:
131137
FeatureName.JSON_SCHEMA_FOR_FUNC_DECL: FeatureConfig(
132138
FeatureStage.WIP, default_on=False
133139
),
140+
FeatureName._MCP_GRACEFUL_ERROR_HANDLING: FeatureConfig(
141+
FeatureStage.EXPERIMENTAL, default_on=False
142+
),
134143
FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig(
135144
FeatureStage.EXPERIMENTAL, default_on=True
136145
),

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import asyncio
1919
import inspect
2020
import logging
21-
import sys
2221
from typing import AsyncGenerator
2322
from typing import Optional
2423
from typing import TYPE_CHECKING
@@ -1169,9 +1168,7 @@ async def _call_llm_async(
11691168
) -> AsyncGenerator[LlmResponse, None]:
11701169

11711170
async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
1172-
cm = tracer.start_as_current_span('call_llm')
1173-
span = cm.__enter__()
1174-
try:
1171+
with tracer.start_as_current_span('call_llm') as span:
11751172
# Runs before_model_callback inside the call_llm span so
11761173
# plugins observe the same span as after/error callbacks.
11771174
if response := await self._handle_before_model_callback(
@@ -1264,23 +1261,6 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
12641261
llm_response = altered
12651262

12661263
yield llm_response
1267-
except BaseException:
1268-
try:
1269-
cm.__exit__(*sys.exc_info())
1270-
except ValueError:
1271-
logger.warning(
1272-
'Failed to detach context during generator cleanup, likely due to'
1273-
' cancellation.'
1274-
)
1275-
raise
1276-
else:
1277-
try:
1278-
cm.__exit__(None, None, None)
1279-
except ValueError:
1280-
logger.warning(
1281-
'Failed to detach context during generator cleanup, likely due to'
1282-
' cancellation.'
1283-
)
12841264

12851265
async with Aclosing(_call_llm_with_tracing()) as agen:
12861266
async for event in agen:

src/google/adk/runners.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -543,9 +543,7 @@ async def _run_with_trace(
543543
new_message: Optional[types.Content] = None,
544544
invocation_id: Optional[str] = None,
545545
) -> AsyncGenerator[Event, None]:
546-
cm = tracer.start_as_current_span('invocation')
547-
span = cm.__enter__()
548-
try:
546+
with tracer.start_as_current_span('invocation'):
549547
session = await self._get_or_create_session(
550548
user_id=user_id,
551549
session_id=session_id,
@@ -629,23 +627,6 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
629627
self.session_service,
630628
skip_token_compaction=invocation_context.token_compaction_checked,
631629
)
632-
except BaseException:
633-
try:
634-
cm.__exit__(*sys.exc_info())
635-
except ValueError:
636-
logger.warning(
637-
'Failed to detach context during generator cleanup, likely due to'
638-
' cancellation.'
639-
)
640-
raise
641-
else:
642-
try:
643-
cm.__exit__(None, None, None)
644-
except ValueError:
645-
logger.warning(
646-
'Failed to detach context during generator cleanup, likely due to'
647-
' cancellation.'
648-
)
649630

650631
async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
651632
async for event in agen:

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848
_USAGE_METADATA_CUSTOM_METADATA_KEY = '_usage_metadata'
4949

5050

51+
def _quote_filter_literal(value: str) -> str:
52+
"""Quotes filter values so embedded metacharacters stay inside the literal."""
53+
escaped_value = value.replace('\\', '\\\\').replace('"', '\\"')
54+
return f'"{escaped_value}"'
55+
56+
5157
def _set_internal_custom_metadata(
5258
metadata_dict: dict[str, Any], *, key: str, value: dict[str, Any]
5359
) -> None:
@@ -228,7 +234,7 @@ async def list_sessions(
228234
sessions = []
229235
config = {}
230236
if user_id is not None:
231-
config['filter'] = f'user_id="{user_id}"'
237+
config['filter'] = f'user_id={_quote_filter_literal(user_id)}'
232238
sessions_iterator = await api_client.agent_engines.sessions.list(
233239
name=f'reasoningEngines/{reasoning_engine_id}',
234240
config=config,

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
from pydantic import BaseModel
4545
from pydantic import ConfigDict
4646

47+
from ...features import FeatureName
48+
from ...features import is_feature_enabled
4749
from .session_context import SessionContext
4850

4951
logger = logging.getLogger('google_adk.' + __name__)
@@ -237,11 +239,18 @@ def __init__(
237239
self._connection_params = connection_params
238240
self._errlog = errlog
239241

240-
# Session pool: maps session keys to (session, exit_stack, loop) tuples
242+
# Session pool: maps session keys to (session, exit_stack, loop) tuples.
243+
# Kept as a tuple for backward-compatibility with downstream tests
244+
# that construct or unpack entries directly.
241245
self._sessions: Dict[
242246
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop]
243247
] = {}
244248

249+
# Sibling pool: maps session keys to their SessionContext. Stored
250+
# separately from `_sessions` so the tuple shape above stays stable.
251+
# Used by McpTool to access `_run_guarded` for transport-crash detection.
252+
self._session_contexts: Dict[str, SessionContext] = {}
253+
245254
# Map of event loops to their respective locks to prevent race conditions
246255
# across different event loops in session creation.
247256
self._session_lock_map: dict[asyncio.AbstractEventLoop, asyncio.Lock] = {}
@@ -323,6 +332,26 @@ def _is_session_disconnected(self, session: ClientSession) -> bool:
323332
"""
324333
return session._read_stream._closed or session._write_stream._closed
325334

335+
def _get_session_context(
336+
self, headers: Optional[Dict[str, str]] = None
337+
) -> Optional[SessionContext]:
338+
"""Returns the SessionContext for the session matching the given headers.
339+
340+
Note: This method reads from the session-context pool without acquiring
341+
``_session_lock``. This is safe because it is called immediately after
342+
``create_session()`` (which populates the entry under the lock) within
343+
the same task, and dict reads are atomic in CPython.
344+
345+
Args:
346+
headers: Optional headers used to identify the session.
347+
348+
Returns:
349+
The SessionContext if a matching session exists, None otherwise.
350+
"""
351+
merged_headers = self._merge_headers(headers)
352+
session_key = self._generate_session_key(merged_headers)
353+
return self._session_contexts.get(session_key)
354+
326355
async def _cleanup_session(
327356
self,
328357
session_key: str,
@@ -378,6 +407,10 @@ def cleanup_done(f: asyncio.Future):
378407
finally:
379408
if session_key in self._sessions:
380409
del self._sessions[session_key]
410+
# Also drop the SessionContext reference so we don't leak the
411+
# SessionContext after its underlying session is gone.
412+
if session_key in self._session_contexts:
413+
del self._session_contexts[session_key]
381414

382415
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
383416
"""Creates an MCP client based on the connection parameters.
@@ -453,15 +486,30 @@ async def create_session(
453486
if session_key in self._sessions:
454487
session, exit_stack, stored_loop = self._sessions[session_key]
455488

456-
# Check if the existing session is still connected and bound to the current loop
489+
# Check if the existing session is still connected and bound to
490+
# the current loop. When the feature flag is on, we ALSO check the
491+
# SessionContext's background task: a crashed transport can leave
492+
# the session's read/write streams open even though the underlying
493+
# task has already died (e.g. after a 4xx/5xx HTTP response).
494+
# Without that extra check, callers would reuse a dead session and
495+
# hang on the next call. The check is gated because it triggers
496+
# session re-creation in some test mocks where `_task` looks
497+
# "not alive" but the streams are otherwise reusable.
457498
current_loop = asyncio.get_running_loop()
458-
if stored_loop is current_loop and not self._is_session_disconnected(
459-
session
499+
if is_feature_enabled(FeatureName._MCP_GRACEFUL_ERROR_HANDLING): # pylint: disable=protected-access
500+
ctx = self._session_contexts.get(session_key)
501+
ctx_alive = ctx is None or ctx._is_task_alive # pylint: disable=protected-access
502+
else:
503+
ctx_alive = True # Pre-fix: do not consult task aliveness
504+
if (
505+
stored_loop is current_loop
506+
and not self._is_session_disconnected(session)
507+
and ctx_alive
460508
):
461509
# Session is still good, return it
462510
return session
463511
else:
464-
# Session is disconnected or from a different loop, clean it up
512+
# Session is disconnected, dead, or from a different loop; clean up.
465513
logger.info(
466514
'Cleaning up session (disconnected or different loop): %s',
467515
session_key,
@@ -485,26 +533,32 @@ async def create_session(
485533
client = self._create_client(merged_headers)
486534
is_stdio = isinstance(self._connection_params, StdioConnectionParams)
487535

536+
session_context = SessionContext(
537+
client=client,
538+
timeout=timeout_in_seconds,
539+
sse_read_timeout=sse_read_timeout_in_seconds,
540+
is_stdio=is_stdio,
541+
sampling_callback=self._sampling_callback,
542+
sampling_capabilities=self._sampling_capabilities,
543+
)
544+
488545
session = await asyncio.wait_for(
489-
exit_stack.enter_async_context(
490-
SessionContext(
491-
client=client,
492-
timeout=timeout_in_seconds,
493-
sse_read_timeout=sse_read_timeout_in_seconds,
494-
is_stdio=is_stdio,
495-
sampling_callback=self._sampling_callback,
496-
sampling_capabilities=self._sampling_capabilities,
497-
)
498-
),
546+
exit_stack.enter_async_context(session_context),
499547
timeout=timeout_in_seconds,
500548
)
501549

502-
# Store session, exit stack, and loop in the pool
550+
# Store session, exit stack, and loop in the pool. The pool storage
551+
# remains a tuple for backward-compatibility with downstream tests
552+
# that construct or unpack entries directly.
503553
self._sessions[session_key] = (
504554
session,
505555
exit_stack,
506556
asyncio.get_running_loop(),
507557
)
558+
# Track the SessionContext in a sibling dict so McpTool can call
559+
# `_run_guarded` on it. Stored separately to avoid changing the
560+
# shape of `_sessions` (which is a public-ish internal surface).
561+
self._session_contexts[session_key] = session_context
508562
logger.debug('Created new session: %s', session_key)
509563
return session
510564

@@ -524,6 +578,7 @@ def __getstate__(self):
524578
state = self.__dict__.copy()
525579
# Remove unpicklable entries or those that shouldn't persist across pickle
526580
state['_sessions'] = {}
581+
state['_session_contexts'] = {}
527582
state['_session_lock_map'] = {}
528583

529584
# Locks and file-like objects cannot be pickled
@@ -537,6 +592,7 @@ def __setstate__(self, state):
537592
self.__dict__.update(state)
538593
# Re-initialize members that were not pickled
539594
self._sessions = {}
595+
self._session_contexts = {}
540596
self._session_lock_map = {}
541597
self._lock_map_lock = threading.Lock()
542598
# If _errlog was removed during pickling, default to sys.stderr

0 commit comments

Comments
 (0)