Skip to content

Commit 1bc96cd

Browse files
authored
feat: #1788 add configurable MCP tool failure handlers (#2378)
1 parent 21ac1d2 commit 1bc96cd

5 files changed

Lines changed: 215 additions & 27 deletions

File tree

src/agents/agent.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ class MCPConfig(TypedDict):
102102
best-effort conversion, so some schemas may not be convertible. Defaults to False.
103103
"""
104104

105+
failure_error_function: NotRequired[ToolErrorFunction | None]
106+
"""Optional function to convert MCP tool failures into model-visible messages. If explicitly
107+
set to None, tool errors will be raised instead. If unset, defaults to
108+
default_tool_error_function.
109+
"""
110+
105111

106112
@dataclass
107113
class AgentBase(Generic[TContext]):
@@ -135,8 +141,15 @@ class AgentBase(Generic[TContext]):
135141
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
136142
"""Fetches the available tools from the MCP servers."""
137143
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
144+
failure_error_function = self.mcp_config.get(
145+
"failure_error_function", default_tool_error_function
146+
)
138147
return await MCPUtil.get_all_function_tools(
139-
self.mcp_servers, convert_schemas_to_strict, run_context, self
148+
self.mcp_servers,
149+
convert_schemas_to_strict,
150+
run_context,
151+
self,
152+
failure_error_function=failure_error_function,
140153
)
141154

142155
async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:

src/agents/mcp/server.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from contextlib import AbstractAsyncContextManager, AsyncExitStack
99
from datetime import timedelta
1010
from pathlib import Path
11-
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
11+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast
1212

1313
import httpx
1414

@@ -26,6 +26,7 @@
2626
from ..exceptions import UserError
2727
from ..logger import logger
2828
from ..run_context import RunContextWrapper
29+
from ..tool import ToolErrorFunction
2930
from ..util._types import MaybeAwaitable
3031
from .util import HttpClientFactory, ToolFilter, ToolFilterContext, ToolFilterStatic
3132

@@ -48,6 +49,13 @@ class RequireApprovalObject(TypedDict, total=False):
4849

4950
T = TypeVar("T")
5051

52+
53+
class _UnsetType:
54+
pass
55+
56+
57+
_UNSET = _UnsetType()
58+
5159
if TYPE_CHECKING:
5260
from ..agent import AgentBase
5361

@@ -59,6 +67,7 @@ def __init__(
5967
self,
6068
use_structured_content: bool = False,
6169
require_approval: RequireApprovalSetting = None,
70+
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
6271
):
6372
"""
6473
Args:
@@ -70,11 +79,16 @@ def __init__(
7079
require_approval: Approval policy for tools on this server. Accepts "always"/"never",
7180
a dict of tool names to those values, a boolean, or an object with always/never
7281
tool lists (mirroring TS requireApproval). Normalized into a needs_approval policy.
82+
failure_error_function: Optional function used to convert MCP tool failures into
83+
a model-visible error message. If explicitly set to None, tool errors will be
84+
raised instead of converted. If left unset, the agent-level configuration (or
85+
SDK default) will be used.
7386
"""
7487
self.use_structured_content = use_structured_content
7588
self._needs_approval_policy = self._normalize_needs_approval(
7689
require_approval=require_approval
7790
)
91+
self._failure_error_function = failure_error_function
7892

7993
@abc.abstractmethod
8094
async def connect(self):
@@ -207,6 +221,14 @@ async def _needs_approval(
207221

208222
return bool(policy)
209223

224+
def _get_failure_error_function(
225+
self, agent_failure_error_function: ToolErrorFunction | None
226+
) -> ToolErrorFunction | None:
227+
"""Return the effective error handler for MCP tool failures."""
228+
if self._failure_error_function is _UNSET:
229+
return agent_failure_error_function
230+
return cast(ToolErrorFunction | None, self._failure_error_function)
231+
210232

211233
class _MCPServerWithClientSession(MCPServer, abc.ABC):
212234
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
@@ -221,6 +243,7 @@ def __init__(
221243
retry_backoff_seconds_base: float = 1.0,
222244
message_handler: MessageHandlerFnT | None = None,
223245
require_approval: RequireApprovalSetting = None,
246+
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
224247
):
225248
"""
226249
Args:
@@ -247,10 +270,15 @@ def __init__(
247270
require_approval: Approval policy for tools on this server. Accepts "always"/"never",
248271
a dict of tool names to those values, a boolean, or an object with always/never
249272
tool lists.
273+
failure_error_function: Optional function used to convert MCP tool failures into
274+
a model-visible error message. If explicitly set to None, tool errors will be
275+
raised instead of converted. If left unset, the agent-level configuration (or
276+
SDK default) will be used.
250277
"""
251278
super().__init__(
252279
use_structured_content=use_structured_content,
253280
require_approval=require_approval,
281+
failure_error_function=failure_error_function,
254282
)
255283
self.session: ClientSession | None = None
256284
self.exit_stack: AsyncExitStack = AsyncExitStack()
@@ -682,6 +710,7 @@ def __init__(
682710
retry_backoff_seconds_base: float = 1.0,
683711
message_handler: MessageHandlerFnT | None = None,
684712
require_approval: RequireApprovalSetting = None,
713+
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
685714
):
686715
"""Create a new MCP server based on the stdio transport.
687716
@@ -713,6 +742,10 @@ def __init__(
713742
ClientSession.
714743
require_approval: Approval policy for tools on this server. Accepts "always"/"never",
715744
a dict of tool names to those values, or an object with always/never tool lists.
745+
failure_error_function: Optional function used to convert MCP tool failures into
746+
a model-visible error message. If explicitly set to None, tool errors will be
747+
raised instead of converted. If left unset, the agent-level configuration (or
748+
SDK default) will be used.
716749
"""
717750
super().__init__(
718751
cache_tools_list,
@@ -723,6 +756,7 @@ def __init__(
723756
retry_backoff_seconds_base,
724757
message_handler=message_handler,
725758
require_approval=require_approval,
759+
failure_error_function=failure_error_function,
726760
)
727761

728762
self.params = StdioServerParameters(
@@ -788,6 +822,7 @@ def __init__(
788822
retry_backoff_seconds_base: float = 1.0,
789823
message_handler: MessageHandlerFnT | None = None,
790824
require_approval: RequireApprovalSetting = None,
825+
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
791826
):
792827
"""Create a new MCP server based on the HTTP with SSE transport.
793828
@@ -821,6 +856,10 @@ def __init__(
821856
ClientSession.
822857
require_approval: Approval policy for tools on this server. Accepts "always"/"never",
823858
a dict of tool names to those values, or an object with always/never tool lists.
859+
failure_error_function: Optional function used to convert MCP tool failures into
860+
a model-visible error message. If explicitly set to None, tool errors will be
861+
raised instead of converted. If left unset, the agent-level configuration (or
862+
SDK default) will be used.
824863
"""
825864
super().__init__(
826865
cache_tools_list,
@@ -831,6 +870,7 @@ def __init__(
831870
retry_backoff_seconds_base,
832871
message_handler=message_handler,
833872
require_approval=require_approval,
873+
failure_error_function=failure_error_function,
834874
)
835875

836876
self.params = params
@@ -899,6 +939,7 @@ def __init__(
899939
retry_backoff_seconds_base: float = 1.0,
900940
message_handler: MessageHandlerFnT | None = None,
901941
require_approval: RequireApprovalSetting = None,
942+
failure_error_function: ToolErrorFunction | None | _UnsetType = _UNSET,
902943
):
903944
"""Create a new MCP server based on the Streamable HTTP transport.
904945
@@ -933,6 +974,10 @@ def __init__(
933974
ClientSession.
934975
require_approval: Approval policy for tools on this server. Accepts "always"/"never",
935976
a dict of tool names to those values, or an object with always/never tool lists.
977+
failure_error_function: Optional function used to convert MCP tool failures into
978+
a model-visible error message. If explicitly set to None, tool errors will be
979+
raised instead of converted. If left unset, the agent-level configuration (or
980+
SDK default) will be used.
936981
"""
937982
super().__init__(
938983
cache_tools_list,
@@ -943,6 +988,7 @@ def __init__(
943988
retry_backoff_seconds_base,
944989
message_handler=message_handler,
945990
require_approval=require_approval,
991+
failure_error_function=failure_error_function,
946992
)
947993

948994
self.params = params

src/agents/mcp/util.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from __future__ import annotations
2+
13
import functools
24
import inspect
35
import json
46
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union
7+
from typing import TYPE_CHECKING, Any, Callable, Protocol, Union
68

79
import httpx
810
from typing_extensions import NotRequired, TypedDict
@@ -15,6 +17,7 @@
1517
from ..tool import (
1618
FunctionTool,
1719
Tool,
20+
ToolErrorFunction,
1821
ToolOutputImageDict,
1922
ToolOutputTextDict,
2023
default_tool_error_function,
@@ -24,8 +27,12 @@
2427
from ..util import _error_tracing
2528
from ..util._types import MaybeAwaitable
2629

27-
ToolOutputItem = Union[ToolOutputTextDict, ToolOutputImageDict]
28-
ToolOutput = Union[str, ToolOutputItem, list[ToolOutputItem]]
30+
if TYPE_CHECKING:
31+
ToolOutputItem = ToolOutputTextDict | ToolOutputImageDict
32+
ToolOutput = str | ToolOutputItem | list[ToolOutputItem]
33+
else:
34+
ToolOutputItem = Union[ToolOutputTextDict, ToolOutputImageDict] # noqa: UP007
35+
ToolOutput = Union[str, ToolOutputItem, list[ToolOutputItem]] # noqa: UP007
2936

3037
if TYPE_CHECKING:
3138
from mcp.types import Tool as MCPTool
@@ -43,9 +50,9 @@ class HttpClientFactory(Protocol):
4350

4451
def __call__(
4552
self,
46-
headers: Optional[dict[str, str]] = None,
47-
timeout: Optional[httpx.Timeout] = None,
48-
auth: Optional[httpx.Auth] = None,
53+
headers: dict[str, str] | None = None,
54+
timeout: httpx.Timeout | None = None,
55+
auth: httpx.Auth | None = None,
4956
) -> httpx.AsyncClient: ...
5057

5158

@@ -56,14 +63,17 @@ class ToolFilterContext:
5663
run_context: RunContextWrapper[Any]
5764
"""The current run context."""
5865

59-
agent: "AgentBase"
66+
agent: AgentBase
6067
"""The agent that is requesting the tool list."""
6168

6269
server_name: str
6370
"""The name of the MCP server."""
6471

6572

66-
ToolFilterCallable = Callable[["ToolFilterContext", "MCPTool"], MaybeAwaitable[bool]]
73+
if TYPE_CHECKING:
74+
ToolFilterCallable = Callable[[ToolFilterContext, MCPTool], MaybeAwaitable[bool]]
75+
else:
76+
ToolFilterCallable = Callable[[ToolFilterContext, Any], MaybeAwaitable[bool]]
6777
"""A function that determines whether a tool should be available.
6878
6979
Args:
@@ -87,14 +97,17 @@ class ToolFilterStatic(TypedDict):
8797
If set, these tools will be filtered out."""
8898

8999

90-
ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None]
100+
if TYPE_CHECKING:
101+
ToolFilter = ToolFilterCallable | ToolFilterStatic | None
102+
else:
103+
ToolFilter = Union[ToolFilterCallable, ToolFilterStatic, None] # noqa: UP007
91104
"""A tool filter that can be either a function, static configuration, or None (no filtering)."""
92105

93106

94107
def create_static_tool_filter(
95-
allowed_tool_names: Optional[list[str]] = None,
96-
blocked_tool_names: Optional[list[str]] = None,
97-
) -> Optional[ToolFilterStatic]:
108+
allowed_tool_names: list[str] | None = None,
109+
blocked_tool_names: list[str] | None = None,
110+
) -> ToolFilterStatic | None:
98111
"""Create a static tool filter from allowlist and blocklist parameters.
99112
100113
This is a convenience function for creating a ToolFilterStatic.
@@ -124,17 +137,22 @@ class MCPUtil:
124137
@classmethod
125138
async def get_all_function_tools(
126139
cls,
127-
servers: list["MCPServer"],
140+
servers: list[MCPServer],
128141
convert_schemas_to_strict: bool,
129142
run_context: RunContextWrapper[Any],
130-
agent: "AgentBase",
143+
agent: AgentBase,
144+
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
131145
) -> list[Tool]:
132146
"""Get all function tools from a list of MCP servers."""
133147
tools = []
134148
tool_names: set[str] = set()
135149
for server in servers:
136150
server_tools = await cls.get_function_tools(
137-
server, convert_schemas_to_strict, run_context, agent
151+
server,
152+
convert_schemas_to_strict,
153+
run_context,
154+
agent,
155+
failure_error_function=failure_error_function,
138156
)
139157
server_tool_names = {tool.name for tool in server_tools}
140158
if len(server_tool_names & tool_names) > 0:
@@ -150,10 +168,11 @@ async def get_all_function_tools(
150168
@classmethod
151169
async def get_function_tools(
152170
cls,
153-
server: "MCPServer",
171+
server: MCPServer,
154172
convert_schemas_to_strict: bool,
155173
run_context: RunContextWrapper[Any],
156-
agent: "AgentBase",
174+
agent: AgentBase,
175+
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
157176
) -> list[Tool]:
158177
"""Get all function tools from a single MCP server."""
159178

@@ -162,19 +181,30 @@ async def get_function_tools(
162181
span.span_data.result = [tool.name for tool in tools]
163182

164183
return [
165-
cls.to_function_tool(tool, server, convert_schemas_to_strict, agent) for tool in tools
184+
cls.to_function_tool(
185+
tool,
186+
server,
187+
convert_schemas_to_strict,
188+
agent,
189+
failure_error_function=failure_error_function,
190+
)
191+
for tool in tools
166192
]
167193

168194
@classmethod
169195
def to_function_tool(
170196
cls,
171-
tool: "MCPTool",
172-
server: "MCPServer",
197+
tool: MCPTool,
198+
server: MCPServer,
173199
convert_schemas_to_strict: bool,
174-
agent: "AgentBase",
200+
agent: AgentBase,
201+
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
175202
) -> FunctionTool:
176203
"""Convert an MCP tool to an Agents SDK function tool."""
177204
invoke_func_impl = functools.partial(cls.invoke_mcp_tool, server, tool)
205+
effective_failure_error_function = server._get_failure_error_function(
206+
failure_error_function
207+
)
178208
schema, is_strict = tool.inputSchema, False
179209

180210
# MCP spec doesn't require the inputSchema to have `properties`, but OpenAI spec does.
@@ -195,8 +225,11 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput:
195225
try:
196226
return await invoke_func_impl(ctx, input_json)
197227
except Exception as e:
198-
# Use default error handling function to convert exception to error message.
199-
result = default_tool_error_function(ctx, e)
228+
if effective_failure_error_function is None:
229+
raise
230+
231+
# Use configured error handling function to convert exception to error message.
232+
result = effective_failure_error_function(ctx, e)
200233
if inspect.isawaitable(result):
201234
result = await result
202235

@@ -233,7 +266,7 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput:
233266

234267
@classmethod
235268
async def invoke_mcp_tool(
236-
cls, server: "MCPServer", tool: "MCPTool", context: RunContextWrapper[Any], input_json: str
269+
cls, server: MCPServer, tool: MCPTool, context: RunContextWrapper[Any], input_json: str
237270
) -> ToolOutput:
238271
"""Invoke an MCP tool and return the result as a string."""
239272
try:

0 commit comments

Comments
 (0)