Skip to content

Commit ae20d9e

Browse files
committed
Address proxy review feedback
1 parent 3c31190 commit ae20d9e

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

src/mcp/proxy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22

33
from __future__ import annotations
44

5-
import inspect
65
from collections.abc import AsyncGenerator, Awaitable, Callable
76
from contextlib import asynccontextmanager
7+
from functools import partial
8+
from typing import cast
89

910
import anyio
11+
from anyio import to_thread
1012

13+
from mcp.shared._callable_inspection import is_async_callable
1114
from mcp.shared._stream_protocols import ReadStream, WriteStream
1215
from mcp.shared.message import SessionMessage
1316

@@ -19,6 +22,7 @@
1922
async def mcp_proxy(
2023
transport_to_client: MessageStream,
2124
transport_to_server: MessageStream,
25+
*,
2226
on_error: ErrorHandler | None = None,
2327
) -> AsyncGenerator[None]:
2428
"""Proxy messages bidirectionally between two MCP transports."""
@@ -60,8 +64,9 @@ async def _run_error_handler(error: Exception, on_error: ErrorHandler | None) ->
6064
return
6165

6266
try:
63-
result = on_error(error)
64-
if inspect.isawaitable(result):
65-
await result
67+
if is_async_callable(on_error):
68+
await cast(Awaitable[None], on_error(error))
69+
else:
70+
await to_thread.run_sync(partial(on_error, error))
6671
except Exception:
6772
return

src/mcp/server/mcpserver/tools/base.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
import functools
4-
import inspect
53
from collections.abc import Callable
64
from functools import cached_property
75
from typing import TYPE_CHECKING, Any
@@ -11,6 +9,7 @@
119
from mcp.server.mcpserver.exceptions import ToolError
1210
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
1311
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata
12+
from mcp.shared._callable_inspection import is_async_callable
1413
from mcp.shared.exceptions import UrlElicitationRequiredError
1514
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
1615
from mcp.types import Icon, ToolAnnotations
@@ -63,7 +62,7 @@ def from_function(
6362
raise ValueError("You must provide a name for lambda functions")
6463

6564
func_doc = description or fn.__doc__ or ""
66-
is_async = _is_async_callable(fn)
65+
is_async = is_async_callable(fn)
6766

6867
if context_kwarg is None: # pragma: no branch
6968
context_kwarg = find_context_parameter(fn)
@@ -118,12 +117,3 @@ async def run(
118117
raise
119118
except Exception as e:
120119
raise ToolError(f"Error executing tool {self.name}: {e}") from e
121-
122-
123-
def _is_async_callable(obj: Any) -> bool:
124-
while isinstance(obj, functools.partial): # pragma: lax no cover
125-
obj = obj.func
126-
127-
return inspect.iscoroutinefunction(obj) or (
128-
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
129-
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import inspect
5+
from typing import Any
6+
7+
8+
def is_async_callable(obj: Any) -> bool:
9+
while isinstance(obj, functools.partial): # pragma: lax no cover
10+
obj = obj.func
11+
12+
return inspect.iscoroutinefunction(obj) or (
13+
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
14+
)

0 commit comments

Comments
 (0)