@@ -104,11 +104,98 @@ async def create_mcp_tools_from_metadata(
104104 Each tool manages its own session lifecycle - creating, using, and cleaning up
105105 the MCP connection within the tool invocation.
106106 """
107+
108+ if config .is_enabled is False :
109+ return []
110+
107111 # Lazy import to improve cold start time
112+ import logging
113+
114+ import anyio
115+ from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
108116 from mcp import ClientSession
109- from mcp .client .streamable_http import streamable_http_client
117+ from mcp .client .streamable_http import GetSessionIdCallback , StreamableHTTPTransport
118+ from mcp .shared .message import SessionMessage
119+
120+ logger : logging .Logger = logging .getLogger (__name__ )
121+
122+ @asynccontextmanager
123+ async def streamable_http_client (
124+ url : str ,
125+ * ,
126+ http_client : httpx .AsyncClient ,
127+ session_id : str | None = None ,
128+ terminate_on_close : bool = False ,
129+ ) -> AsyncGenerator [
130+ tuple [
131+ MemoryObjectReceiveStream [SessionMessage | Exception ],
132+ MemoryObjectSendStream [SessionMessage ],
133+ GetSessionIdCallback ,
134+ ],
135+ None ,
136+ ]:
137+ """Client transport for StreamableHTTP.
138+
139+ Args:
140+ url: The MCP server endpoint URL.
141+ http_client: Pre-configured httpx.AsyncClient to use for HTTP requests.
142+ session_id: Optional session ID for reusing an existing MCP session. If provided,
143+ the client will reconnect to an existing session instead of creating a new one.
144+ If None, a new session will be created on initialization.
145+ terminate_on_close: If True, send a DELETE request to terminate the session when the context exits.
146+
147+ Yields:
148+ Tuple containing:
149+ - read_stream: Stream for reading messages from the server
150+ - write_stream: Stream for sending messages to the server
151+ - get_session_id_callback: Function to retrieve the current session ID
152+
153+ Example:
154+ See examples/snippets/clients/ for usage patterns.
155+ """
156+ read_stream_writer , read_stream = anyio .create_memory_object_stream [
157+ SessionMessage | Exception
158+ ](0 )
159+ write_stream , write_stream_reader = anyio .create_memory_object_stream [
160+ SessionMessage
161+ ](0 )
162+
163+ transport = StreamableHTTPTransport (url )
164+ transport .session_id = session_id # type: ignore[assignment]
165+
166+ async with anyio .create_task_group () as tg :
167+ try :
168+ logger .debug (f"Connecting to StreamableHTTP endpoint: { url } " )
169+
170+ def start_get_stream () -> None :
171+ tg .start_soon (
172+ transport .handle_get_stream , http_client , read_stream_writer
173+ )
174+
175+ tg .start_soon (
176+ transport .post_writer ,
177+ http_client ,
178+ write_stream_reader ,
179+ read_stream_writer ,
180+ write_stream ,
181+ start_get_stream ,
182+ tg ,
183+ )
184+
185+ try :
186+ yield (read_stream , write_stream , transport .get_session_id )
187+ finally :
188+ if transport .session_id and terminate_on_close :
189+ await transport .terminate_session (http_client )
190+ tg .cancel_scope .cancel ()
191+
192+ finally :
193+ await read_stream_writer .aclose ()
194+ await write_stream .aclose ()
110195
111196 tools : list [BaseTool ] = []
197+ session_id : str | None = None
198+ session_lock = asyncio .Lock ()
112199
113200 for mcp_tool in config .available_tools :
114201 tool_name = sanitize_tool_name (mcp_tool .name )
@@ -146,20 +233,30 @@ async def tool_fn(**kwargs: Any) -> Any:
146233 httpx .AsyncClient (** client_kwargs )
147234 )
148235
149- # Create streamable connection
150- read , write , _ = await stack .enter_async_context (
151- streamable_http_client (
152- url = f"{ mcpServer .mcp_url } " , http_client = http_client
236+ # Create streamable connection and initialize session with lock
237+ # to prevent race conditions when multiple tools are invoked concurrently
238+ nonlocal session_id
239+ async with session_lock :
240+ logger .debug (f"Connecting to session { session_id } " )
241+ read , write , getSessionId = await stack .enter_async_context (
242+ streamable_http_client (
243+ url = f"{ mcpServer .mcp_url } " ,
244+ http_client = http_client ,
245+ session_id = session_id ,
246+ )
153247 )
154- )
155248
156- # Create and initialize session
157- session = await stack .enter_async_context (
158- ClientSession (read , write )
159- )
160- await session .initialize ()
249+ # Create and initialize session
250+ session = await stack .enter_async_context (
251+ ClientSession (read , write )
252+ )
253+
254+ if not session_id :
255+ await session .initialize ()
256+ session_id = getSessionId ()
257+ logger .info (f"session { session_id } created" )
161258
162- # Call the tool
259+ # Call the tool (outside lock to allow concurrent tool calls)
163260 result = await session .call_tool (mcp_tool .name , arguments = kwargs )
164261 return result .content if hasattr (result , "content" ) else result
165262
0 commit comments