-
Notifications
You must be signed in to change notification settings - Fork 3.4k
feat(mcp): #1167 prefix colliding MCP tool names #2788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9179a99
b9f25db
d710186
fa92be2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| import asyncio | ||
| import copy | ||
| import functools | ||
| import hashlib | ||
| import inspect | ||
| import json | ||
| from collections.abc import Awaitable | ||
|
|
@@ -207,17 +208,23 @@ async def get_all_function_tools( | |
| run_context: RunContextWrapper[Any], | ||
| agent: AgentBase, | ||
| failure_error_function: ToolErrorFunction | None = default_tool_error_function, | ||
| include_server_in_tool_names: bool = False, | ||
| ) -> list[Tool]: | ||
| """Get all function tools from a list of MCP servers.""" | ||
| tools = [] | ||
| tool_names: set[str] = set() | ||
| server_tool_name_prefixes = ( | ||
| cls._server_tool_name_prefixes(servers) if include_server_in_tool_names else {} | ||
| ) | ||
| for server in servers: | ||
| server_tools = await cls.get_function_tools( | ||
| server, | ||
| convert_schemas_to_strict, | ||
| run_context, | ||
| agent, | ||
| failure_error_function=failure_error_function, | ||
| include_server_in_tool_names=include_server_in_tool_names, | ||
| tool_name_prefix=server_tool_name_prefixes.get(id(server)), | ||
| ) | ||
| server_tool_names = {tool.name for tool in server_tools} | ||
| if len(server_tool_names & tool_names) > 0: | ||
|
|
@@ -238,24 +245,81 @@ async def get_function_tools( | |
| run_context: RunContextWrapper[Any], | ||
| agent: AgentBase, | ||
| failure_error_function: ToolErrorFunction | None = default_tool_error_function, | ||
| include_server_in_tool_names: bool = False, | ||
| tool_name_prefix: str | None = None, | ||
| ) -> list[Tool]: | ||
| """Get all function tools from a single MCP server.""" | ||
|
|
||
| with mcp_tools_span(server=server.name) as span: | ||
| tools = await server.list_tools(run_context, agent) | ||
| span.span_data.result = [tool.name for tool in tools] | ||
|
|
||
| if tool_name_prefix is None: | ||
| tool_name_prefix = ( | ||
| cls._server_tool_name_prefix(server.name) if include_server_in_tool_names else "" | ||
| ) | ||
| return [ | ||
| cls.to_function_tool( | ||
| tool, | ||
| server, | ||
| convert_schemas_to_strict, | ||
| agent, | ||
| failure_error_function=failure_error_function, | ||
| tool_name_override=( | ||
| cls._prefixed_tool_name(tool_name_prefix, tool.name) | ||
| if tool_name_prefix | ||
| else None | ||
| ), | ||
| ) | ||
| for tool in tools | ||
| ] | ||
|
|
||
| @staticmethod | ||
| def _server_tool_name_prefix(server_name: str) -> str: | ||
| normalized = "".join( | ||
| char if (char.isascii() and char.isalnum()) or char in ("_", "-") else "_" | ||
| for char in server_name | ||
| ) | ||
| normalized = normalized.strip("_-") | ||
| if not normalized: | ||
| normalized = "server" | ||
| return f"{normalized}_" | ||
|
|
||
| @staticmethod | ||
| def _prefixed_tool_name(tool_name_prefix: str, tool_name: str) -> str: | ||
| full_name = f"mcp_{len(tool_name_prefix)}_{tool_name_prefix}{tool_name}" | ||
| if len(full_name) <= 64: | ||
| return full_name | ||
| # Truncate to 64 chars using a deterministic hash suffix to avoid collisions | ||
| hash_suffix = hashlib.sha1(full_name.encode("utf-8")).hexdigest()[:8] | ||
| # Reserve 9 chars for "_" + 8-char hash | ||
| truncated = full_name[: 64 - 9] | ||
| return f"{truncated}_{hash_suffix}" | ||
|
|
||
| @classmethod | ||
| def _server_tool_name_prefixes(cls, servers: list[MCPServer]) -> dict[int, str]: | ||
| normalized_to_servers: dict[str, list[MCPServer]] = {} | ||
| for server in servers: | ||
| normalized_prefix = cls._server_tool_name_prefix(server.name)[:-1] | ||
| normalized_to_servers.setdefault(normalized_prefix, []).append(server) | ||
|
|
||
| prefixes: dict[int, str] = {} | ||
| for normalized_prefix, grouped_servers in normalized_to_servers.items(): | ||
| if len(grouped_servers) == 1: | ||
| prefixes[id(grouped_servers[0])] = f"{normalized_prefix}_" | ||
| continue | ||
|
|
||
| seen_prefixes: set[str] = set() | ||
| for index, server in enumerate(grouped_servers, start=1): | ||
| hash_suffix = hashlib.sha1(server.name.encode("utf-8")).hexdigest()[:8] | ||
| prefix = f"{normalized_prefix}_{hash_suffix}" | ||
| if prefix in seen_prefixes: | ||
|
Comment on lines
+315
to
+316
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The disambiguation format for colliding server names ( Useful? React with 👍 / 👎. |
||
| prefix = f"{prefix}_{index}" | ||
| seen_prefixes.add(prefix) | ||
| prefixes[id(server)] = f"{prefix}_" | ||
|
|
||
| return prefixes | ||
|
|
||
| @classmethod | ||
| def to_function_tool( | ||
| cls, | ||
|
|
@@ -264,6 +328,7 @@ def to_function_tool( | |
| convert_schemas_to_strict: bool, | ||
| agent: AgentBase | None = None, | ||
| failure_error_function: ToolErrorFunction | None = default_tool_error_function, | ||
| tool_name_override: str | None = None, | ||
| ) -> FunctionTool: | ||
| """Convert an MCP tool to an Agents SDK function tool. | ||
|
|
||
|
|
@@ -273,11 +338,13 @@ def to_function_tool( | |
| policies. If the server uses a callable approval policy, approvals default | ||
| to required to avoid bypassing dynamic checks. | ||
| """ | ||
| tool_name = tool_name_override or tool.name | ||
| static_meta = cls._extract_static_meta(tool) | ||
| invoke_func_impl = functools.partial( | ||
| cls.invoke_mcp_tool, | ||
| server, | ||
| tool, | ||
| tool_display_name=tool_name, | ||
| meta=static_meta, | ||
| ) | ||
| effective_failure_error_function = server._get_failure_error_function( | ||
|
|
@@ -301,7 +368,7 @@ def to_function_tool( | |
| ) = server._get_needs_approval_for_tool(tool, agent) | ||
|
|
||
| function_tool = _build_wrapped_function_tool( | ||
| name=tool.name, | ||
| name=tool_name, | ||
| description=resolve_mcp_tool_description_for_model(tool), | ||
| params_json_schema=schema, | ||
| invoke_tool_impl=invoke_func_impl, | ||
|
|
@@ -367,25 +434,28 @@ async def invoke_mcp_tool( | |
| input_json: str, | ||
| *, | ||
| meta: dict[str, Any] | None = None, | ||
| tool_display_name: str | None = None, | ||
| ) -> ToolOutput: | ||
| """Invoke an MCP tool and return the result as ToolOutput.""" | ||
| tool_name = tool_display_name or tool.name | ||
| try: | ||
| json_data: dict[str, Any] = json.loads(input_json) if input_json else {} | ||
| except Exception as e: | ||
| if _debug.DONT_LOG_TOOL_DATA: | ||
| logger.debug(f"Invalid JSON input for tool {tool.name}") | ||
| logger.debug(f"Invalid JSON input for tool {tool_name}") | ||
| else: | ||
| logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}") | ||
| logger.debug(f"Invalid JSON input for tool {tool_name}: {input_json}") | ||
| raise ModelBehaviorError( | ||
| f"Invalid JSON input for tool {tool.name}: {input_json}" | ||
| f"Invalid JSON input for tool {tool_name}: {input_json}" | ||
| ) from e | ||
|
|
||
| if _debug.DONT_LOG_TOOL_DATA: | ||
| logger.debug(f"Invoking MCP tool {tool.name}") | ||
| logger.debug(f"Invoking MCP tool {tool_name}") | ||
| else: | ||
| logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}") | ||
| logger.debug(f"Invoking MCP tool {tool_name} with input {input_json}") | ||
|
|
||
| try: | ||
| # Keep meta resolution and server routing keyed by the original MCP tool name. | ||
| resolved_meta = await cls._resolve_meta(server, context, tool.name, json_data) | ||
| merged_meta = cls._merge_mcp_meta(resolved_meta, meta) | ||
| call_task = asyncio.create_task( | ||
|
|
@@ -423,20 +493,20 @@ async def invoke_mcp_tool( | |
| # failure_error_function=None will have the error raised as documented. | ||
| error_text = e.error.message if hasattr(e, "error") and e.error else str(e) | ||
| logger.warning( | ||
| f"MCP tool {tool.name} on server '{server.name}' returned an error: " | ||
| f"MCP tool {tool_name} on server '{server.name}' returned an error: " | ||
| f"{error_text}" | ||
| ) | ||
| raise | ||
|
|
||
| logger.error(f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}") | ||
| logger.error(f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}") | ||
| raise AgentsException( | ||
| f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}" | ||
| f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}" | ||
| ) from e | ||
|
|
||
| if _debug.DONT_LOG_TOOL_DATA: | ||
| logger.debug(f"MCP tool {tool.name} completed.") | ||
| logger.debug(f"MCP tool {tool_name} completed.") | ||
| else: | ||
| logger.debug(f"MCP tool {tool.name} returned {result}") | ||
| logger.debug(f"MCP tool {tool_name} returned {result}") | ||
|
|
||
| # If structured content is requested and available, use it exclusively | ||
| tool_output: ToolOutput | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.