diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index f7b0b3be8b..33bea065c5 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -177,6 +177,28 @@ def create_static_tool_filter( class MCPUtil: """Set of utilities for interop between MCP and Agents SDK tools.""" + @staticmethod + def _extract_static_meta(tool: Any) -> dict[str, Any] | None: + meta = getattr(tool, "meta", None) + if isinstance(meta, dict): + return copy.deepcopy(meta) + + model_extra = getattr(tool, "model_extra", None) + if isinstance(model_extra, dict): + extra_meta = model_extra.get("meta") + if isinstance(extra_meta, dict): + return copy.deepcopy(extra_meta) + + model_dump = getattr(tool, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + dumped_meta = dumped.get("meta") + if isinstance(dumped_meta, dict): + return copy.deepcopy(dumped_meta) + + return None + @classmethod async def get_all_function_tools( cls, @@ -251,7 +273,13 @@ def to_function_tool( policies. If the server uses a callable approval policy, approvals default to required to avoid bypassing dynamic checks. """ - invoke_func_impl = functools.partial(cls.invoke_mcp_tool, server, tool) + static_meta = cls._extract_static_meta(tool) + invoke_func_impl = functools.partial( + cls.invoke_mcp_tool, + server, + tool, + meta=static_meta, + ) effective_failure_error_function = server._get_failure_error_function( failure_error_function ) diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index 0c33a3d313..5a9cbd140c 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -149,6 +149,62 @@ def resolve_meta(context): assert args == {"foo": "bar"} +@pytest.mark.asyncio +async def test_to_function_tool_passes_static_mcp_meta(): + server = FakeMCPServer() + tool = MCPTool( + name="test_tool_1", + inputSchema={}, + _meta={"locale": "en", "extra": "value"}, + ) + + function_tool = MCPUtil.to_function_tool(tool, server, convert_schemas_to_strict=False) + tool_context = ToolContext( + context=None, + tool_name="test_tool_1", + tool_call_id="test_call_static_meta", + tool_arguments="{}", + ) + + await function_tool.on_invoke_tool(tool_context, "{}") + + assert server.tool_metas[-1] == {"locale": "en", "extra": "value"} + + +@pytest.mark.asyncio +async def test_to_function_tool_merges_static_mcp_meta_with_resolver(): + captured: dict[str, Any] = {} + + def resolve_meta(context): + captured["run_context"] = context.run_context + captured["server_name"] = context.server_name + captured["tool_name"] = context.tool_name + captured["arguments"] = context.arguments + return {"request_id": "req-123", "locale": "ja"} + + server = FakeMCPServer(tool_meta_resolver=resolve_meta) + tool = MCPTool( + name="test_tool_1", + inputSchema={}, + _meta={"locale": "en", "extra": "value"}, + ) + + function_tool = MCPUtil.to_function_tool(tool, server, convert_schemas_to_strict=False) + tool_context = ToolContext( + context={"request_id": "req-123"}, + tool_name="test_tool_1", + tool_call_id="test_call_static_meta_with_resolver", + tool_arguments="{}", + ) + + await function_tool.on_invoke_tool(tool_context, "{}") + + assert server.tool_metas[-1] == {"request_id": "req-123", "locale": "en", "extra": "value"} + assert captured["server_name"] == server.name + assert captured["tool_name"] == "test_tool_1" + assert captured["arguments"] == {} + + @pytest.mark.asyncio async def test_mcp_invoke_bad_json_errors(caplog: pytest.LogCaptureFixture): caplog.set_level(logging.DEBUG)