Skip to content

Commit c8a052c

Browse files
authored
contrib/strands: add cache_tools toggle to TemporalMCPClient (#1571)
* contrib/strands: add cache_tools toggle to TemporalMCPClient Replace worker-startup tool discovery with a per-server {server}-list-tools activity executed from inside the workflow. TemporalMCPClient.cache_tools (default True) lists tools once at the start of the workflow; cache_tools=False re-lists on every agent turn so a mid-workflow MCP server restart is picked up. Strands calls load_tools() once at agent construction on a separate run_async thread with no workflow runtime, so the activity is dispatched from a BeforeModelCallEvent hook (which runs on the workflow loop before the registry is read each turn) that reconciles added/removed/renamed tools. * contrib/strands: default cache_tools to False
1 parent 9a95fef commit c8a052c

5 files changed

Lines changed: 247 additions & 76 deletions

File tree

temporalio/contrib/strands/README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ class ChatWorkflow:
380380

381381
## MCP
382382

383-
`StrandsPlugin(mcp_clients=...)` takes a mapping of `name → MCPClient factory`, mirroring the `models=` pattern. The plugin registers a per-server `{name}-call-tool` activity and connects at worker startup to enumerate tools. Workflow-side, `TemporalMCPClient(server="name")` is a pure handle: it references the server by name and carries the per-call activity options.
383+
`StrandsPlugin(mcp_clients=...)` takes a mapping of `name → MCPClient factory`, mirroring the `models=` pattern. The plugin registers per-server `{name}-call-tool` and `{name}-list-tools` activities. Workflow-side, `TemporalMCPClient(server="name")` is a pure handle: it references the server by name, discovers tools by running `{name}-list-tools`, and carries the per-call activity options.
384384

385385
```python
386386
from mcp import StdioServerParameters, stdio_client
@@ -412,9 +412,15 @@ Worker(
412412
)
413413
```
414414

415-
Each factory returns a fully configured `MCPClient`, so you can pass options like `tool_filters`, `prefix`, `elicitation_callback`, or `tasks_config` to it. The plugin connects to each MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If a server is unavailable at startup, the worker fails to start.
415+
Each factory returns a fully configured `MCPClient`, so you can pass options like `tool_filters`, `prefix`, `elicitation_callback`, or `tasks_config` to it.
416416

417-
To amortize connection setup, the `{name}-call-tool` activity keeps a worker-process MCP connection open between calls and reuses it. The connection is disconnected after it sits idle for `mcp_connection_idle_timeout` (default 5 minutes); the timer resets on every reuse:
417+
By default, `TemporalMCPClient` re-lists the server's tools (via `{name}-list-tools`) on every agent turn, so an MCP server that is restarted mid-workflow — with tools added, removed, or renamed — is picked up. To list the tools just once at the beginning of the workflow and reuse that schema for the workflow's lifetime (one fewer activity per turn), set `cache_tools=True`:
418+
419+
```python
420+
echo = TemporalMCPClient(server="echo", cache_tools=True, start_to_close_timeout=timedelta(seconds=30))
421+
```
422+
423+
To amortize connection setup, the `{name}-call-tool` and `{name}-list-tools` activities share a worker-process MCP connection that is opened lazily and reused across calls. The connection is disconnected after it sits idle for `mcp_connection_idle_timeout` (default 5 minutes); the timer resets on every reuse:
418424

419425
```python
420426
StrandsPlugin(

temporalio/contrib/strands/_plugin.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from ._temporal_mcp_client import (
1818
_evict_connection,
1919
build_call_tool_activity,
20-
clear_cache,
21-
populate_cache,
20+
build_list_tools_activity,
2221
)
2322

2423

@@ -31,10 +30,11 @@ class StrandsPlugin(SimplePlugin):
3130
on first use, then cached for the worker's lifetime. Use the same name in
3231
``TemporalAgent(model=...)`` inside the workflow.
3332
34-
When ``mcp_clients`` is supplied, registers a per-server
35-
``{server}-call-tool`` activity for each entry and, at worker startup,
36-
connects to each MCP server to cache its tool list. Workflow-side
37-
``TemporalMCPClient(server="...").load_tools()`` reads from the cache.
33+
When ``mcp_clients`` is supplied, registers per-server
34+
``{server}-call-tool`` and ``{server}-list-tools`` activities for each
35+
entry. Workflow-side ``TemporalMCPClient(server="...")`` discovers tools by
36+
running ``{server}-list-tools``; whether it lists once per workflow or once
37+
per agent turn is controlled by its ``cache_tools`` option.
3838
3939
``mcp_connection_idle_timeout`` controls how long a worker-process MCP
4040
connection is kept open between ``call-tool`` activities before it is
@@ -69,17 +69,19 @@ def __init__(
6969
server, client_factory, mcp_connection_idle_timeout
7070
)
7171
)
72+
activities.append(
73+
build_list_tools_activity(
74+
server, client_factory, mcp_connection_idle_timeout
75+
)
76+
)
7277

7378
@asynccontextmanager
7479
async def run_context() -> AsyncGenerator[None, None]:
75-
for server, client_factory in mcp_clients.items():
76-
await populate_cache(server, client_factory)
7780
try:
7881
yield
7982
finally:
8083
for server in mcp_clients:
8184
await _evict_connection(server)
82-
clear_cache(server)
8385

8486
super().__init__(
8587
"aws.StrandsPlugin",

temporalio/contrib/strands/_temporal_agent.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from typing import Any
33

44
from strands import Agent
5+
from strands.hooks import BeforeModelCallEvent, HookCallback
56

67
from temporalio.common import Priority, RetryPolicy
78
from temporalio.workflow import ActivityCancellationType, VersioningIntent
89

10+
from ._temporal_mcp_client import TemporalMCPClient
911
from ._temporal_model import TemporalModel
1012

1113
_SNAPSHOT_DISABLED = (
@@ -76,6 +78,51 @@ def __init__(
7678
)
7779
super().__init__(model=temporal_model, **agent_kwargs)
7880

81+
# Strands invokes ToolProvider.load_tools() once at construction on a
82+
# separate run_async thread that has no workflow runtime, so a
83+
# TemporalMCPClient cannot list its tools there. Instead refresh from a
84+
# BeforeModelCallEvent hook, which runs on the workflow loop just before
85+
# the registry is read each turn. cache_tools=True lists once (guarded
86+
# by _fetched); cache_tools=False re-lists every turn.
87+
for provider in self.tool_registry._tool_providers:
88+
if isinstance(provider, TemporalMCPClient):
89+
self.hooks.add_callback(
90+
BeforeModelCallEvent, self._make_mcp_refresh_hook(provider)
91+
)
92+
93+
def _make_mcp_refresh_hook(
94+
self, provider: TemporalMCPClient
95+
) -> HookCallback[BeforeModelCallEvent]:
96+
async def hook(event: BeforeModelCallEvent) -> None:
97+
if provider._cache_tools and provider._fetched:
98+
return
99+
old_names = {tool.tool_name for tool in provider._tools}
100+
await provider._refresh()
101+
self._reconcile_mcp_tools(event, provider, old_names)
102+
103+
return hook
104+
105+
def _reconcile_mcp_tools(
106+
self,
107+
event: BeforeModelCallEvent,
108+
provider: TemporalMCPClient,
109+
old_names: set[str],
110+
) -> None:
111+
reg = event.agent.tool_registry
112+
new = {tool.tool_name: tool for tool in provider._tools}
113+
# Tools the server dropped or renamed since the last listing. There is
114+
# no public unregister, so remove them from the registry directly.
115+
for name in old_names - set(new):
116+
reg.registry.pop(name, None)
117+
reg.dynamic_tools.pop(name, None)
118+
# replace() swaps an existing tool in place (no hot-reload guard);
119+
# register_tool() adds a newly-discovered one.
120+
for name, tool in new.items():
121+
if name in reg.registry:
122+
reg.replace(tool)
123+
else:
124+
reg.register_tool(tool)
125+
79126
def take_snapshot(self, *_args: Any, **_kwargs: Any) -> Any:
80127
"""Disabled; Temporal's event history is the source of truth."""
81128
raise NotImplementedError(_SNAPSHOT_DISABLED)

temporalio/contrib/strands/_temporal_mcp_client.py

Lines changed: 100 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from strands.tools.mcp.mcp_types import MCPToolResult
1414
from strands.types.tools import AgentTool
1515

16-
from temporalio import activity
16+
from temporalio import activity, workflow
1717
from temporalio.common import Priority, RetryPolicy
1818
from temporalio.workflow import ActivityCancellationType, VersioningIntent
1919

@@ -33,20 +33,21 @@ class _CallToolArgs:
3333
tool_use_id: str = ""
3434

3535

36-
# Server name -> cached tool list. Populated by ``_populate_cache`` at worker
37-
# startup and read by ``TemporalMCPClient.load_tools()`` inside the workflow
38-
# sandbox. ``temporalio`` is in the SDK's default sandbox passthrough, so this
39-
# dict is shared between worker process and workflow execution.
40-
_TOOL_CACHE: dict[str, list[_MCPToolInfo]] = {}
41-
42-
4336
class TemporalMCPClient(ToolProvider):
4437
"""Workflow-side handle to an MCP server registered on the worker.
4538
46-
The transport factory and tool discovery live worker-side via
47-
``StrandsPlugin(mcp_clients={"server": lambda: ...})``. This handle only
48-
carries the server name (which selects the registered factory) and the
49-
per-call activity options.
39+
The transport factory lives worker-side via
40+
``StrandsPlugin(mcp_clients={"server": lambda: ...})``. This handle carries
41+
the server name (which selects the registered factory) and the per-call
42+
activity options. Tool discovery runs as the ``{server}-list-tools``
43+
activity, dispatched from inside the workflow by ``TemporalAgent`` before
44+
each model call.
45+
46+
``cache_tools`` controls how often that listing happens. When ``False``
47+
(the default) the tools are re-listed on every agent turn, so an MCP server
48+
restarted mid-workflow (with tools added, removed, or renamed) is picked up.
49+
When ``True`` the tools are listed once at the beginning of the workflow and
50+
reused for its lifetime.
5051
5152
Construct once at module level and pass to ``TemporalAgent(tools=[...])``
5253
inside the workflow. Multiple handles may reference the same server name
@@ -57,6 +58,7 @@ def __init__(
5758
self,
5859
server: str,
5960
*,
61+
cache_tools: bool = False,
6062
task_queue: str | None = None,
6163
schedule_to_close_timeout: timedelta | None = None,
6264
schedule_to_start_timeout: timedelta | None = None,
@@ -70,6 +72,9 @@ def __init__(
7072
) -> None:
7173
"""Configure the server name and activity options."""
7274
self._server = server
75+
self._cache_tools = cache_tools
76+
self._tools: list[AgentTool] = []
77+
self._fetched = False
7378
self._options: dict[str, Any] = {
7479
"task_queue": task_queue,
7580
"schedule_to_close_timeout": schedule_to_close_timeout,
@@ -89,11 +94,33 @@ def server(self) -> str:
8994
return self._server
9095

9196
async def load_tools(self, **_kwargs: Any) -> Sequence[AgentTool]:
92-
"""Return TemporalMCPTool wrappers for tools cached at worker startup."""
97+
"""Return the tools fetched by the most recent ``_refresh``.
98+
99+
This must stay free of any ``workflow`` API: Strands invokes it once at
100+
``Agent`` construction on a separate ``run_async`` thread that has no
101+
workflow runtime. ``TemporalAgent`` populates the tools by calling
102+
``_refresh`` from a ``BeforeModelCallEvent`` hook before the registry is
103+
first read.
104+
"""
105+
return list(self._tools)
106+
107+
async def _refresh(self) -> None:
108+
"""List the server's tools via the ``{server}-list-tools`` activity.
109+
110+
Runs on the workflow event loop (dispatched from ``TemporalAgent``'s
111+
hook), so the activity result is recorded in history and replay-safe.
112+
"""
93113
from ._temporal_mcp_tool import TemporalMCPTool
94114

95-
infos = _TOOL_CACHE.get(self._server, [])
96-
return [TemporalMCPTool(self._server, info, self._options) for info in infos]
115+
infos: list[_MCPToolInfo] = await workflow.execute_activity(
116+
f"{self._server}-list-tools",
117+
result_type=list[_MCPToolInfo],
118+
**self._options,
119+
)
120+
self._tools = [
121+
TemporalMCPTool(self._server, info, self._options) for info in infos
122+
]
123+
self._fetched = True
97124

98125
def add_consumer(self, consumer_id: Any, **_kwargs: Any) -> None:
99126
"""No-op; consumer tracking is handled by the underlying MCP client."""
@@ -104,45 +131,37 @@ def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None:
104131
return None
105132

106133

107-
# Use MCP sessions directly instead of MCPClient's background-thread helpers.
108-
# Those helpers route calls through cross-loop futures that are unreliable on
109-
# Python 3.10 when invoked from Temporal's async worker/activity event loops.
110-
async def _list_mcp_tools(client: MCPClient) -> Sequence[Tool]:
111-
async with client._transport_callable() as (read_stream, write_stream, *_):
112-
async with ClientSession(
113-
read_stream,
114-
write_stream,
115-
elicitation_callback=client._elicitation_callback,
116-
) as session:
117-
await session.initialize()
118-
tools: list[Tool] = []
119-
pagination_token = None
120-
while True:
121-
page = await session.list_tools(
122-
params=PaginatedRequestParams(cursor=pagination_token)
123-
if pagination_token is not None
124-
else None
125-
)
126-
tools.extend(page.tools)
127-
pagination_token = page.nextCursor
128-
if pagination_token is None:
129-
return tools
130-
131-
132-
def _agent_tool_for_filtering(client: MCPClient, tool: Tool) -> MCPAgentTool:
133-
if client._prefix:
134-
return MCPAgentTool(tool, client, name_override=f"{client._prefix}_{tool.name}")
135-
return MCPAgentTool(tool, client)
136-
137-
138-
async def populate_cache(server: str, client_factory: Callable[[], MCPClient]) -> None:
139-
"""Connect to the MCP server, list tools, fill ``_TOOL_CACHE``."""
140-
client = client_factory()
134+
# Use the MCP session directly instead of MCPClient's background-thread
135+
# helpers. Those helpers route calls through cross-loop futures that are
136+
# unreliable on Python 3.10 when invoked from Temporal's async worker/activity
137+
# event loops.
138+
async def _paginate_list_tools(session: ClientSession) -> list[Tool]:
139+
tools: list[Tool] = []
140+
pagination_token = None
141+
while True:
142+
page = await session.list_tools(
143+
params=PaginatedRequestParams(cursor=pagination_token)
144+
if pagination_token is not None
145+
else None
146+
)
147+
tools.extend(page.tools)
148+
pagination_token = page.nextCursor
149+
if pagination_token is None:
150+
return tools
151+
152+
153+
def _tool_infos(client: MCPClient, tools: Sequence[Tool]) -> list[_MCPToolInfo]:
154+
"""Apply the client's tool filters and project to serializable records."""
141155
infos: list[_MCPToolInfo] = []
142-
for tool in await _list_mcp_tools(client):
156+
for tool in tools:
157+
if client._prefix:
158+
agent_tool = MCPAgentTool(
159+
tool, client, name_override=f"{client._prefix}_{tool.name}"
160+
)
161+
else:
162+
agent_tool = MCPAgentTool(tool, client)
143163
if not client._should_include_tool_with_filters(
144-
_agent_tool_for_filtering(client, tool),
145-
client._tool_filters,
164+
agent_tool, client._tool_filters
146165
):
147166
continue
148167
infos.append(
@@ -153,12 +172,7 @@ async def populate_cache(server: str, client_factory: Callable[[], MCPClient]) -
153172
output_schema=tool.outputSchema,
154173
)
155174
)
156-
_TOOL_CACHE[server] = infos
157-
158-
159-
def clear_cache(server: str) -> None:
160-
"""Drop the cached tool list for ``server``."""
161-
_TOOL_CACHE.pop(server, None)
175+
return infos
162176

163177

164178
# Default for how long an idle MCP connection stays open before it is
@@ -324,3 +338,31 @@ async def call_tool(args: _CallToolArgs) -> MCPToolResult:
324338
record.release()
325339

326340
return call_tool
341+
342+
343+
def build_list_tools_activity(
344+
server: str,
345+
client_factory: Callable[[], MCPClient],
346+
idle_timeout: timedelta | None = None,
347+
) -> Callable:
348+
"""Return the per-server ``{server}-list-tools`` activity for registration.
349+
350+
Lists the server's tools (applying the client's tool filters) and reuses
351+
the same lazily-opened, idle-evicted worker-process MCP session as
352+
``{server}-call-tool``.
353+
"""
354+
idle = idle_timeout if idle_timeout is not None else _MCP_CONNECTION_IDLE
355+
356+
@activity.defn(name=f"{server}-list-tools")
357+
async def list_tools() -> list[_MCPToolInfo]:
358+
client, session, record = await get_connection(server, client_factory, idle)
359+
try:
360+
return _tool_infos(client, await _paginate_list_tools(session))
361+
except Exception:
362+
# The session may be broken; drop it so the next call reconnects.
363+
await _evict_connection(server)
364+
raise
365+
finally:
366+
record.release()
367+
368+
return list_tools

0 commit comments

Comments
 (0)