Skip to content

Commit 7edf47d

Browse files
m1lestonesclaude
andcommitted
feat(tools): support @function_tool on class instance methods via descriptor protocol
Fixes #94. Decorating a class method with @function_tool now works correctly: - `function_schema()` detects an unannotated leading `self` or `cls` parameter and sets `skips_receiver=True`, stripping the receiver from both the JSON schema and the stored signature so the LLM never sees it and `to_call_args()` never tries to populate it from model output. - `FunctionTool` gains a `__get__` descriptor method. Accessing a method tool on a class instance (e.g. `instance.my_tool`) returns a bound copy whose invoker prepends the instance as the first argument, so the underlying method receives the correct `self`. Accessing via the class returns the unbound tool. - A `_make_impl(receiver)` factory is stored on method tools; `__get__` calls it with the instance and wires the result into the copied tool's invoker. - A `RunContextWrapper`/`ToolContext` parameter immediately after `self`/`cls` is handled correctly (`takes_context=True`), as is the case where context is in the wrong position (still raises `UserError`). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 86739b1 commit 7edf47d

3 files changed

Lines changed: 481 additions & 8 deletions

File tree

src/agents/function_schema.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ class FuncSchema:
3636
"""The signature of the function."""
3737
takes_context: bool = False
3838
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
39+
skips_receiver: bool = False
40+
"""Whether the function's leading ``self`` or ``cls`` parameter was stripped from the schema.
41+
When True, the tool is a *method tool* and must be called with a receiver prepended to the
42+
argument list (see :meth:`FunctionTool.__get__`)."""
3943
strict_json_schema: bool = True
4044
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
4145
as it increases the likelihood of correct JSON input."""
@@ -286,23 +290,46 @@ def function_schema(
286290
sig = inspect.signature(func)
287291
params = list(sig.parameters.items())
288292
takes_context = False
293+
skips_receiver = False
289294
filtered_params = []
295+
# Index into `params` where non-receiver, non-context processing begins.
296+
_params_start = 0
290297

291298
if params:
292299
first_name, first_param = params[0]
293300
# Prefer the evaluated type hint if available
294301
ann = type_hints.get(first_name, first_param.annotation)
295-
if ann != inspect._empty:
302+
if ann == inspect._empty and first_name in ("self", "cls"):
303+
# Unannotated self/cls → this is an instance or class method receiver.
304+
# Exclude it from the schema so the LLM never sees it; the tool's __get__
305+
# descriptor will supply the receiver at call time.
306+
skips_receiver = True
307+
_params_start = 1
308+
elif ann != inspect._empty:
296309
origin = get_origin(ann) or ann
297310
if origin is RunContextWrapper or origin is ToolContext:
298311
takes_context = True # Mark that the function takes context
312+
_params_start = 1
299313
else:
300314
filtered_params.append((first_name, first_param))
315+
_params_start = 1
301316
else:
302317
filtered_params.append((first_name, first_param))
318+
_params_start = 1
319+
320+
# When the first param is a method receiver, the *next* param may be a context arg.
321+
if skips_receiver and len(params) > 1:
322+
second_name, second_param = params[1]
323+
second_ann = type_hints.get(second_name, second_param.annotation)
324+
if second_ann != inspect._empty:
325+
origin = get_origin(second_ann) or second_ann
326+
if origin is RunContextWrapper or origin is ToolContext:
327+
takes_context = True
328+
_params_start = 2
303329

304-
# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
305-
for name, param in params[1:]:
330+
# For parameters beyond the first (and optional context), raise an error if any use
331+
# RunContextWrapper or ToolContext in an unsupported position.
332+
for name, param in params[_params_start:]:
306333
ann = type_hints.get(name, param.annotation)
307334
if ann != inspect._empty:
308335
origin = get_origin(ann) or ann
@@ -313,6 +340,12 @@ def function_schema(
313340
)
314341
filtered_params.append((name, param))
315342

343+
# If this is a method, strip the receiver from the stored signature so that
344+
# to_call_args() never attempts to populate self/cls from LLM-supplied JSON.
345+
if skips_receiver:
346+
receiver_name = params[0][0]
347+
sig = sig.replace(parameters=[p for n, p in sig.parameters.items() if n != receiver_name])
348+
316349
# We will collect field definitions for create_model as a dict:
317350
# field_name -> (type_annotation, default_value_or_Field(...))
318351
fields: dict[str, Any] = {}
@@ -419,5 +452,6 @@ def function_schema(
419452
params_json_schema=json_schema,
420453
signature=sig,
421454
takes_context=takes_context,
455+
skips_receiver=skips_receiver,
422456
strict_json_schema=strict_json_schema,
423457
)

src/agents/tool.py

Lines changed: 117 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,31 @@ def __copy__(self) -> FunctionTool:
353353
setattr(copied_tool, attr_name, attr_value)
354354
return copied_tool
355355

356+
def __get__(self, obj: Any, objtype: Any = None) -> FunctionTool:
357+
"""Descriptor protocol: bind this method tool to a class instance.
358+
359+
When a :func:`function_tool`-decorated method is accessed on an instance
360+
(e.g. ``my_instance.my_tool``), this returns a new :class:`FunctionTool`
361+
whose invocation automatically prepends ``my_instance`` as the receiver,
362+
so the underlying method receives the correct ``self``/``cls`` argument.
363+
364+
Accessing the tool on the *class* (``MyClass.my_tool``) returns the
365+
unbound :class:`FunctionTool` unchanged.
366+
"""
367+
if obj is None:
368+
# Class-level access — return the unbound tool descriptor.
369+
return self
370+
make_impl = getattr(self, "_make_impl", None)
371+
if make_impl is None:
372+
# Not a method tool; behave as a plain attribute (no binding needed).
373+
return self
374+
# Build a copy and rewire its invoker to use the bound receiver.
375+
bound_tool = copy.copy(self)
376+
handler = bound_tool.on_invoke_tool
377+
if isinstance(handler, _FailureHandlingFunctionToolInvoker):
378+
handler._invoke_tool_impl = make_impl(obj)
379+
return bound_tool
380+
356381

357382
class _FailureHandlingFunctionToolInvoker:
358383
"""Internal callable that rebinds wrapper error handling for copied FunctionTools."""
@@ -1669,6 +1694,97 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
16691694
strict_json_schema=strict_mode,
16701695
)
16711696

1697+
_on_handled_error = _build_handled_function_tool_error_handler(
1698+
span_message="Error running tool (non-fatal)",
1699+
span_message_for_json_decode_error="Error running tool",
1700+
log_label="Tool",
1701+
)
1702+
1703+
if schema.skips_receiver:
1704+
# The decorated function is an unbound instance/class method. We
1705+
# store a factory (_make_impl) on the returned FunctionTool so that
1706+
# the __get__ descriptor can produce a correctly-bound invoker when
1707+
# the tool is accessed via a class instance.
1708+
def _make_impl(
1709+
receiver: Any,
1710+
) -> Callable[[ToolContext[Any], str], Awaitable[Any]]:
1711+
async def _method_invoke_impl(ctx: ToolContext[Any], input: str) -> Any:
1712+
tool_name = ctx.tool_name
1713+
json_data = _parse_function_tool_json_input(
1714+
tool_name=tool_name, input_json=input
1715+
)
1716+
_log_function_tool_invocation(tool_name=tool_name, input_json=input)
1717+
1718+
try:
1719+
parsed = (
1720+
schema.params_pydantic_model(**json_data)
1721+
if json_data
1722+
else schema.params_pydantic_model()
1723+
)
1724+
except ValidationError as e:
1725+
raise ModelBehaviorError(
1726+
f"Invalid JSON input for tool {tool_name}: {e}"
1727+
) from e
1728+
1729+
args, kwargs_dict = schema.to_call_args(parsed)
1730+
1731+
if not _debug.DONT_LOG_TOOL_DATA:
1732+
logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}")
1733+
1734+
if receiver is None:
1735+
raise UserError(
1736+
f"Tool '{schema.name}' was decorated on a class method and must be "
1737+
f"accessed via a class instance before being invoked. "
1738+
f"Use 'instance.{schema.name}' or bind the tool with "
1739+
f"'tool.__get__(instance)' before adding it to an agent."
1740+
)
1741+
1742+
if not is_sync_function_tool:
1743+
if schema.takes_context:
1744+
result = await the_func(receiver, ctx, *args, **kwargs_dict)
1745+
else:
1746+
result = await the_func(receiver, *args, **kwargs_dict)
1747+
else:
1748+
if schema.takes_context:
1749+
result = await asyncio.to_thread(
1750+
the_func, receiver, ctx, *args, **kwargs_dict
1751+
)
1752+
else:
1753+
result = await asyncio.to_thread(
1754+
the_func, receiver, *args, **kwargs_dict
1755+
)
1756+
1757+
if _debug.DONT_LOG_TOOL_DATA:
1758+
logger.debug(f"Tool {tool_name} completed.")
1759+
else:
1760+
logger.debug(f"Tool {tool_name} returned {result}")
1761+
1762+
return result
1763+
1764+
return _method_invoke_impl
1765+
1766+
function_tool = _build_wrapped_function_tool(
1767+
name=schema.name,
1768+
description=schema.description or "",
1769+
params_json_schema=schema.params_json_schema,
1770+
invoke_tool_impl=_make_impl(None), # unbound placeholder
1771+
on_handled_error=_on_handled_error,
1772+
failure_error_function=failure_error_function,
1773+
strict_json_schema=strict_mode,
1774+
is_enabled=is_enabled,
1775+
needs_approval=needs_approval,
1776+
tool_input_guardrails=tool_input_guardrails,
1777+
tool_output_guardrails=tool_output_guardrails,
1778+
timeout_seconds=timeout,
1779+
timeout_behavior=timeout_behavior,
1780+
timeout_error_function=timeout_error_function,
1781+
defer_loading=defer_loading,
1782+
sync_invoker=is_sync_function_tool,
1783+
)
1784+
# Store the factory so __get__ can bind a receiver on instance access.
1785+
function_tool._make_impl = _make_impl # type: ignore[attr-defined]
1786+
return function_tool
1787+
16721788
async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
16731789
tool_name = ctx.tool_name
16741790
json_data = _parse_function_tool_json_input(tool_name=tool_name, input_json=input)
@@ -1711,11 +1827,7 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
17111827
description=schema.description or "",
17121828
params_json_schema=schema.params_json_schema,
17131829
invoke_tool_impl=_on_invoke_tool_impl,
1714-
on_handled_error=_build_handled_function_tool_error_handler(
1715-
span_message="Error running tool (non-fatal)",
1716-
span_message_for_json_decode_error="Error running tool",
1717-
log_label="Tool",
1718-
),
1830+
on_handled_error=_on_handled_error,
17191831
failure_error_function=failure_error_function,
17201832
strict_json_schema=strict_mode,
17211833
is_enabled=is_enabled,

0 commit comments

Comments
 (0)