Skip to content

Commit 0726f5e

Browse files
committed
fix: detect context on callable tool objects
1 parent ac96f88 commit 0726f5e

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

src/mcp/server/mcpserver/utilities/context_injection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None:
2222
Returns:
2323
The name of the context parameter, or None if not found
2424
"""
25+
target = fn
26+
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
27+
target = fn.__call__
28+
2529
# Get type hints to properly resolve string annotations
2630
try:
27-
hints = typing.get_type_hints(fn)
31+
hints = typing.get_type_hints(target)
2832
except Exception: # pragma: lax no cover
2933
# If we can't resolve type hints, we can't find the context parameter
3034
return None

tests/server/mcpserver/test_tool_manager.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,40 @@ def something(a: int, ctx: Context) -> int: # pragma: no cover
328328
assert "ctx" not in tool.fn_metadata.arg_model.model_fields
329329

330330

331+
def test_context_arg_excluded_from_callable_object_schema():
332+
class MyTool:
333+
def __init__(self):
334+
self.__name__ = "MyTool"
335+
336+
async def __call__(self, query: str, ctx: Context) -> str: # pragma: no cover
337+
return query
338+
339+
manager = ToolManager()
340+
tool = manager.add_tool(MyTool())
341+
342+
assert tool.context_kwarg == "ctx"
343+
assert "ctx" not in json.dumps(tool.parameters)
344+
assert "Context" not in json.dumps(tool.parameters)
345+
assert "ctx" not in tool.fn_metadata.arg_model.model_fields
346+
347+
348+
@pytest.mark.anyio
349+
async def test_context_injected_into_callable_object():
350+
class MyTool:
351+
def __init__(self):
352+
self.__name__ = "MyTool"
353+
354+
async def __call__(self, query: str, ctx: Context) -> str:
355+
assert isinstance(ctx, Context)
356+
return query
357+
358+
manager = ToolManager()
359+
manager.add_tool(MyTool())
360+
361+
result = await manager.call_tool("MyTool", {"query": "hello"}, context=Context())
362+
assert result == "hello"
363+
364+
331365
class TestContextHandling:
332366
"""Test context handling in the tool manager."""
333367

0 commit comments

Comments
 (0)