Skip to content

Commit 243fbb9

Browse files
committed
Support detect context parameter with generic alias and forward reference
1 parent c2ca8e0 commit 243fbb9

2 files changed

Lines changed: 45 additions & 6 deletions

File tree

src/mcp/server/fastmcp/tools/base.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import inspect
44
from collections.abc import Callable
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, get_type_hints
66

77
from pydantic import BaseModel, Field
88

@@ -15,6 +15,18 @@
1515
from mcp.shared.context import LifespanContextT
1616

1717

18+
def _is_context_type(annotation: type[Any]) -> bool:
19+
from mcp.server.fastmcp import Context
20+
21+
if annotation is Context:
22+
return True
23+
if (
24+
generic_metadata := getattr(annotation, "__pydantic_generic_metadata__", None)
25+
) is not None:
26+
return _is_context_type(generic_metadata["origin"])
27+
return False
28+
29+
1830
class Tool(BaseModel):
1931
"""Internal tool registration info."""
2032

@@ -40,8 +52,6 @@ def from_function(
4052
context_kwarg: str | None = None,
4153
) -> Tool:
4254
"""Create a Tool from a function."""
43-
from mcp.server.fastmcp import Context
44-
4555
func_name = name or fn.__name__
4656

4757
if func_name == "<lambda>":
@@ -51,9 +61,11 @@ def from_function(
5161
is_async = inspect.iscoroutinefunction(fn)
5262

5363
if context_kwarg is None:
54-
sig = inspect.signature(fn)
55-
for param_name, param in sig.parameters.items():
56-
if param.annotation is Context:
64+
type_hints = get_type_hints(fn)
65+
for param_name, param_type in type_hints.items():
66+
if param_name == "return":
67+
continue
68+
if _is_context_type(param_type):
5769
context_kwarg = param_name
5870
break
5971

tests/server/fastmcp/test_server.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
if TYPE_CHECKING:
2424
from mcp.server.fastmcp import Context
25+
from mcp.server.session import ServerSession
2526

2627

2728
class TestServer:
@@ -480,6 +481,32 @@ def tool_with_context(x: int, ctx: Context) -> str:
480481
tool = mcp._tool_manager.add_tool(tool_with_context)
481482
assert tool.context_kwarg == "ctx"
482483

484+
@pytest.mark.anyio
485+
async def test_context_detection_forward_ref(self):
486+
"""
487+
Test that context parameters are properly detected with forward references.
488+
"""
489+
mcp = FastMCP()
490+
491+
def tool_with_context(x: int, ctx: "Context") -> str:
492+
return f"Request {ctx.request_id}: {x}"
493+
494+
tool = mcp._tool_manager.add_tool(tool_with_context)
495+
assert tool.context_kwarg == "ctx"
496+
497+
@pytest.mark.anyio
498+
async def test_context_detection_generic_alias(self):
499+
"""Test that context parameters are properly detected with generic alias."""
500+
mcp = FastMCP()
501+
502+
class AppContext: ...
503+
504+
def tool_with_context(x: int, ctx: Context["ServerSession", AppContext]) -> str:
505+
return f"Request {ctx.request_id}: {x}"
506+
507+
tool = mcp._tool_manager.add_tool(tool_with_context)
508+
assert tool.context_kwarg == "ctx"
509+
483510
@pytest.mark.anyio
484511
async def test_context_injection(self):
485512
"""Test that context is properly injected into tool calls."""

0 commit comments

Comments
 (0)