diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index e137e8456c..2ae30efe10 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -2,7 +2,7 @@ import inspect from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, get_type_hints from pydantic import BaseModel, Field @@ -15,6 +15,18 @@ from mcp.shared.context import LifespanContextT +def _is_context_type(annotation: type[Any]) -> bool: + from mcp.server.fastmcp import Context + + if annotation is Context: + return True + if ( + generic_metadata := getattr(annotation, "__pydantic_generic_metadata__", None) + ) is not None: + return _is_context_type(generic_metadata["origin"]) + return False + + class Tool(BaseModel): """Internal tool registration info.""" @@ -40,8 +52,6 @@ def from_function( context_kwarg: str | None = None, ) -> Tool: """Create a Tool from a function.""" - from mcp.server.fastmcp import Context - func_name = name or fn.__name__ if func_name == "": @@ -51,9 +61,11 @@ def from_function( is_async = inspect.iscoroutinefunction(fn) if context_kwarg is None: - sig = inspect.signature(fn) - for param_name, param in sig.parameters.items(): - if param.annotation is Context: + type_hints = get_type_hints(fn) + for param_name, param_type in type_hints.items(): + if param_name == "return": + continue + if _is_context_type(param_type): context_kwarg = param_name break diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index e76e59c52e..3ddd4f7ba6 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from mcp.server.fastmcp import Context + from mcp.server.session import ServerSession class TestServer: @@ -480,6 +481,32 @@ def tool_with_context(x: int, ctx: Context) -> str: tool = mcp._tool_manager.add_tool(tool_with_context) assert tool.context_kwarg == "ctx" + @pytest.mark.anyio + async def test_context_detection_forward_ref(self): + """ + Test that context parameters are properly detected with forward references. + """ + mcp = FastMCP() + + def tool_with_context(x: int, ctx: "Context") -> str: + return f"Request {ctx.request_id}: {x}" + + tool = mcp._tool_manager.add_tool(tool_with_context) + assert tool.context_kwarg == "ctx" + + @pytest.mark.anyio + async def test_context_detection_generic_alias(self): + """Test that context parameters are properly detected with generic alias.""" + mcp = FastMCP() + + class AppContext: ... + + def tool_with_context(x: int, ctx: Context["ServerSession", AppContext]) -> str: + return f"Request {ctx.request_id}: {x}" + + tool = mcp._tool_manager.add_tool(tool_with_context) + assert tool.context_kwarg == "ctx" + @pytest.mark.anyio async def test_context_injection(self): """Test that context is properly injected into tool calls."""