1616
1717import asyncio
1818from collections import deque
19+ import concurrent .futures
1920from contextlib import AsyncExitStack
21+ from dataclasses import dataclass
2022from datetime import timedelta
2123import functools
2224import hashlib
2527import sys
2628import threading
2729from typing import Any
30+ from typing import Callable
31+ from typing import cast
2832from typing import Dict
2933from typing import Optional
3034from typing import Protocol
3135from typing import runtime_checkable
3236from typing import TextIO
37+ from typing import TYPE_CHECKING
38+ from typing import TypeVar
3339from typing import Union
3440
41+ if TYPE_CHECKING :
42+ from .session_context import SessionContext
43+
3544from mcp import ClientSession
3645from mcp import SamplingCapability
3746from mcp import StdioServerParameters
4453from pydantic import BaseModel
4554from pydantic import ConfigDict
4655
47- from .session_context import SessionContext
48-
4956logger = logging .getLogger ('google_adk.' + __name__ )
5057
5158
@@ -146,7 +153,10 @@ class StreamableHTTPConnectionParams(BaseModel):
146153 httpx_client_factory : CheckableMcpHttpClientFactory = create_mcp_http_client
147154
148155
149- def retry_on_errors (func ):
156+ _F = TypeVar ('_F' , bound = Callable [..., Any ])
157+
158+
159+ def retry_on_errors (func : _F ) -> _F :
150160 """Decorator to automatically retry action when MCP session errors occur.
151161
152162 When MCP session errors occur, the decorator will automatically retry the
@@ -165,7 +175,7 @@ def retry_on_errors(func):
165175 """
166176
167177 @functools .wraps (func ) # Preserves original function metadata
168- async def wrapper (self , * args , ** kwargs ) :
178+ async def wrapper (self : Any , * args : Any , ** kwargs : Any ) -> Any :
169179 try :
170180 return await func (self , * args , ** kwargs )
171181 except Exception as e :
@@ -182,7 +192,17 @@ async def wrapper(self, *args, **kwargs):
182192 logger .info ('Retrying %s due to error: %s' , func .__name__ , e )
183193 return await func (self , * args , ** kwargs )
184194
185- return wrapper
195+ return cast (_F , wrapper )
196+
197+
198+ @dataclass
199+ class _SessionEntry :
200+ """A dataclass to hold session information."""
201+
202+ session : ClientSession
203+ exit_stack : AsyncExitStack
204+ loop : asyncio .AbstractEventLoop
205+ context : SessionContext
186206
187207
188208class MCPSessionManager :
@@ -205,7 +225,7 @@ def __init__(
205225 * ,
206226 sampling_callback : Optional [SamplingFnT ] = None ,
207227 sampling_capabilities : Optional [SamplingCapability ] = None ,
208- ):
228+ ) -> None :
209229 """Initializes the MCP session manager.
210230
211231 Args:
@@ -237,10 +257,8 @@ def __init__(
237257 self ._connection_params = connection_params
238258 self ._errlog = errlog
239259
240- # Session pool: maps session keys to (session, exit_stack, loop) tuples
241- self ._sessions : Dict [
242- str , tuple [ClientSession , AsyncExitStack , asyncio .AbstractEventLoop ]
243- ] = {}
260+ # Session pool: maps session keys to _SessionEntry objects
261+ self ._sessions : Dict [str , _SessionEntry ] = {}
244262
245263 # Map of event loops to their respective locks to prevent race conditions
246264 # across different event loops in session creation.
@@ -312,35 +330,66 @@ def _merge_headers(
312330
313331 return base_headers
314332
315- def _is_session_disconnected (self , session : ClientSession ) -> bool :
333+ def _is_session_disconnected (
334+ self ,
335+ entry : _SessionEntry ,
336+ ) -> bool :
316337 """Checks if a session is disconnected or closed.
317338
318339 Args:
319- session : The ClientSession to check.
340+ entry : The _SessionEntry to check.
320341
321342 Returns:
322343 True if the session is disconnected, False otherwise.
323344 """
324- return session ._read_stream ._closed or session ._write_stream ._closed
345+ if (
346+ entry .session ._read_stream ._closed
347+ or entry .session ._write_stream ._closed
348+ ):
349+ return True
350+ if entry .context is not None and not entry .context ._is_task_alive : # pylint: disable=protected-access
351+ return True
352+ return False
353+
354+ def _get_session_context (
355+ self , headers : Optional [Dict [str , str ]] = None
356+ ) -> Optional ['SessionContext' ]:
357+ """Returns the SessionContext for the session matching the given headers.
358+
359+ Note: This method reads from the session pool without acquiring
360+ ``_session_lock``. This is safe because it is called immediately after
361+ ``create_session()`` (which populates the entry under the lock) within
362+ the same task, and dict reads are atomic in CPython.
363+
364+ Args:
365+ headers: Optional headers used to identify the session.
366+
367+ Returns:
368+ The SessionContext if a matching session exists, None otherwise.
369+ """
370+ merged_headers = self ._merge_headers (headers )
371+ session_key = self ._generate_session_key (merged_headers )
372+ entry = self ._sessions .get (session_key )
373+ if entry is not None :
374+ return entry .context
375+ return None
325376
326377 async def _cleanup_session (
327378 self ,
328379 session_key : str ,
329- exit_stack : AsyncExitStack ,
330- stored_loop : asyncio .AbstractEventLoop ,
331- ):
380+ entry : _SessionEntry ,
381+ ) -> None :
332382 """Cleans up a session, handling different event loops safely.
333383
334384 Args:
335385 session_key: The session key to clean up.
336- exit_stack: The AsyncExitStack managing the session resources.
337- stored_loop: The event loop on which the session was created.
386+ entry: The _SessionEntry managing the session resources.
338387 """
339388 current_loop = asyncio .get_running_loop ()
340389 try :
341- if stored_loop is current_loop :
342- await exit_stack .aclose ()
343- elif stored_loop .is_closed ():
390+ if entry . loop is current_loop :
391+ await entry . exit_stack .aclose ()
392+ elif entry . loop .is_closed ():
344393 logger .warning (
345394 f'Error cleaning up session { session_key } : original event loop'
346395 ' is closed, resources may be leaked.'
@@ -353,11 +402,11 @@ async def _cleanup_session(
353402 ' event loop.'
354403 )
355404 future = asyncio .run_coroutine_threadsafe (
356- exit_stack .aclose (), stored_loop
405+ entry . exit_stack .aclose (), entry . loop
357406 )
358407
359408 # Attach a callback so errors don't go unnoticed
360- def cleanup_done (f : asyncio . Future ) :
409+ def cleanup_done (f : 'concurrent.futures. Future[Any]' ) -> None :
361410 try :
362411 if f .exception ():
363412 logger .warning (
@@ -379,7 +428,9 @@ def cleanup_done(f: asyncio.Future):
379428 if session_key in self ._sessions :
380429 del self ._sessions [session_key ]
381430
382- def _create_client (self , merged_headers : Optional [Dict [str , str ]] = None ):
431+ def _create_client (
432+ self , merged_headers : Optional [Dict [str , str ]] = None
433+ ) -> Any :
383434 """Creates an MCP client based on the connection parameters.
384435
385436 Args:
@@ -451,22 +502,22 @@ async def create_session(
451502 async with self ._session_lock :
452503 # Check if we have an existing session
453504 if session_key in self ._sessions :
454- session , exit_stack , stored_loop = self ._sessions [session_key ]
505+ entry = self ._sessions [session_key ]
455506
456507 # Check if the existing session is still connected and bound to the current loop
457508 current_loop = asyncio .get_running_loop ()
458- if stored_loop is current_loop and not self ._is_session_disconnected (
459- session
509+ if entry . loop is current_loop and not self ._is_session_disconnected (
510+ entry
460511 ):
461512 # Session is still good, return it
462- return session
513+ return entry . session
463514 else :
464515 # Session is disconnected or from a different loop, clean it up
465516 logger .info (
466517 'Cleaning up session (disconnected or different loop): %s' ,
467518 session_key ,
468519 )
469- await self ._cleanup_session (session_key , exit_stack , stored_loop )
520+ await self ._cleanup_session (session_key , entry )
470521
471522 # Create a new session (either first time or replacing disconnected one)
472523 exit_stack = AsyncExitStack ()
@@ -482,28 +533,30 @@ async def create_session(
482533 )
483534
484535 try :
536+ from .session_context import SessionContext
537+
485538 client = self ._create_client (merged_headers )
486539 is_stdio = isinstance (self ._connection_params , StdioConnectionParams )
487540
541+ session_context = SessionContext (
542+ client = client ,
543+ timeout = timeout_in_seconds ,
544+ sse_read_timeout = sse_read_timeout_in_seconds ,
545+ is_stdio = is_stdio ,
546+ sampling_callback = self ._sampling_callback ,
547+ sampling_capabilities = self ._sampling_capabilities ,
548+ )
488549 session = await asyncio .wait_for (
489- exit_stack .enter_async_context (
490- SessionContext (
491- client = client ,
492- timeout = timeout_in_seconds ,
493- sse_read_timeout = sse_read_timeout_in_seconds ,
494- is_stdio = is_stdio ,
495- sampling_callback = self ._sampling_callback ,
496- sampling_capabilities = self ._sampling_capabilities ,
497- )
498- ),
550+ exit_stack .enter_async_context (session_context ),
499551 timeout = timeout_in_seconds ,
500552 )
501553
502- # Store session, exit stack, and loop in the pool
503- self ._sessions [session_key ] = (
504- session ,
505- exit_stack ,
506- asyncio .get_running_loop (),
554+ # Store session, exit stack, loop, and context in the pool
555+ self ._sessions [session_key ] = _SessionEntry (
556+ session = session ,
557+ exit_stack = exit_stack ,
558+ loop = asyncio .get_running_loop (),
559+ context = session_context ,
507560 )
508561 logger .debug ('Created new session: %s' , session_key )
509562 return session
@@ -519,7 +572,7 @@ async def create_session(
519572 )
520573 raise ConnectionError (f'Failed to create MCP session: { e } ' ) from e
521574
522- def __getstate__ (self ):
575+ def __getstate__ (self ) -> Dict [ str , Any ] :
523576 """Custom pickling to exclude non-picklable runtime objects."""
524577 state = self .__dict__ .copy ()
525578 # Remove unpicklable entries or those that shouldn't persist across pickle
@@ -532,7 +585,7 @@ def __getstate__(self):
532585
533586 return state
534587
535- def __setstate__ (self , state ) :
588+ def __setstate__ (self , state : Dict [ str , Any ]) -> None :
536589 """Custom unpickling to restore state."""
537590 self .__dict__ .update (state )
538591 # Re-initialize members that were not pickled
@@ -543,12 +596,12 @@ def __setstate__(self, state):
543596 if not hasattr (self , '_errlog' ) or self ._errlog is None :
544597 self ._errlog = sys .stderr
545598
546- async def close (self ):
599+ async def close (self ) -> None :
547600 """Closes all sessions and cleans up resources."""
548601 async with self ._session_lock :
549602 for session_key in list (self ._sessions .keys ()):
550- _ , exit_stack , stored_loop = self ._sessions [session_key ]
551- await self ._cleanup_session (session_key , exit_stack , stored_loop )
603+ entry = self ._sessions [session_key ]
604+ await self ._cleanup_session (session_key , entry )
552605
553606
554607SseServerParams = SseConnectionParams
0 commit comments