@@ -328,6 +328,40 @@ def something(a: int, ctx: Context) -> int: # pragma: no cover
328328 assert "ctx" not in tool .fn_metadata .arg_model .model_fields
329329
330330
331+ def test_context_arg_excluded_from_callable_object_schema ():
332+ class MyTool :
333+ def __init__ (self ):
334+ self .__name__ = "MyTool"
335+
336+ async def __call__ (self , query : str , ctx : Context ) -> str : # pragma: no cover
337+ return query
338+
339+ manager = ToolManager ()
340+ tool = manager .add_tool (MyTool ())
341+
342+ assert tool .context_kwarg == "ctx"
343+ assert "ctx" not in json .dumps (tool .parameters )
344+ assert "Context" not in json .dumps (tool .parameters )
345+ assert "ctx" not in tool .fn_metadata .arg_model .model_fields
346+
347+
348+ @pytest .mark .anyio
349+ async def test_context_injected_into_callable_object ():
350+ class MyTool :
351+ def __init__ (self ):
352+ self .__name__ = "MyTool"
353+
354+ async def __call__ (self , query : str , ctx : Context ) -> str :
355+ assert isinstance (ctx , Context )
356+ return query
357+
358+ manager = ToolManager ()
359+ manager .add_tool (MyTool ())
360+
361+ result = await manager .call_tool ("MyTool" , {"query" : "hello" }, context = Context ())
362+ assert result == "hello"
363+
364+
331365class TestContextHandling :
332366 """Test context handling in the tool manager."""
333367
0 commit comments