Skip to content

Commit a87789d

Browse files
brianstrauchclaude
andcommitted
contrib/strands: cache MCP connections across tool calls
The per-server {server}-call-tool activity opened a fresh MCP session on every invocation (open transport + initialize + call + teardown), so an agent making several successive MCP calls paid that handshake per call -- and for stdio servers, a subprocess spawn per call. Hold a lazily-opened MCP session per server in the activity worker process so successive call-tool activities reuse one connection. A dedicated owner task enters and exits the anyio transport/ClientSession context managers in the same task (the cancel-scope rule); call-tool activities on the same event loop invoke session.call_tool directly, which MCP multiplexes by request id. Evict on idle timeout, on a call error (so a broken session reconnects), and on worker shutdown -- scoped to the servers the plugin registered rather than every cached connection. A reused session now carries server-side state across workflows sharing a worker, a behavior change from the previous per-call isolation. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent a70a40a commit a87789d

3 files changed

Lines changed: 205 additions & 8 deletions

File tree

temporalio/contrib/strands/_plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ._failure_converter import StrandsFailureConverter
1515
from ._model_activity import ModelActivity
1616
from ._temporal_mcp_client import (
17+
_evict_connection,
1718
build_call_tool_activity,
1819
clear_cache,
1920
populate_cache,
@@ -67,6 +68,7 @@ async def run_context() -> AsyncGenerator[None, None]:
6768
yield
6869
finally:
6970
for server in mcp_clients:
71+
await _evict_connection(server)
7072
clear_cache(server)
7173

7274
super().__init__(

temporalio/contrib/strands/_temporal_mcp_client.py

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
3+
import asyncio
14
from collections.abc import Callable, Sequence
25
from dataclasses import dataclass, field
36
from datetime import timedelta
@@ -158,13 +161,36 @@ def clear_cache(server: str) -> None:
158161
_TOOL_CACHE.pop(server, None)
159162

160163

161-
def build_call_tool_activity(
162-
server: str, client_factory: Callable[[], MCPClient]
163-
) -> Callable:
164-
"""Return the per-server ``{server}-call-tool`` activity for registration."""
164+
# How long an idle MCP connection stays open before it is disconnected.
165+
_MCP_CONNECTION_IDLE = timedelta(minutes=5)
165166

166-
@activity.defn(name=f"{server}-call-tool")
167-
async def call_tool(args: _CallToolArgs) -> MCPToolResult:
167+
# Server name -> live connection held open in the activity worker process.
168+
# Activities run in the worker process , so this module state is shared across activity invocations on the worker
169+
_CONNECTIONS: dict[str, _ConnectionRecord] = {}
170+
171+
172+
class _ConnectionRecord:
173+
"""A single MCP session held open by a dedicated owner task.
174+
175+
The MCP transport and ``ClientSession`` are anyio context managers whose
176+
cancel scope is bound to the task that enters them, so they must be entered
177+
and exited in the same task. ``_run`` owns that task for the connection's
178+
whole lifetime; ``call_tool`` activities on the same event loop invoke
179+
``session.call_tool`` directly (MCP multiplexes concurrent requests by id).
180+
"""
181+
182+
def __init__(self, server: str, client_factory: Callable[[], MCPClient]) -> None:
183+
loop = asyncio.get_running_loop()
184+
self._server = server
185+
self._stop = asyncio.Event()
186+
self._ready: asyncio.Future[tuple[MCPClient, ClientSession]] = (
187+
loop.create_future()
188+
)
189+
self._idle_handle: asyncio.TimerHandle | None = None
190+
self._idle_task: asyncio.Task[None] | None = None
191+
self._owner = asyncio.create_task(self._run(client_factory))
192+
193+
async def _run(self, client_factory: Callable[[], MCPClient]) -> None:
168194
client = client_factory()
169195
try:
170196
async with client._transport_callable() as (read_stream, write_stream, *_):
@@ -174,9 +200,87 @@ async def call_tool(args: _CallToolArgs) -> MCPToolResult:
174200
elicitation_callback=client._elicitation_callback,
175201
) as session:
176202
await session.initialize()
177-
result = await session.call_tool(args.tool_name, args.arguments)
178-
return client._handle_tool_result(args.tool_use_id, result)
203+
self._ready.set_result((client, session))
204+
await self._stop.wait()
205+
except BaseException as err:
206+
# A failed connect should not be cached; drop it so the next call
207+
# retries instead of awaiting a permanently rejected future.
208+
if not self._ready.done():
209+
self._ready.set_exception(err)
210+
_CONNECTIONS.pop(self._server, None)
211+
raise
212+
213+
def touch(self) -> None:
214+
"""Restart the idle-eviction timer because the connection was used."""
215+
if self._idle_handle is not None:
216+
self._idle_handle.cancel()
217+
loop = asyncio.get_running_loop()
218+
self._idle_handle = loop.call_later(
219+
_MCP_CONNECTION_IDLE.total_seconds(), self._on_idle
220+
)
221+
222+
def _on_idle(self) -> None:
223+
self._idle_task = asyncio.ensure_future(_evict_connection(self._server))
224+
225+
async def aclose(self) -> None:
226+
"""Signal the owner task to exit its context managers and wait for it."""
227+
if self._idle_handle is not None:
228+
self._idle_handle.cancel()
229+
self._idle_handle = None
230+
self._stop.set()
231+
try:
232+
await self._owner
233+
except BaseException:
234+
pass
235+
236+
async def session(self) -> tuple[MCPClient, ClientSession]:
237+
"""Return the live client and session, or raise the connect failure."""
238+
return await self._ready
239+
240+
241+
async def get_connection(
242+
server: str, client_factory: Callable[[], MCPClient]
243+
) -> tuple[MCPClient, ClientSession]:
244+
"""Return the cached session for ``server``, opening one lazily if needed.
245+
246+
Concurrent first-callers dedupe onto a single connect handshake by awaiting
247+
the same record.
248+
"""
249+
record = _CONNECTIONS.get(server)
250+
if record is None:
251+
record = _ConnectionRecord(server, client_factory)
252+
_CONNECTIONS[server] = record
253+
record.touch()
254+
return await record.session()
255+
256+
257+
async def _evict_connection(server: str) -> None:
258+
record = _CONNECTIONS.pop(server, None)
259+
if record is not None:
260+
await record.aclose()
261+
262+
263+
def build_call_tool_activity(
264+
server: str, client_factory: Callable[[], MCPClient]
265+
) -> Callable:
266+
"""Return the per-server ``{server}-call-tool`` activity for registration.
267+
268+
Reuses a worker-process MCP session opened lazily through ``client_factory``.
269+
"""
270+
271+
@activity.defn(name=f"{server}-call-tool")
272+
async def call_tool(args: _CallToolArgs) -> MCPToolResult:
273+
try:
274+
client, session = await get_connection(server, client_factory)
275+
except Exception as err:
276+
# Connecting failed; map to a tool error result like a call would.
277+
return client_factory()._handle_tool_execution_error(args.tool_use_id, err)
278+
try:
279+
result = await session.call_tool(args.tool_name, args.arguments)
280+
return client._handle_tool_result(args.tool_use_id, result)
179281
except Exception as err:
282+
# The session may be broken; drop it so the next call reconnects.
283+
await _evict_connection(server)
180284
return client._handle_tool_execution_error(args.tool_use_id, err)
181285

182286
return call_tool

tests/contrib/strands/test_mcp.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,24 @@
1212
StrandsPlugin,
1313
TemporalAgent,
1414
TemporalMCPClient,
15+
_temporal_mcp_client,
1516
)
1617
from temporalio.worker import Replayer, Worker
1718
from tests.contrib.strands.common import get_activities
1819
from tests.contrib.strands.mock_model import MockModel
1920

2021

22+
def _echo_client_factory() -> MCPClient:
23+
return MCPClient(
24+
lambda: stdio_client(
25+
StdioServerParameters(
26+
command=sys.executable,
27+
args=[str(Path(__file__).parent / "echo_mcp_server.py")],
28+
)
29+
)
30+
)
31+
32+
2133
@workflow.defn
2234
class MCPWorkflow:
2335
def __init__(self) -> None:
@@ -86,3 +98,82 @@ async def test_mcp(client: Client):
8698
workflows=[MCPWorkflow],
8799
plugins=[plugin],
88100
).replay_workflow(history)
101+
102+
103+
@workflow.defn
104+
class MCPReuseWorkflow:
105+
def __init__(self) -> None:
106+
echo = TemporalMCPClient(
107+
server="echo_cached",
108+
start_to_close_timeout=timedelta(seconds=30),
109+
)
110+
self.agent = TemporalAgent(
111+
model="mock",
112+
start_to_close_timeout=timedelta(seconds=30),
113+
tools=[echo],
114+
)
115+
116+
@workflow.run
117+
async def run(self, prompt: str) -> str:
118+
result = await self.agent.invoke_async(prompt)
119+
return str(result)
120+
121+
122+
async def test_mcp_reuses_connection(client: Client):
123+
"""Successive MCP tool calls reuse one cached worker-side connection."""
124+
task_queue = "test_mcp_reuses_connection"
125+
# Count how often the worker opens a connection. With caching this is one
126+
# startup-discovery connection plus one cached call connection serving both
127+
# tool calls (2); reconnecting per call would make it 3.
128+
factory_calls = [0]
129+
130+
def counting_factory() -> MCPClient:
131+
factory_calls[0] += 1
132+
return _echo_client_factory()
133+
134+
plugin = StrandsPlugin(
135+
models={
136+
"mock": lambda: MockModel(
137+
[
138+
{"name": "echo", "input": {"message": "one"}},
139+
{"name": "echo", "input": {"message": "two"}},
140+
"Done!",
141+
]
142+
)
143+
},
144+
mcp_clients={"echo_cached": counting_factory},
145+
)
146+
147+
async with Worker(
148+
client,
149+
task_queue=task_queue,
150+
workflows=[MCPReuseWorkflow],
151+
plugins=[plugin],
152+
max_cached_workflows=0,
153+
):
154+
handle = await client.start_workflow(
155+
MCPReuseWorkflow.run,
156+
"echo twice",
157+
id=f"test_mcp_reuses_connection_{uuid4()}",
158+
task_queue=task_queue,
159+
)
160+
assert await handle.result() == "Done!\n"
161+
162+
# The worker context has exited, so its run_context finally evicted the
163+
# cached connection.
164+
assert "echo_cached" not in _temporal_mcp_client._CONNECTIONS
165+
assert factory_calls[0] == 2
166+
167+
history = await handle.fetch_history()
168+
assert get_activities(history) == [
169+
"invoke_model",
170+
"echo_cached-call-tool",
171+
"invoke_model",
172+
"echo_cached-call-tool",
173+
"invoke_model",
174+
]
175+
176+
await Replayer(
177+
workflows=[MCPReuseWorkflow],
178+
plugins=[plugin],
179+
).replay_workflow(history)

0 commit comments

Comments
 (0)