Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class FuncSchema:
"""The signature of the function."""
takes_context: bool = False
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
skips_receiver: bool = False
"""Whether the function's leading ``self`` or ``cls`` parameter was stripped from the schema.
When True, the tool is a *method tool* and must be called with a receiver prepended to the
argument list (see :meth:`FunctionTool.__get__`)."""
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
Expand Down Expand Up @@ -286,23 +290,46 @@ def function_schema(
sig = inspect.signature(func)
params = list(sig.parameters.items())
takes_context = False
skips_receiver = False
filtered_params = []
# Index into `params` where non-receiver, non-context processing begins.
_params_start = 0

if params:
first_name, first_param = params[0]
# Prefer the evaluated type hint if available
ann = type_hints.get(first_name, first_param.annotation)
if ann != inspect._empty:
if ann == inspect._empty and first_name in ("self", "cls"):
# Unannotated self/cls → this is an instance or class method receiver.
# Exclude it from the schema so the LLM never sees it; the tool's __get__
# descriptor will supply the receiver at call time.
skips_receiver = True
_params_start = 1
elif ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper or origin is ToolContext:
takes_context = True # Mark that the function takes context
_params_start = 1
else:
filtered_params.append((first_name, first_param))
_params_start = 1
else:
filtered_params.append((first_name, first_param))
_params_start = 1

# When the first param is a method receiver, the *next* param may be a context arg.
if skips_receiver and len(params) > 1:
second_name, second_param = params[1]
second_ann = type_hints.get(second_name, second_param.annotation)
if second_ann != inspect._empty:
origin = get_origin(second_ann) or second_ann
if origin is RunContextWrapper or origin is ToolContext:
takes_context = True
_params_start = 2

# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
for name, param in params[1:]:
# For parameters beyond the first (and optional context), raise an error if any use
# RunContextWrapper or ToolContext in an unsupported position.
for name, param in params[_params_start:]:
ann = type_hints.get(name, param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
Expand All @@ -313,6 +340,12 @@ def function_schema(
)
filtered_params.append((name, param))

# If this is a method, strip the receiver from the stored signature so that
# to_call_args() never attempts to populate self/cls from LLM-supplied JSON.
if skips_receiver:
receiver_name = params[0][0]
sig = sig.replace(parameters=[p for n, p in sig.parameters.items() if n != receiver_name])

# We will collect field definitions for create_model as a dict:
# field_name -> (type_annotation, default_value_or_Field(...))
fields: dict[str, Any] = {}
Expand Down Expand Up @@ -419,5 +452,6 @@ def function_schema(
params_json_schema=json_schema,
signature=sig,
takes_context=takes_context,
skips_receiver=skips_receiver,
strict_json_schema=strict_json_schema,
)
122 changes: 117 additions & 5 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,31 @@ def __copy__(self) -> FunctionTool:
setattr(copied_tool, attr_name, attr_value)
return copied_tool

def __get__(self, obj: Any, objtype: Any = None) -> FunctionTool:
"""Descriptor protocol: bind this method tool to a class instance.

When a :func:`function_tool`-decorated method is accessed on an instance
(e.g. ``my_instance.my_tool``), this returns a new :class:`FunctionTool`
whose invocation automatically prepends ``my_instance`` as the receiver,
so the underlying method receives the correct ``self``/``cls`` argument.

Accessing the tool on the *class* (``MyClass.my_tool``) returns the
unbound :class:`FunctionTool` unchanged.
"""
if obj is None:
# Class-level access — return the unbound tool descriptor.
return self
make_impl = getattr(self, "_make_impl", None)
if make_impl is None:
# Not a method tool; behave as a plain attribute (no binding needed).
return self
# Build a copy and rewire its invoker to use the bound receiver.
bound_tool = copy.copy(self)
handler = bound_tool.on_invoke_tool
if isinstance(handler, _FailureHandlingFunctionToolInvoker):
handler._invoke_tool_impl = make_impl(obj)
return bound_tool


class _FailureHandlingFunctionToolInvoker:
"""Internal callable that rebinds wrapper error handling for copied FunctionTools."""
Expand Down Expand Up @@ -1669,6 +1694,97 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
strict_json_schema=strict_mode,
)

_on_handled_error = _build_handled_function_tool_error_handler(
span_message="Error running tool (non-fatal)",
span_message_for_json_decode_error="Error running tool",
log_label="Tool",
)

if schema.skips_receiver:
# The decorated function is an unbound instance/class method. We
# store a factory (_make_impl) on the returned FunctionTool so that
# the __get__ descriptor can produce a correctly-bound invoker when
# the tool is accessed via a class instance.
def _make_impl(
receiver: Any,
) -> Callable[[ToolContext[Any], str], Awaitable[Any]]:
async def _method_invoke_impl(ctx: ToolContext[Any], input: str) -> Any:
tool_name = ctx.tool_name
json_data = _parse_function_tool_json_input(
tool_name=tool_name, input_json=input
)
_log_function_tool_invocation(tool_name=tool_name, input_json=input)

try:
parsed = (
schema.params_pydantic_model(**json_data)
if json_data
else schema.params_pydantic_model()
)
except ValidationError as e:
raise ModelBehaviorError(
f"Invalid JSON input for tool {tool_name}: {e}"
) from e

args, kwargs_dict = schema.to_call_args(parsed)

if not _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}")

if receiver is None:
raise UserError(
f"Tool '{schema.name}' was decorated on a class method and must be "
f"accessed via a class instance before being invoked. "
f"Use 'instance.{schema.name}' or bind the tool with "
f"'tool.__get__(instance)' before adding it to an agent."
)

if not is_sync_function_tool:
if schema.takes_context:
result = await the_func(receiver, ctx, *args, **kwargs_dict)
else:
result = await the_func(receiver, *args, **kwargs_dict)
else:
if schema.takes_context:
result = await asyncio.to_thread(
the_func, receiver, ctx, *args, **kwargs_dict
)
else:
result = await asyncio.to_thread(
the_func, receiver, *args, **kwargs_dict
)

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Tool {tool_name} completed.")
else:
logger.debug(f"Tool {tool_name} returned {result}")

return result

return _method_invoke_impl

function_tool = _build_wrapped_function_tool(
name=schema.name,
description=schema.description or "",
params_json_schema=schema.params_json_schema,
invoke_tool_impl=_make_impl(None), # unbound placeholder
on_handled_error=_on_handled_error,
failure_error_function=failure_error_function,
strict_json_schema=strict_mode,
is_enabled=is_enabled,
needs_approval=needs_approval,
tool_input_guardrails=tool_input_guardrails,
tool_output_guardrails=tool_output_guardrails,
timeout_seconds=timeout,
timeout_behavior=timeout_behavior,
timeout_error_function=timeout_error_function,
defer_loading=defer_loading,
sync_invoker=is_sync_function_tool,
)
# Store the factory so __get__ can bind a receiver on instance access.
function_tool._make_impl = _make_impl # type: ignore[attr-defined]
return function_tool

async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
tool_name = ctx.tool_name
json_data = _parse_function_tool_json_input(tool_name=tool_name, input_json=input)
Expand Down Expand Up @@ -1711,11 +1827,7 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
description=schema.description or "",
params_json_schema=schema.params_json_schema,
invoke_tool_impl=_on_invoke_tool_impl,
on_handled_error=_build_handled_function_tool_error_handler(
span_message="Error running tool (non-fatal)",
span_message_for_json_decode_error="Error running tool",
log_label="Tool",
),
on_handled_error=_on_handled_error,
failure_error_function=failure_error_function,
strict_json_schema=strict_mode,
is_enabled=is_enabled,
Expand Down
Loading