Skip to content

Commit 0477032

Browse files
committed
Support function_tool on instance methods
1 parent 65774ce commit 0477032

4 files changed

Lines changed: 128 additions & 12 deletions

File tree

docs/tools.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,25 @@ for tool in agent.tools:
309309
3. Functions can optionally take the `context` (must be the first argument). You can also set overrides, like the name of the tool, description, which docstring style to use, etc.
310310
4. You can pass the decorated functions to the list of tools.
311311

312+
You can also decorate instance methods. Access the tool from an instance before passing it to
313+
`Agent.tools`; the implicit `self` parameter is bound to that instance and omitted from the tool
314+
schema.
315+
316+
```python
317+
class CustomerTools:
318+
def __init__(self, tenant_id: str) -> None:
319+
self.tenant_id = tenant_id
320+
321+
@function_tool
322+
def lookup_customer(self, customer_id: str) -> str:
323+
"""Look up a customer by ID."""
324+
return f"{self.tenant_id}:{customer_id}"
325+
326+
327+
customer_tools = CustomerTools("tenant_123")
328+
agent = Agent(name="Assistant", tools=[customer_tools.lookup_customer])
329+
```
330+
312331
??? note "Expand to see output"
313332

314333
```

src/agents/function_schema.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class FuncSchema:
4040
strict_json_schema: bool = True
4141
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
4242
as it increases the likelihood of correct JSON input."""
43+
omitted_parameter_names: tuple[str, ...] = ()
44+
"""Parameter names that are supplied by the SDK instead of model-generated JSON."""
4345

4446
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
4547
"""
@@ -52,6 +54,8 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
5254

5355
# Use enumerate() so we can skip the first parameter if it's context.
5456
for idx, (name, param) in enumerate(self.signature.parameters.items()):
57+
if name in self.omitted_parameter_names:
58+
continue
5559
# If the function takes a RunContextWrapper and this is the first parameter, skip it.
5660
if self.takes_context and idx == 0:
5761
continue
@@ -228,6 +232,7 @@ def function_schema(
228232
description_override: str | None = None,
229233
use_docstring_info: bool = True,
230234
strict_json_schema: bool = True,
235+
skip_first_parameter: bool = False,
231236
) -> FuncSchema:
232237
"""
233238
Given a Python function, extracts a `FuncSchema` from it, capturing the name, description,
@@ -246,6 +251,8 @@ def function_schema(
246251
the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
247252
recommend setting this to True, as it increases the likelihood of the LLM producing
248253
correct JSON input.
254+
skip_first_parameter: If True, omit the first signature parameter from the tool schema and
255+
call arguments. This is used for instance methods decorated with `@function_tool`.
249256
250257
Returns:
251258
A `FuncSchema` object containing the function's name, description, parameter descriptions,
@@ -288,22 +295,29 @@ def function_schema(
288295
params = list(sig.parameters.items())
289296
takes_context = False
290297
filtered_params = []
298+
omitted_parameter_names: list[str] = []
299+
300+
params_to_check = params
301+
if skip_first_parameter and params:
302+
omitted_parameter_names.append(params[0][0])
303+
params_to_check = params[1:]
291304

292-
if params:
293-
first_name, first_param = params[0]
305+
if params_to_check:
306+
first_name, first_param = params_to_check[0]
294307
# Prefer the evaluated type hint if available
295308
ann = type_hints.get(first_name, first_param.annotation)
296309
if ann != inspect._empty:
297310
origin = get_origin(ann) or ann
298311
if origin is RunContextWrapper or origin is ToolContext:
299312
takes_context = True # Mark that the function takes context
313+
omitted_parameter_names.append(first_name)
300314
else:
301315
filtered_params.append((first_name, first_param))
302316
else:
303317
filtered_params.append((first_name, first_param))
304318

305319
# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
306-
for name, param in params[1:]:
320+
for name, param in params_to_check[1:]:
307321
ann = type_hints.get(name, param.annotation)
308322
if ann != inspect._empty:
309323
origin = get_origin(ann) or ann
@@ -421,4 +435,5 @@ def function_schema(
421435
signature=sig,
422436
takes_context=takes_context,
423437
strict_json_schema=strict_json_schema,
438+
omitted_parameter_names=tuple(omitted_parameter_names),
424439
)

src/agents/tool.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,18 @@ class FunctionTool:
389389
_emit_tool_origin: bool = field(default=True, kw_only=True, repr=False)
390390
"""Whether runtime item generation should emit tool origin metadata for this tool."""
391391

392+
_method_tool_factory: Callable[[Any], FunctionTool] | None = field(
393+
default=None,
394+
kw_only=True,
395+
repr=False,
396+
)
397+
"""Internal descriptor hook used for instance methods decorated with `@function_tool`."""
398+
399+
def __get__(self, instance: Any, owner: type[Any] | None = None) -> FunctionTool:
400+
if instance is None or self._method_tool_factory is None:
401+
return self
402+
return self._method_tool_factory(instance)
403+
392404
@property
393405
def qualified_name(self) -> str:
394406
"""Return the public qualified name used to identify this function tool."""
@@ -1827,18 +1839,33 @@ def function_tool(
18271839
explicitly loads it.
18281840
"""
18291841

1830-
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
1842+
def _is_instance_method_tool(the_func: ToolFunction[...]) -> bool:
1843+
parameters = tuple(inspect.signature(the_func).parameters.values())
1844+
return bool(parameters) and parameters[0].name == "self"
1845+
1846+
def _create_function_tool(
1847+
the_func: ToolFunction[...],
1848+
*,
1849+
method_tool_instance: Any | None = None,
1850+
) -> FunctionTool:
18311851
is_sync_function_tool = not inspect.iscoroutinefunction(the_func)
1852+
is_instance_method_tool = _is_instance_method_tool(the_func)
18321853
schema = function_schema(
18331854
func=the_func,
18341855
name_override=name_override,
18351856
description_override=description_override,
18361857
docstring_style=docstring_style,
18371858
use_docstring_info=use_docstring_info,
18381859
strict_json_schema=strict_mode,
1860+
skip_first_parameter=is_instance_method_tool,
18391861
)
18401862

18411863
async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
1864+
if is_instance_method_tool and method_tool_instance is None:
1865+
raise UserError(
1866+
f"Instance method tool {schema.name} must be accessed from an instance"
1867+
)
1868+
18421869
tool_name = ctx.tool_name
18431870
json_data = _parse_function_tool_json_input(tool_name=tool_name, input_json=input)
18441871
_log_function_tool_invocation(tool_name=tool_name, input_json=input)
@@ -1857,16 +1884,16 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
18571884
if not _debug.DONT_LOG_TOOL_DATA:
18581885
logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}")
18591886

1887+
leading_args: list[Any] = []
1888+
if is_instance_method_tool:
1889+
leading_args.append(method_tool_instance)
1890+
if schema.takes_context:
1891+
leading_args.append(ctx)
1892+
18601893
if not is_sync_function_tool:
1861-
if schema.takes_context:
1862-
result = await the_func(ctx, *args, **kwargs_dict)
1863-
else:
1864-
result = await the_func(*args, **kwargs_dict)
1894+
result = await the_func(*leading_args, *args, **kwargs_dict)
18651895
else:
1866-
if schema.takes_context:
1867-
result = await asyncio.to_thread(the_func, ctx, *args, **kwargs_dict)
1868-
else:
1869-
result = await asyncio.to_thread(the_func, *args, **kwargs_dict)
1896+
result = await asyncio.to_thread(the_func, *leading_args, *args, **kwargs_dict)
18701897

18711898
if _debug.DONT_LOG_TOOL_DATA:
18721899
logger.debug(f"Tool {tool_name} completed.")
@@ -1897,6 +1924,11 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
18971924
defer_loading=defer_loading,
18981925
sync_invoker=is_sync_function_tool,
18991926
)
1927+
if is_instance_method_tool and method_tool_instance is None:
1928+
function_tool._method_tool_factory = lambda instance: _create_function_tool(
1929+
the_func,
1930+
method_tool_instance=instance,
1931+
)
19001932
return function_tool
19011933

19021934
# If func is actually a callable, we were used as @function_tool with no parentheses

tests/test_function_tool.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,56 @@ async def test_simple_function():
153153
)
154154

155155

156+
@pytest.mark.asyncio
157+
async def test_instance_method_function_tool_binds_self():
158+
class AccountTools:
159+
def __init__(self, prefix: str) -> None:
160+
self.prefix = prefix
161+
162+
@function_tool
163+
def lookup(self, account_id: str) -> str:
164+
"""Look up an account."""
165+
return f"{self.prefix}:{account_id}"
166+
167+
tools = AccountTools("acct")
168+
tool = tools.lookup
169+
170+
assert isinstance(AccountTools.lookup, FunctionTool)
171+
assert tool.name == "lookup"
172+
assert "self" not in tool.params_json_schema["properties"]
173+
assert "account_id" in tool.params_json_schema["properties"]
174+
175+
result = await tool.on_invoke_tool(
176+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""),
177+
'{"account_id": "123"}',
178+
)
179+
180+
assert result == "acct:123"
181+
182+
183+
@pytest.mark.asyncio
184+
async def test_instance_method_function_tool_supports_context_after_self():
185+
class AccountTools:
186+
@function_tool
187+
def lookup(self, ctx: ToolContext[str], account_id: str) -> str:
188+
"""Look up an account with context."""
189+
return f"{ctx.context}:{account_id}"
190+
191+
tools = AccountTools()
192+
tool = tools.lookup
193+
194+
assert "self" not in tool.params_json_schema["properties"]
195+
assert "ctx" not in tool.params_json_schema["properties"]
196+
assert "account_id" in tool.params_json_schema["properties"]
197+
198+
result = await tool.on_invoke_tool(
199+
ToolContext("tenant", tool_name=tool.name, tool_call_id="1", tool_arguments=""),
200+
'{"account_id": "123"}',
201+
)
202+
203+
assert result == "tenant:123"
204+
205+
156206
@pytest.mark.asyncio
157207
async def test_sync_function_runs_via_to_thread(monkeypatch: pytest.MonkeyPatch) -> None:
158208
calls = {"to_thread": 0, "func": 0}

0 commit comments

Comments
 (0)