Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ class MCPConfig(TypedDict):
default_tool_error_function.
"""

include_server_in_tool_names: NotRequired[bool]
"""If True, MCP tools are exposed with an unambiguous server-specific prefix to avoid
collisions across servers that publish the same tool names. Defaults to False.
"""


@dataclass
class AgentBase(Generic[TContext]):
Expand Down Expand Up @@ -186,12 +191,14 @@ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[
failure_error_function = self.mcp_config.get(
"failure_error_function", default_tool_error_function
)
include_server_in_tool_names = self.mcp_config.get("include_server_in_tool_names", False)
return await MCPUtil.get_all_function_tools(
self.mcp_servers,
convert_schemas_to_strict,
run_context,
self,
failure_error_function=failure_error_function,
include_server_in_tool_names=include_server_in_tool_names,
)

async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
Expand Down
92 changes: 81 additions & 11 deletions src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import copy
import functools
import hashlib
import inspect
import json
from collections.abc import Awaitable
Expand Down Expand Up @@ -207,17 +208,23 @@ async def get_all_function_tools(
run_context: RunContextWrapper[Any],
agent: AgentBase,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
include_server_in_tool_names: bool = False,
) -> list[Tool]:
"""Get all function tools from a list of MCP servers."""
tools = []
tool_names: set[str] = set()
server_tool_name_prefixes = (
cls._server_tool_name_prefixes(servers) if include_server_in_tool_names else {}
)
for server in servers:
server_tools = await cls.get_function_tools(
server,
convert_schemas_to_strict,
run_context,
agent,
failure_error_function=failure_error_function,
include_server_in_tool_names=include_server_in_tool_names,
tool_name_prefix=server_tool_name_prefixes.get(id(server)),
)
server_tool_names = {tool.name for tool in server_tools}
if len(server_tool_names & tool_names) > 0:
Expand All @@ -238,24 +245,81 @@ async def get_function_tools(
run_context: RunContextWrapper[Any],
agent: AgentBase,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
include_server_in_tool_names: bool = False,
tool_name_prefix: str | None = None,
) -> list[Tool]:
"""Get all function tools from a single MCP server."""

with mcp_tools_span(server=server.name) as span:
tools = await server.list_tools(run_context, agent)
span.span_data.result = [tool.name for tool in tools]

if tool_name_prefix is None:
tool_name_prefix = (
cls._server_tool_name_prefix(server.name) if include_server_in_tool_names else ""
)
return [
cls.to_function_tool(
tool,
server,
convert_schemas_to_strict,
agent,
failure_error_function=failure_error_function,
tool_name_override=(
cls._prefixed_tool_name(tool_name_prefix, tool.name)
if tool_name_prefix
else None
),
)
for tool in tools
]

@staticmethod
def _server_tool_name_prefix(server_name: str) -> str:
normalized = "".join(
char if (char.isascii() and char.isalnum()) or char in ("_", "-") else "_"
for char in server_name
)
normalized = normalized.strip("_-")
if not normalized:
normalized = "server"
return f"{normalized}_"

@staticmethod
def _prefixed_tool_name(tool_name_prefix: str, tool_name: str) -> str:
full_name = f"mcp_{len(tool_name_prefix)}_{tool_name_prefix}{tool_name}"
if len(full_name) <= 64:
return full_name
# Truncate to 64 chars using a deterministic hash suffix to avoid collisions
hash_suffix = hashlib.sha1(full_name.encode("utf-8")).hexdigest()[:8]
# Reserve 9 chars for "_" + 8-char hash
truncated = full_name[: 64 - 9]
return f"{truncated}_{hash_suffix}"

@classmethod
def _server_tool_name_prefixes(cls, servers: list[MCPServer]) -> dict[int, str]:
normalized_to_servers: dict[str, list[MCPServer]] = {}
for server in servers:
normalized_prefix = cls._server_tool_name_prefix(server.name)[:-1]
normalized_to_servers.setdefault(normalized_prefix, []).append(server)

prefixes: dict[int, str] = {}
for normalized_prefix, grouped_servers in normalized_to_servers.items():
if len(grouped_servers) == 1:
prefixes[id(grouped_servers[0])] = f"{normalized_prefix}_"
continue

seen_prefixes: set[str] = set()
for index, server in enumerate(grouped_servers, start=1):
hash_suffix = hashlib.sha1(server.name.encode("utf-8")).hexdigest()[:8]
prefix = f"{normalized_prefix}_{hash_suffix}"
if prefix in seen_prefixes:
Comment on lines +315 to +316
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Separate hashed prefixes from literal normalized prefixes

The disambiguation format for colliding server names (f"{normalized_prefix}_{hash_suffix}") is in the same namespace as normal single-server prefixes, so it can still collide in real inputs. For example, with server names foo, foo!, and foo_0beec7b5, the hashed prefix for foo becomes foo_0beec7b5_, which is identical to the literal normalized prefix for foo_0beec7b5; if both expose create_issue, get_all_function_tools still raises Duplicate tool names found across MCP servers even when include_server_in_tool_names=True.

Useful? React with 👍 / 👎.

prefix = f"{prefix}_{index}"
seen_prefixes.add(prefix)
prefixes[id(server)] = f"{prefix}_"

return prefixes

@classmethod
def to_function_tool(
cls,
Expand All @@ -264,6 +328,7 @@ def to_function_tool(
convert_schemas_to_strict: bool,
agent: AgentBase | None = None,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
tool_name_override: str | None = None,
) -> FunctionTool:
"""Convert an MCP tool to an Agents SDK function tool.

Expand All @@ -273,11 +338,13 @@ def to_function_tool(
policies. If the server uses a callable approval policy, approvals default
to required to avoid bypassing dynamic checks.
"""
tool_name = tool_name_override or tool.name
static_meta = cls._extract_static_meta(tool)
invoke_func_impl = functools.partial(
cls.invoke_mcp_tool,
server,
tool,
tool_display_name=tool_name,
meta=static_meta,
)
effective_failure_error_function = server._get_failure_error_function(
Expand All @@ -301,7 +368,7 @@ def to_function_tool(
) = server._get_needs_approval_for_tool(tool, agent)

function_tool = _build_wrapped_function_tool(
name=tool.name,
name=tool_name,
description=resolve_mcp_tool_description_for_model(tool),
params_json_schema=schema,
invoke_tool_impl=invoke_func_impl,
Expand Down Expand Up @@ -367,25 +434,28 @@ async def invoke_mcp_tool(
input_json: str,
*,
meta: dict[str, Any] | None = None,
tool_display_name: str | None = None,
) -> ToolOutput:
"""Invoke an MCP tool and return the result as ToolOutput."""
tool_name = tool_display_name or tool.name
try:
json_data: dict[str, Any] = json.loads(input_json) if input_json else {}
except Exception as e:
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invalid JSON input for tool {tool.name}")
logger.debug(f"Invalid JSON input for tool {tool_name}")
else:
logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}")
logger.debug(f"Invalid JSON input for tool {tool_name}: {input_json}")
raise ModelBehaviorError(
f"Invalid JSON input for tool {tool.name}: {input_json}"
f"Invalid JSON input for tool {tool_name}: {input_json}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invoking MCP tool {tool.name}")
logger.debug(f"Invoking MCP tool {tool_name}")
else:
logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}")
logger.debug(f"Invoking MCP tool {tool_name} with input {input_json}")

try:
# Keep meta resolution and server routing keyed by the original MCP tool name.
resolved_meta = await cls._resolve_meta(server, context, tool.name, json_data)
merged_meta = cls._merge_mcp_meta(resolved_meta, meta)
call_task = asyncio.create_task(
Expand Down Expand Up @@ -423,20 +493,20 @@ async def invoke_mcp_tool(
# failure_error_function=None will have the error raised as documented.
error_text = e.error.message if hasattr(e, "error") and e.error else str(e)
logger.warning(
f"MCP tool {tool.name} on server '{server.name}' returned an error: "
f"MCP tool {tool_name} on server '{server.name}' returned an error: "
f"{error_text}"
)
raise

logger.error(f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}")
logger.error(f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}")
raise AgentsException(
f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}"
f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"MCP tool {tool.name} completed.")
logger.debug(f"MCP tool {tool_name} completed.")
else:
logger.debug(f"MCP tool {tool.name} returned {result}")
logger.debug(f"MCP tool {tool_name} returned {result}")

# If structured content is requested and available, use it exclusively
tool_output: ToolOutput
Expand Down
Loading