Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,28 @@ def create_static_tool_filter(
class MCPUtil:
"""Set of utilities for interop between MCP and Agents SDK tools."""

@staticmethod
def _extract_static_meta(tool: Any) -> dict[str, Any] | None:
meta = getattr(tool, "meta", None)
if isinstance(meta, dict):
return copy.deepcopy(meta)

model_extra = getattr(tool, "model_extra", None)
if isinstance(model_extra, dict):
extra_meta = model_extra.get("meta")
if isinstance(extra_meta, dict):
return copy.deepcopy(extra_meta)

model_dump = getattr(tool, "model_dump", None)
if callable(model_dump):
dumped = model_dump()
if isinstance(dumped, dict):
dumped_meta = dumped.get("meta")
if isinstance(dumped_meta, dict):
return copy.deepcopy(dumped_meta)

return None

@classmethod
async def get_all_function_tools(
cls,
Expand Down Expand Up @@ -251,7 +273,13 @@ def to_function_tool(
policies. If the server uses a callable approval policy, approvals default
to required to avoid bypassing dynamic checks.
"""
invoke_func_impl = functools.partial(cls.invoke_mcp_tool, server, tool)
static_meta = cls._extract_static_meta(tool)
invoke_func_impl = functools.partial(
cls.invoke_mcp_tool,
server,
tool,
meta=static_meta,
)
effective_failure_error_function = server._get_failure_error_function(
failure_error_function
)
Expand Down
56 changes: 56 additions & 0 deletions tests/mcp/test_mcp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,62 @@ def resolve_meta(context):
assert args == {"foo": "bar"}


@pytest.mark.asyncio
async def test_to_function_tool_passes_static_mcp_meta():
server = FakeMCPServer()
tool = MCPTool(
name="test_tool_1",
inputSchema={},
_meta={"locale": "en", "extra": "value"},
)

function_tool = MCPUtil.to_function_tool(tool, server, convert_schemas_to_strict=False)
tool_context = ToolContext(
context=None,
tool_name="test_tool_1",
tool_call_id="test_call_static_meta",
tool_arguments="{}",
)

await function_tool.on_invoke_tool(tool_context, "{}")

assert server.tool_metas[-1] == {"locale": "en", "extra": "value"}


@pytest.mark.asyncio
async def test_to_function_tool_merges_static_mcp_meta_with_resolver():
captured: dict[str, Any] = {}

def resolve_meta(context):
captured["run_context"] = context.run_context
captured["server_name"] = context.server_name
captured["tool_name"] = context.tool_name
captured["arguments"] = context.arguments
return {"request_id": "req-123", "locale": "ja"}

server = FakeMCPServer(tool_meta_resolver=resolve_meta)
tool = MCPTool(
name="test_tool_1",
inputSchema={},
_meta={"locale": "en", "extra": "value"},
)

function_tool = MCPUtil.to_function_tool(tool, server, convert_schemas_to_strict=False)
tool_context = ToolContext(
context={"request_id": "req-123"},
tool_name="test_tool_1",
tool_call_id="test_call_static_meta_with_resolver",
tool_arguments="{}",
)

await function_tool.on_invoke_tool(tool_context, "{}")

assert server.tool_metas[-1] == {"request_id": "req-123", "locale": "en", "extra": "value"}
assert captured["server_name"] == server.name
assert captured["tool_name"] == "test_tool_1"
assert captured["arguments"] == {}


@pytest.mark.asyncio
async def test_mcp_invoke_bad_json_errors(caplog: pytest.LogCaptureFixture):
caplog.set_level(logging.DEBUG)
Expand Down
Loading