-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathmcp_utils.py
More file actions
282 lines (235 loc) · 9.86 KB
/
mcp_utils.py
File metadata and controls
282 lines (235 loc) · 9.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# SPDX-FileCopyrightText: GitHub, Inc.
# SPDX-License-Identifier: MIT
"""MCP client utilities.
Provides tool-name compression, namespace-aware MCP wrappers with
confirmation support, and toolbox parameter resolution.
"""
from __future__ import annotations
__all__ = [
"COMPRESSED_NAME_LENGTH",
"DEFAULT_MCP_CLIENT_SESSION_TIMEOUT",
"MCPNamespaceWrap",
"compress_name",
"mcp_client_params",
]
import hashlib
import json
import logging
import shutil
from typing import Any
from mcp.types import CallToolResult, TextContent
from .available_tools import AvailableTools
from .env_utils import swap_env
# Re-export transport classes and prompt builder so that existing
# ``from .mcp_utils import …`` statements continue to work.
from .mcp_prompt import mcp_system_prompt as mcp_system_prompt # noqa: F401
from .mcp_transport import ( # noqa: F401
AsyncDebugMCPServerStdio as AsyncDebugMCPServerStdio,
ReconnectingMCPServerStdio as ReconnectingMCPServerStdio,
StreamableMCPThread as StreamableMCPThread,
)
DEFAULT_MCP_CLIENT_SESSION_TIMEOUT: int = 120
# The OpenAI API rejects tool names longer than 64 characters.
# We hash long names down to this many hex characters.
COMPRESSED_NAME_LENGTH: int = 12
def compress_name(name: str) -> str:
"""Return a short hash of *name* to fit the OpenAI 64-char tool-name limit.
Args:
name: The original tool / toolbox name.
Returns:
A 12-character lowercase hex digest.
"""
m = hashlib.sha256()
m.update(name.encode("utf-8"))
return m.hexdigest()[:COMPRESSED_NAME_LENGTH]
class MCPNamespaceWrap:
"""MCP client wrapper that prefixes tool names with a namespace hash.
Also provides optional interactive confirmation before calling
specific tools.
Args:
confirms: Tool names that require user confirmation.
obj: The underlying MCP server/client object to wrap.
"""
def __init__(self, confirms: list[str], obj: Any) -> None:
self.confirms: list[str] = confirms
self._obj: Any = obj
self.namespace: str = compress_name(obj.name)
def __getattr__(self, name: str) -> Any:
attr = getattr(self._obj, name)
if callable(attr):
match name:
case "call_tool":
return self.call_tool
case "list_tools":
return self.list_tools
case _:
return attr
return attr
async def list_tools(self, *args: Any, **kwargs: Any) -> list[Any]:
"""List tools with namespace-prefixed names."""
result = await self._obj.list_tools(*args, **kwargs)
namespaced_tools: list[Any] = []
for tool in result:
tool_copy = tool.copy()
tool_copy.name = f"{self.namespace}{tool.name}"
namespaced_tools.append(tool_copy)
return namespaced_tools
def confirm_tool(self, tool_name: str, args: list[Any]) -> bool:
"""Interactively prompt the user for tool-call confirmation.
Args:
tool_name: The tool being invoked.
args: Positional arguments to display.
Returns:
``True`` if the user approved the call.
"""
while True:
yn = input(
f"** 🤖❗ Allow tool call?: {tool_name}({','.join([json.dumps(arg) for arg in args])}) (yes/no): "
)
if yn in ["yes", "y"]:
return True
if yn in ["no", "n"]:
return False
async def call_tool(self, *args: Any, **kwargs: Any) -> Any:
"""Call a tool, stripping the namespace prefix and optionally confirming."""
_args = list(args)
tool_name: str = _args[0]
tool_name = tool_name.removeprefix(self.namespace)
# to run headless, just make confirms an empty list
if self.confirms and tool_name in self.confirms:
if not self.confirm_tool(tool_name, _args[1:]):
result = CallToolResult(
content=[TextContent(type="text", text="Tool call not allowed.", annotations=None, meta=None)]
)
return result
_args[0] = tool_name
args = tuple(_args)
result = await self._obj.call_tool(*args, **kwargs)
return result
ClientParamsMap = dict[str, tuple[dict[str, Any], list[str], str | None, int | None]]
def mcp_client_params(
available_tools: AvailableTools,
requested_toolboxes: list[str],
) -> ClientParamsMap:
"""Resolve toolbox configs into MCP server connection parameters.
Args:
available_tools: The tool registry that can look up toolbox configs.
requested_toolboxes: Module paths of the toolboxes to resolve.
Returns:
A mapping from toolbox name to a tuple of
``(server_params, confirms, server_prompt, client_session_timeout)``.
Raises:
ValueError: If the transport kind is not supported.
FileNotFoundError: If a streamable command cannot be found on ``$PATH``.
"""
client_params: ClientParamsMap = {}
for tb in requested_toolboxes:
toolbox = available_tools.get_toolbox(tb)
sp = toolbox.server_params
kind: str = sp.kind
reconnecting: bool = sp.reconnecting
server_params: dict[str, Any] = {"kind": kind, "reconnecting": reconnecting}
match kind:
case "stdio":
env = dict(sp.env) if sp.env else None
args = list(sp.args) if sp.args else None
logging.debug("Initializing toolbox: %s\nargs:\n%s\nenv:\n%s\n", tb, args, env)
if env:
for k, v in list(env.items()):
try:
env[k] = swap_env(v)
except LookupError as e:
logging.critical(e)
logging.info("Assuming toolbox has default configuration available")
del env[k]
logging.debug("Tool call environment: %s", env)
if args:
for i, v in enumerate(args):
args[i] = swap_env(v)
logging.debug("Tool call args: %s", args)
server_params["command"] = sp.command
server_params["args"] = args
server_params["env"] = env
case "sse":
headers = _resolve_headers(sp.headers, sp.optional_headers)
server_params["url"] = sp.url
server_params["headers"] = headers
server_params["timeout"] = sp.timeout
case "streamable":
headers = _resolve_headers(sp.headers, sp.optional_headers)
server_params["url"] = sp.url
server_params["headers"] = headers
server_params["timeout"] = sp.timeout
if sp.command is not None:
env = dict(sp.env) if sp.env else None
args = list(sp.args) if sp.args else None
logging.debug("Initializing streamable toolbox: %s\nargs:\n%s\nenv:\n%s\n", tb, args, env)
exe = shutil.which(sp.command)
if exe is None:
raise FileNotFoundError(f"Could not resolve path to {sp.command}")
start_cmd = [exe]
if args:
for i, v in enumerate(args):
args[i] = swap_env(v)
start_cmd += args
server_params["command"] = start_cmd
if env:
for k, v in list(env.items()):
try:
env[k] = swap_env(v)
except LookupError as e:
logging.critical(e)
logging.info("Assuming toolbox has default configuration available")
del env[k]
server_params["env"] = env
case _:
raise ValueError(f"Unsupported MCP transport {kind}")
client_params[tb] = (
server_params,
list(toolbox.confirm),
toolbox.server_prompt,
toolbox.client_session_timeout,
)
return client_params
def _resolve_headers(
headers: dict[str, str] | None,
optional_headers: dict[str, str] | None,
) -> dict[str, str] | None:
"""Expand env references in headers and merge required + optional.
Required headers raise on missing env vars; optional headers are
silently dropped when a referenced variable is absent.
Args:
headers: Header dict whose values may contain ``{{ env('…') }}``.
optional_headers: Like *headers*, but missing env vars are tolerated.
Returns:
Merged header dict, or ``None`` if both inputs are ``None``.
"""
resolved: dict[str, str] | None = None
if headers:
resolved = dict(headers)
for k, v in resolved.items():
resolved[k] = swap_env(v)
resolved_optional: dict[str, str] | None = None
if optional_headers:
resolved_optional = dict(optional_headers)
for k, v in list(resolved_optional.items()):
try:
resolved_optional[k] = swap_env(v)
except LookupError:
del resolved_optional[k]
return _merge_headers(resolved, resolved_optional)
def _merge_headers(
headers: dict[str, str] | None,
optional_headers: dict[str, str] | None,
) -> dict[str, str] | None:
"""Merge required and optional header dicts.
Args:
headers: Required headers (may be ``None``).
optional_headers: Optional headers (may be ``None``).
Returns:
Combined header dict, or ``None`` if both are ``None``.
"""
if headers and optional_headers:
headers.update(optional_headers)
return headers
return headers or optional_headers