Skip to content

Commit 41a5daa

Browse files
committed
feat: reuse mcp session between tool calls
1 parent 42667b6 commit 41a5daa

2 files changed

Lines changed: 835 additions & 12 deletions

File tree

src/uipath_langchain/agent/tools/mcp_tool.py

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,98 @@ async def create_mcp_tools_from_metadata(
104104
Each tool manages its own session lifecycle - creating, using, and cleaning up
105105
the MCP connection within the tool invocation.
106106
"""
107+
108+
if config.is_enabled is False:
109+
return []
110+
107111
# Lazy import to improve cold start time
112+
import logging
113+
114+
import anyio
115+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
108116
from mcp import ClientSession
109-
from mcp.client.streamable_http import streamable_http_client
117+
from mcp.client.streamable_http import GetSessionIdCallback, StreamableHTTPTransport
118+
from mcp.shared.message import SessionMessage
119+
120+
logger: logging.Logger = logging.getLogger(__name__)
121+
122+
@asynccontextmanager
123+
async def streamable_http_client(
124+
url: str,
125+
*,
126+
http_client: httpx.AsyncClient,
127+
session_id: str | None = None,
128+
terminate_on_close: bool = False,
129+
) -> AsyncGenerator[
130+
tuple[
131+
MemoryObjectReceiveStream[SessionMessage | Exception],
132+
MemoryObjectSendStream[SessionMessage],
133+
GetSessionIdCallback,
134+
],
135+
None,
136+
]:
137+
"""Client transport for StreamableHTTP.
138+
139+
Args:
140+
url: The MCP server endpoint URL.
141+
http_client: Pre-configured httpx.AsyncClient to use for HTTP requests.
142+
session_id: Optional session ID for reusing an existing MCP session. If provided,
143+
the client will reconnect to an existing session instead of creating a new one.
144+
If None, a new session will be created on initialization.
145+
terminate_on_close: If True, send a DELETE request to terminate the session when the context exits.
146+
147+
Yields:
148+
Tuple containing:
149+
- read_stream: Stream for reading messages from the server
150+
- write_stream: Stream for sending messages to the server
151+
- get_session_id_callback: Function to retrieve the current session ID
152+
153+
Example:
154+
See examples/snippets/clients/ for usage patterns.
155+
"""
156+
read_stream_writer, read_stream = anyio.create_memory_object_stream[
157+
SessionMessage | Exception
158+
](0)
159+
write_stream, write_stream_reader = anyio.create_memory_object_stream[
160+
SessionMessage
161+
](0)
162+
163+
transport = StreamableHTTPTransport(url)
164+
transport.session_id = session_id # type: ignore[assignment]
165+
166+
async with anyio.create_task_group() as tg:
167+
try:
168+
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
169+
170+
def start_get_stream() -> None:
171+
tg.start_soon(
172+
transport.handle_get_stream, http_client, read_stream_writer
173+
)
174+
175+
tg.start_soon(
176+
transport.post_writer,
177+
http_client,
178+
write_stream_reader,
179+
read_stream_writer,
180+
write_stream,
181+
start_get_stream,
182+
tg,
183+
)
184+
185+
try:
186+
yield (read_stream, write_stream, transport.get_session_id)
187+
finally:
188+
if transport.session_id and terminate_on_close:
189+
await transport.terminate_session(http_client)
190+
tg.cancel_scope.cancel()
191+
192+
finally:
193+
await read_stream_writer.aclose()
194+
await write_stream.aclose()
110195

111196
tools: list[BaseTool] = []
197+
session_id: str | None = None
198+
session_lock = asyncio.Lock()
112199

113200
for mcp_tool in config.available_tools:
114201
tool_name = sanitize_tool_name(mcp_tool.name)
@@ -146,20 +233,30 @@ async def tool_fn(**kwargs: Any) -> Any:
146233
httpx.AsyncClient(**client_kwargs)
147234
)
148235

149-
# Create streamable connection
150-
read, write, _ = await stack.enter_async_context(
151-
streamable_http_client(
152-
url=f"{mcpServer.mcp_url}", http_client=http_client
236+
# Create streamable connection and initialize session with lock
237+
# to prevent race conditions when multiple tools are invoked concurrently
238+
nonlocal session_id
239+
async with session_lock:
240+
logger.debug(f"Connecting to session {session_id}")
241+
read, write, getSessionId = await stack.enter_async_context(
242+
streamable_http_client(
243+
url=f"{mcpServer.mcp_url}",
244+
http_client=http_client,
245+
session_id=session_id,
246+
)
153247
)
154-
)
155248

156-
# Create and initialize session
157-
session = await stack.enter_async_context(
158-
ClientSession(read, write)
159-
)
160-
await session.initialize()
249+
# Create and initialize session
250+
session = await stack.enter_async_context(
251+
ClientSession(read, write)
252+
)
253+
254+
if not session_id:
255+
await session.initialize()
256+
session_id = getSessionId()
257+
logger.info(f"session {session_id} created")
161258

162-
# Call the tool
259+
# Call the tool (outside lock to allow concurrent tool calls)
163260
result = await session.call_tool(mcp_tool.name, arguments=kwargs)
164261
return result.content if hasattr(result, "content") else result
165262

0 commit comments

Comments
 (0)