From 7edf47d44a88857ec8ef2a72ec3b6caa3b14b1e3 Mon Sep 17 00:00:00 2001 From: Juan Franco <91078895+m1lestones@users.noreply.github.com> Date: Mon, 13 Apr 2026 01:47:35 -0400 Subject: [PATCH] feat(tools): support @function_tool on class instance methods via descriptor protocol Fixes #94. Decorating a class method with @function_tool now works correctly: - `function_schema()` detects an unannotated leading `self` or `cls` parameter and sets `skips_receiver=True`, stripping the receiver from both the JSON schema and the stored signature so the LLM never sees it and `to_call_args()` never tries to populate it from model output. - `FunctionTool` gains a `__get__` descriptor method. Accessing a method tool on a class instance (e.g. `instance.my_tool`) returns a bound copy whose invoker prepends the instance as the first argument, so the underlying method receives the correct `self`. Accessing via the class returns the unbound tool. - A `_make_impl(receiver)` factory is stored on method tools; `__get__` calls it with the instance and wires the result into the copied tool's invoker. - A `RunContextWrapper`/`ToolContext` parameter immediately after `self`/`cls` is handled correctly (`takes_context=True`), as is the case where context is in the wrong position (still raises `UserError`). Co-Authored-By: Claude Sonnet 4.6 --- src/agents/function_schema.py | 40 ++++- src/agents/tool.py | 122 ++++++++++++- tests/test_method_tool.py | 327 ++++++++++++++++++++++++++++++++++ 3 files changed, 481 insertions(+), 8 deletions(-) create mode 100644 tests/test_method_tool.py diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index 881ebdf00f..e44989c2a7 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -36,6 +36,10 @@ class FuncSchema: """The signature of the function.""" takes_context: bool = False """Whether the function takes a RunContextWrapper argument (must be the first argument).""" + skips_receiver: bool = False + """Whether the function's leading ``self`` or ``cls`` parameter was stripped from the schema. + When True, the tool is a *method tool* and must be called with a receiver prepended to the + argument list (see :meth:`FunctionTool.__get__`).""" strict_json_schema: bool = True """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" @@ -286,23 +290,46 @@ def function_schema( sig = inspect.signature(func) params = list(sig.parameters.items()) takes_context = False + skips_receiver = False filtered_params = [] + # Index into `params` where non-receiver, non-context processing begins. + _params_start = 0 if params: first_name, first_param = params[0] # Prefer the evaluated type hint if available ann = type_hints.get(first_name, first_param.annotation) - if ann != inspect._empty: + if ann == inspect._empty and first_name in ("self", "cls"): + # Unannotated self/cls → this is an instance or class method receiver. + # Exclude it from the schema so the LLM never sees it; the tool's __get__ + # descriptor will supply the receiver at call time. + skips_receiver = True + _params_start = 1 + elif ann != inspect._empty: origin = get_origin(ann) or ann if origin is RunContextWrapper or origin is ToolContext: takes_context = True # Mark that the function takes context + _params_start = 1 else: filtered_params.append((first_name, first_param)) + _params_start = 1 else: filtered_params.append((first_name, first_param)) + _params_start = 1 + + # When the first param is a method receiver, the *next* param may be a context arg. + if skips_receiver and len(params) > 1: + second_name, second_param = params[1] + second_ann = type_hints.get(second_name, second_param.annotation) + if second_ann != inspect._empty: + origin = get_origin(second_ann) or second_ann + if origin is RunContextWrapper or origin is ToolContext: + takes_context = True + _params_start = 2 - # For parameters other than the first, raise error if any use RunContextWrapper or ToolContext. - for name, param in params[1:]: + # For parameters beyond the first (and optional context), raise an error if any use + # RunContextWrapper or ToolContext in an unsupported position. + for name, param in params[_params_start:]: ann = type_hints.get(name, param.annotation) if ann != inspect._empty: origin = get_origin(ann) or ann @@ -313,6 +340,12 @@ def function_schema( ) filtered_params.append((name, param)) + # If this is a method, strip the receiver from the stored signature so that + # to_call_args() never attempts to populate self/cls from LLM-supplied JSON. + if skips_receiver: + receiver_name = params[0][0] + sig = sig.replace(parameters=[p for n, p in sig.parameters.items() if n != receiver_name]) + # We will collect field definitions for create_model as a dict: # field_name -> (type_annotation, default_value_or_Field(...)) fields: dict[str, Any] = {} @@ -419,5 +452,6 @@ def function_schema( params_json_schema=json_schema, signature=sig, takes_context=takes_context, + skips_receiver=skips_receiver, strict_json_schema=strict_json_schema, ) diff --git a/src/agents/tool.py b/src/agents/tool.py index 1ac3c29ae3..0e6e01982d 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -353,6 +353,31 @@ def __copy__(self) -> FunctionTool: setattr(copied_tool, attr_name, attr_value) return copied_tool + def __get__(self, obj: Any, objtype: Any = None) -> FunctionTool: + """Descriptor protocol: bind this method tool to a class instance. + + When a :func:`function_tool`-decorated method is accessed on an instance + (e.g. ``my_instance.my_tool``), this returns a new :class:`FunctionTool` + whose invocation automatically prepends ``my_instance`` as the receiver, + so the underlying method receives the correct ``self``/``cls`` argument. + + Accessing the tool on the *class* (``MyClass.my_tool``) returns the + unbound :class:`FunctionTool` unchanged. + """ + if obj is None: + # Class-level access — return the unbound tool descriptor. + return self + make_impl = getattr(self, "_make_impl", None) + if make_impl is None: + # Not a method tool; behave as a plain attribute (no binding needed). + return self + # Build a copy and rewire its invoker to use the bound receiver. + bound_tool = copy.copy(self) + handler = bound_tool.on_invoke_tool + if isinstance(handler, _FailureHandlingFunctionToolInvoker): + handler._invoke_tool_impl = make_impl(obj) + return bound_tool + class _FailureHandlingFunctionToolInvoker: """Internal callable that rebinds wrapper error handling for copied FunctionTools.""" @@ -1669,6 +1694,97 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: strict_json_schema=strict_mode, ) + _on_handled_error = _build_handled_function_tool_error_handler( + span_message="Error running tool (non-fatal)", + span_message_for_json_decode_error="Error running tool", + log_label="Tool", + ) + + if schema.skips_receiver: + # The decorated function is an unbound instance/class method. We + # store a factory (_make_impl) on the returned FunctionTool so that + # the __get__ descriptor can produce a correctly-bound invoker when + # the tool is accessed via a class instance. + def _make_impl( + receiver: Any, + ) -> Callable[[ToolContext[Any], str], Awaitable[Any]]: + async def _method_invoke_impl(ctx: ToolContext[Any], input: str) -> Any: + tool_name = ctx.tool_name + json_data = _parse_function_tool_json_input( + tool_name=tool_name, input_json=input + ) + _log_function_tool_invocation(tool_name=tool_name, input_json=input) + + try: + parsed = ( + schema.params_pydantic_model(**json_data) + if json_data + else schema.params_pydantic_model() + ) + except ValidationError as e: + raise ModelBehaviorError( + f"Invalid JSON input for tool {tool_name}: {e}" + ) from e + + args, kwargs_dict = schema.to_call_args(parsed) + + if not _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}") + + if receiver is None: + raise UserError( + f"Tool '{schema.name}' was decorated on a class method and must be " + f"accessed via a class instance before being invoked. " + f"Use 'instance.{schema.name}' or bind the tool with " + f"'tool.__get__(instance)' before adding it to an agent." + ) + + if not is_sync_function_tool: + if schema.takes_context: + result = await the_func(receiver, ctx, *args, **kwargs_dict) + else: + result = await the_func(receiver, *args, **kwargs_dict) + else: + if schema.takes_context: + result = await asyncio.to_thread( + the_func, receiver, ctx, *args, **kwargs_dict + ) + else: + result = await asyncio.to_thread( + the_func, receiver, *args, **kwargs_dict + ) + + if _debug.DONT_LOG_TOOL_DATA: + logger.debug(f"Tool {tool_name} completed.") + else: + logger.debug(f"Tool {tool_name} returned {result}") + + return result + + return _method_invoke_impl + + function_tool = _build_wrapped_function_tool( + name=schema.name, + description=schema.description or "", + params_json_schema=schema.params_json_schema, + invoke_tool_impl=_make_impl(None), # unbound placeholder + on_handled_error=_on_handled_error, + failure_error_function=failure_error_function, + strict_json_schema=strict_mode, + is_enabled=is_enabled, + needs_approval=needs_approval, + tool_input_guardrails=tool_input_guardrails, + tool_output_guardrails=tool_output_guardrails, + timeout_seconds=timeout, + timeout_behavior=timeout_behavior, + timeout_error_function=timeout_error_function, + defer_loading=defer_loading, + sync_invoker=is_sync_function_tool, + ) + # Store the factory so __get__ can bind a receiver on instance access. + function_tool._make_impl = _make_impl # type: ignore[attr-defined] + return function_tool + async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: tool_name = ctx.tool_name json_data = _parse_function_tool_json_input(tool_name=tool_name, input_json=input) @@ -1711,11 +1827,7 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: description=schema.description or "", params_json_schema=schema.params_json_schema, invoke_tool_impl=_on_invoke_tool_impl, - on_handled_error=_build_handled_function_tool_error_handler( - span_message="Error running tool (non-fatal)", - span_message_for_json_decode_error="Error running tool", - log_label="Tool", - ), + on_handled_error=_on_handled_error, failure_error_function=failure_error_function, strict_json_schema=strict_mode, is_enabled=is_enabled, diff --git a/tests/test_method_tool.py b/tests/test_method_tool.py new file mode 100644 index 0000000000..27e781a958 --- /dev/null +++ b/tests/test_method_tool.py @@ -0,0 +1,327 @@ +"""Tests for @function_tool applied to class instance methods (issue #94). + +Covers the descriptor protocol (__get__), schema generation (self/cls stripped), +receiver binding at call time, and edge cases flagged in previous review rounds. +""" + +from __future__ import annotations + +import pytest + +from agents import RunContextWrapper +from agents.exceptions import UserError +from agents.function_schema import function_schema +from agents.tool import FunctionTool, function_tool +from agents.tool_context import ToolContext + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_tool_context(tool_name: str = "test_tool") -> ToolContext: + return ToolContext(context=None, tool_name=tool_name, tool_call_id="1", tool_arguments="") + + +# --------------------------------------------------------------------------- +# function_schema unit tests +# --------------------------------------------------------------------------- + + +class _BasicMethodClass: + multiplier: int + + def __init__(self, multiplier: int) -> None: + self.multiplier = multiplier + + def multiply(self, x: int) -> int: + """Multiply x by the instance multiplier. + + Args: + x: The value to multiply. + """ + return x * self.multiplier + + +def test_function_schema_strips_self() -> None: + """function_schema must not include self in the JSON schema.""" + schema = function_schema(_BasicMethodClass.multiply) + assert schema.skips_receiver is True + assert schema.takes_context is False + assert "self" not in schema.params_json_schema.get("properties", {}) + assert "x" in schema.params_json_schema.get("properties", {}) + + +def test_function_schema_strips_self_from_stored_signature() -> None: + """The stored signature must not include self so to_call_args never fetches it.""" + schema = function_schema(_BasicMethodClass.multiply) + assert "self" not in schema.signature.parameters + + +def test_function_schema_to_call_args_without_receiver() -> None: + """to_call_args must return only the non-receiver arguments.""" + schema = function_schema(_BasicMethodClass.multiply) + parsed = schema.params_pydantic_model(x=7) + args, kwargs = schema.to_call_args(parsed) + # args should be [7]; no None placeholder for self + assert args == [7] + assert kwargs == {} + + +def test_function_schema_cls_is_stripped() -> None: + """Leading cls parameter (unannotated) must also be treated as a receiver. + + We test the undecorated classmethod function directly (before @classmethod + binds cls) because @classmethod already hides cls from the signature. + """ + + # Define an unbound function that uses cls as its first unannotated param, + # as if it were the raw underlying function of a classmethod. + def greet(cls, name: str) -> str: + """Say hi. + + Args: + name: Who to greet. + """ + return f"hi {name}" + + schema = function_schema(greet) + assert schema.skips_receiver is True + assert "cls" not in schema.params_json_schema.get("properties", {}) + assert "name" in schema.params_json_schema.get("properties", {}) + + +def test_function_schema_annotated_self_not_stripped() -> None: + """A first parameter named self *with* a type annotation must not be stripped.""" + + def weird(self: int, y: int) -> int: + return self + y + + schema = function_schema(weird) + assert schema.skips_receiver is False + assert "self" in schema.params_json_schema.get("properties", {}) + + +def test_function_schema_self_with_context_param() -> None: + """self followed immediately by RunContextWrapper must set both flags correctly.""" + + class _WithCtx: + def act(self, ctx: RunContextWrapper[None], value: int) -> int: + """Do something. + + Args: + value: The input value. + """ + return value + + schema = function_schema(_WithCtx.act) + assert schema.skips_receiver is True + assert schema.takes_context is True + assert "self" not in schema.params_json_schema.get("properties", {}) + assert "ctx" not in schema.params_json_schema.get("properties", {}) + assert "value" in schema.params_json_schema.get("properties", {}) + + +def test_function_schema_context_in_wrong_position_raises() -> None: + """RunContextWrapper after self but not in position 1 must raise UserError.""" + + class _Bad: + def bad(self, x: int, ctx: RunContextWrapper[None]) -> int: + return x + + with pytest.raises(UserError, match="non-first position"): + function_schema(_Bad.bad) + + +def test_function_schema_regular_function_unchanged() -> None: + """function_schema on a plain function must behave exactly as before.""" + + def add(a: int, b: int) -> int: + """Add two numbers. + + Args: + a: First. + b: Second. + """ + return a + b + + schema = function_schema(add) + assert schema.skips_receiver is False + assert schema.takes_context is False + assert "a" in schema.params_json_schema.get("properties", {}) + assert "b" in schema.params_json_schema.get("properties", {}) + + +# --------------------------------------------------------------------------- +# FunctionTool descriptor / __get__ tests +# --------------------------------------------------------------------------- + + +class _Calculator: + def __init__(self, base: int) -> None: + self.base = base + + @function_tool + def add(self, x: int) -> int: + """Add x to the base. + + Args: + x: Value to add. + """ + return self.base + x + + @function_tool + async def async_add(self, x: int) -> int: + """Async-add x to the base. + + Args: + x: Value to add. + """ + return self.base + x + + +def test_class_level_access_returns_function_tool() -> None: + """Accessing the tool on the class should return a FunctionTool (unbound).""" + assert isinstance(_Calculator.add, FunctionTool) + + +def test_instance_access_returns_function_tool() -> None: + """Accessing the tool on an instance should also return a FunctionTool.""" + calc = _Calculator(base=10) + assert isinstance(calc.add, FunctionTool) + + +def test_instance_access_returns_different_object() -> None: + """Each instance access should produce a new (bound) FunctionTool.""" + calc = _Calculator(base=10) + bound1 = calc.add + bound2 = calc.add + assert bound1 is not bound2 + assert bound1 is not _Calculator.add + + +def test_bound_tool_schema_unchanged() -> None: + """The schema of the bound tool must be identical to the class-level tool.""" + calc = _Calculator(base=10) + assert calc.add.params_json_schema == _Calculator.add.params_json_schema + assert calc.add.name == _Calculator.add.name + + +@pytest.mark.asyncio +async def test_bound_tool_invokes_correct_instance() -> None: + """The bound tool must call the method on the correct instance.""" + calc5 = _Calculator(base=5) + calc20 = _Calculator(base=20) + + ctx = _make_tool_context("add") + result5 = await calc5.add.on_invoke_tool(ctx, '{"x": 3}') + result20 = await calc20.add.on_invoke_tool(ctx, '{"x": 3}') + + assert result5 == 8 # 5 + 3 + assert result20 == 23 # 20 + 3 + + +@pytest.mark.asyncio +async def test_async_bound_tool_invokes_correct_instance() -> None: + """The async variant also dispatches to the right instance.""" + calc = _Calculator(base=100) + ctx = _make_tool_context("async_add") + result = await calc.async_add.on_invoke_tool(ctx, '{"x": 1}') + assert result == 101 + + +@pytest.mark.asyncio +async def test_unbound_tool_returns_error_message() -> None: + """Calling the class-level (unbound) tool must produce an error message. + + The UserError raised internally is caught by the failure error handler and + returned as a string so the LLM receives a meaningful error rather than + crashing the run. + """ + ctx = _make_tool_context("add") + result = await _Calculator.add.on_invoke_tool(ctx, '{"x": 1}') + assert isinstance(result, str) + assert "class instance" in result + + +# --------------------------------------------------------------------------- +# Method tool with RunContextWrapper +# --------------------------------------------------------------------------- + + +class _ContextAwareService: + def __init__(self, prefix: str) -> None: + self.prefix = prefix + + @function_tool + def greet(self, ctx: RunContextWrapper[None], name: str) -> str: + """Greet someone. + + Args: + name: The person's name. + """ + return f"{self.prefix}: hello {name}" + + +@pytest.mark.asyncio +async def test_method_tool_with_context() -> None: + """Method tool that also takes RunContextWrapper must pass ctx correctly.""" + svc = _ContextAwareService(prefix="BOT") + tool_ctx = _make_tool_context("greet") + + result = await svc.greet.on_invoke_tool(tool_ctx, '{"name": "Alice"}') + assert result == "BOT: hello Alice" + + +# --------------------------------------------------------------------------- +# Only the leading self/cls is stripped (not all params named self) +# --------------------------------------------------------------------------- + + +def test_only_leading_self_is_stripped() -> None: + """A parameter named 'self' that is NOT the first parameter must appear in the schema.""" + + class _Tricky: + def method(self, value: int, self_count: int = 0) -> int: + """Do something. + + Args: + value: Main value. + self_count: Extra count (not a receiver). + """ + return value + self_count + + schema = function_schema(_Tricky.method) + props = schema.params_json_schema.get("properties", {}) + assert "self" not in props + assert "value" in props + assert "self_count" in props + + +# --------------------------------------------------------------------------- +# Decorator with arguments still works for methods +# --------------------------------------------------------------------------- + + +class _Described: + @function_tool(name_override="my_described_tool", description_override="A described tool.") + def compute(self, n: int) -> int: + """Fallback docstring. + + Args: + n: Input. + """ + return n * 2 + + +def test_function_tool_with_args_on_method() -> None: + assert _Described.compute.name == "my_described_tool" + assert _Described.compute.description == "A described tool." + + +@pytest.mark.asyncio +async def test_function_tool_with_args_on_method_binding() -> None: + obj = _Described() + ctx = _make_tool_context("my_described_tool") + result = await obj.compute.on_invoke_tool(ctx, '{"n": 4}') + assert result == 8