4444from pydantic import BaseModel
4545from pydantic import ConfigDict
4646
47+ from ...features import FeatureName
48+ from ...features import is_feature_enabled
4749from .session_context import SessionContext
4850
4951logger = logging .getLogger ('google_adk.' + __name__ )
@@ -237,11 +239,18 @@ def __init__(
237239 self ._connection_params = connection_params
238240 self ._errlog = errlog
239241
240- # Session pool: maps session keys to (session, exit_stack, loop) tuples
242+ # Session pool: maps session keys to (session, exit_stack, loop) tuples.
243+ # Kept as a tuple for backward-compatibility with downstream tests
244+ # that construct or unpack entries directly.
241245 self ._sessions : Dict [
242246 str , tuple [ClientSession , AsyncExitStack , asyncio .AbstractEventLoop ]
243247 ] = {}
244248
249+ # Sibling pool: maps session keys to their SessionContext. Stored
250+ # separately from `_sessions` so the tuple shape above stays stable.
251+ # Used by McpTool to access `_run_guarded` for transport-crash detection.
252+ self ._session_contexts : Dict [str , SessionContext ] = {}
253+
245254 # Map of event loops to their respective locks to prevent race conditions
246255 # across different event loops in session creation.
247256 self ._session_lock_map : dict [asyncio .AbstractEventLoop , asyncio .Lock ] = {}
@@ -323,6 +332,26 @@ def _is_session_disconnected(self, session: ClientSession) -> bool:
323332 """
324333 return session ._read_stream ._closed or session ._write_stream ._closed
325334
335+ def _get_session_context (
336+ self , headers : Optional [Dict [str , str ]] = None
337+ ) -> Optional [SessionContext ]:
338+ """Returns the SessionContext for the session matching the given headers.
339+
340+ Note: This method reads from the session-context pool without acquiring
341+ ``_session_lock``. This is safe because it is called immediately after
342+ ``create_session()`` (which populates the entry under the lock) within
343+ the same task, and dict reads are atomic in CPython.
344+
345+ Args:
346+ headers: Optional headers used to identify the session.
347+
348+ Returns:
349+ The SessionContext if a matching session exists, None otherwise.
350+ """
351+ merged_headers = self ._merge_headers (headers )
352+ session_key = self ._generate_session_key (merged_headers )
353+ return self ._session_contexts .get (session_key )
354+
326355 async def _cleanup_session (
327356 self ,
328357 session_key : str ,
@@ -378,6 +407,10 @@ def cleanup_done(f: asyncio.Future):
378407 finally :
379408 if session_key in self ._sessions :
380409 del self ._sessions [session_key ]
410+ # Also drop the SessionContext reference so we don't leak the
411+ # SessionContext after its underlying session is gone.
412+ if session_key in self ._session_contexts :
413+ del self ._session_contexts [session_key ]
381414
382415 def _create_client (self , merged_headers : Optional [Dict [str , str ]] = None ):
383416 """Creates an MCP client based on the connection parameters.
@@ -453,15 +486,30 @@ async def create_session(
453486 if session_key in self ._sessions :
454487 session , exit_stack , stored_loop = self ._sessions [session_key ]
455488
456- # Check if the existing session is still connected and bound to the current loop
489+ # Check if the existing session is still connected and bound to
490+ # the current loop. When the feature flag is on, we ALSO check the
491+ # SessionContext's background task: a crashed transport can leave
492+ # the session's read/write streams open even though the underlying
493+ # task has already died (e.g. after a 4xx/5xx HTTP response).
494+ # Without that extra check, callers would reuse a dead session and
495+ # hang on the next call. The check is gated because it triggers
496+ # session re-creation in some test mocks where `_task` looks
497+ # "not alive" but the streams are otherwise reusable.
457498 current_loop = asyncio .get_running_loop ()
458- if stored_loop is current_loop and not self ._is_session_disconnected (
459- session
499+ if is_feature_enabled (FeatureName ._MCP_GRACEFUL_ERROR_HANDLING ): # pylint: disable=protected-access
500+ ctx = self ._session_contexts .get (session_key )
501+ ctx_alive = ctx is None or ctx ._is_task_alive # pylint: disable=protected-access
502+ else :
503+ ctx_alive = True # Pre-fix: do not consult task aliveness
504+ if (
505+ stored_loop is current_loop
506+ and not self ._is_session_disconnected (session )
507+ and ctx_alive
460508 ):
461509 # Session is still good, return it
462510 return session
463511 else :
464- # Session is disconnected or from a different loop, clean it up
512+ # Session is disconnected, dead, or from a different loop; clean up.
465513 logger .info (
466514 'Cleaning up session (disconnected or different loop): %s' ,
467515 session_key ,
@@ -485,26 +533,32 @@ async def create_session(
485533 client = self ._create_client (merged_headers )
486534 is_stdio = isinstance (self ._connection_params , StdioConnectionParams )
487535
536+ session_context = SessionContext (
537+ client = client ,
538+ timeout = timeout_in_seconds ,
539+ sse_read_timeout = sse_read_timeout_in_seconds ,
540+ is_stdio = is_stdio ,
541+ sampling_callback = self ._sampling_callback ,
542+ sampling_capabilities = self ._sampling_capabilities ,
543+ )
544+
488545 session = await asyncio .wait_for (
489- exit_stack .enter_async_context (
490- SessionContext (
491- client = client ,
492- timeout = timeout_in_seconds ,
493- sse_read_timeout = sse_read_timeout_in_seconds ,
494- is_stdio = is_stdio ,
495- sampling_callback = self ._sampling_callback ,
496- sampling_capabilities = self ._sampling_capabilities ,
497- )
498- ),
546+ exit_stack .enter_async_context (session_context ),
499547 timeout = timeout_in_seconds ,
500548 )
501549
502- # Store session, exit stack, and loop in the pool
550+ # Store session, exit stack, and loop in the pool. The pool storage
551+ # remains a tuple for backward-compatibility with downstream tests
552+ # that construct or unpack entries directly.
503553 self ._sessions [session_key ] = (
504554 session ,
505555 exit_stack ,
506556 asyncio .get_running_loop (),
507557 )
558+ # Track the SessionContext in a sibling dict so McpTool can call
559+ # `_run_guarded` on it. Stored separately to avoid changing the
560+ # shape of `_sessions` (which is a public-ish internal surface).
561+ self ._session_contexts [session_key ] = session_context
508562 logger .debug ('Created new session: %s' , session_key )
509563 return session
510564
@@ -524,6 +578,7 @@ def __getstate__(self):
524578 state = self .__dict__ .copy ()
525579 # Remove unpicklable entries or those that shouldn't persist across pickle
526580 state ['_sessions' ] = {}
581+ state ['_session_contexts' ] = {}
527582 state ['_session_lock_map' ] = {}
528583
529584 # Locks and file-like objects cannot be pickled
@@ -537,6 +592,7 @@ def __setstate__(self, state):
537592 self .__dict__ .update (state )
538593 # Re-initialize members that were not pickled
539594 self ._sessions = {}
595+ self ._session_contexts = {}
540596 self ._session_lock_map = {}
541597 self ._lock_map_lock = threading .Lock ()
542598 # If _errlog was removed during pickling, default to sys.stderr
0 commit comments