Skip to content

Commit e522fc7

Browse files
authored
feat: allow async functions as tools (#1041)
Detect coroutine functions in MelleaTool.from_callable and wrap them through mellea's shared event loop so sync .run() callers receive the resolved value rather than an un-awaited coroutine. Add overloads on both from_callable and the @tool decorator so Callable[P, Awaitable[R]] narrows to MelleaTool[P, R]. Closes part of #1032. Assisted-by: Claude Code Signed-off-by: Alex Bozarth <ajbozart@us.ibm.com>
1 parent 1fdb2c1 commit e522fc7

4 files changed

Lines changed: 136 additions & 6 deletions

File tree

mellea/backends/tools.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import json
1313
import re
1414
from collections import defaultdict
15-
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
15+
from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, Sequence
1616
from typing import Any, Literal, ParamSpec, TypeVar, overload
1717

1818
from pydantic import BaseModel, ConfigDict, Field
1919

2020
from mellea.core.utils import MelleaLogger
21+
from mellea.helpers.event_loop_helper import _run_async_in_thread
2122

2223
from ..core import CBlock, Component, TemplateRepresentation
2324
from ..core.base import AbstractMelleaTool
@@ -179,17 +180,33 @@ def tool_call(*args, **kwargs):
179180
"Please install mellea with tools support: pip install 'mellea[tools]'"
180181
) from e
181182

183+
@overload
184+
@classmethod
185+
def from_callable(
186+
cls, func: Callable[P, Awaitable[R]], name: str | None = None
187+
) -> "MelleaTool[P, R]": ...
188+
189+
@overload
182190
@classmethod
183191
def from_callable(
184192
cls, func: Callable[P, R], name: str | None = None
193+
) -> "MelleaTool[P, R]": ...
194+
195+
@classmethod
196+
def from_callable(
197+
cls, func: Callable[P, R] | Callable[P, Awaitable[R]], name: str | None = None
185198
) -> "MelleaTool[P, R]":
186199
"""Create a MelleaTool from a plain Python callable.
187200
188201
Introspects the callable's signature and docstring to build an
189-
OpenAI-compatible JSON schema automatically.
202+
OpenAI-compatible JSON schema automatically. Async functions (defined
203+
with ``async def``) are supported transparently: the coroutine is
204+
awaited on mellea's shared event loop so sync callers of ``.run()``
205+
receive the resolved value rather than an un-awaited coroutine.
190206
191207
Args:
192-
func (Callable[P, R]): The Python callable to wrap as a tool.
208+
func (Callable[P, R] | Callable[P, Awaitable[R]]): The Python
209+
callable (sync or async) to wrap as a tool.
193210
name (str | None): Optional name override; defaults to ``func.__name__``.
194211
195212
Returns:
@@ -200,10 +217,22 @@ def from_callable(
200217
as_json = convert_function_to_ollama_tool(func, tool_name).model_dump(
201218
exclude_none=True
202219
)
203-
tool_call = func
220+
if inspect.iscoroutinefunction(func):
221+
async_func = func
222+
223+
def tool_call(*args: P.args, **kwargs: P.kwargs) -> R:
224+
return _run_async_in_thread(async_func(*args, **kwargs))
225+
else:
226+
tool_call = func # type: ignore[assignment]
204227
return MelleaTool(tool_name, tool_call, as_json)
205228

206229

230+
@overload
231+
def tool(
232+
func: Callable[P, Awaitable[R]], *, name: str | None = None
233+
) -> MelleaTool[P, R]: ...
234+
235+
207236
@overload
208237
def tool(func: Callable[P, R], *, name: str | None = None) -> MelleaTool[P, R]: ...
209238

@@ -215,7 +244,8 @@ def tool(
215244

216245

217246
def tool(
218-
func: Callable[P, R] | None = None, name: str | None = None
247+
func: Callable[P, R] | Callable[P, Awaitable[R]] | None = None,
248+
name: str | None = None,
219249
) -> MelleaTool[P, R] | Callable[[Callable[P, R]], MelleaTool[P, R]]:
220250
"""Decorator to mark a function as a Mellea tool with type-safe parameter and return types.
221251
@@ -278,7 +308,7 @@ def decorator(f: Callable[P, R]) -> MelleaTool[P, R]:
278308
return decorator
279309
else:
280310
# Called without arguments: @tool
281-
return decorator(func)
311+
return decorator(func) # type: ignore[arg-type]
282312

283313

284314
def add_tools_from_model_options(

test/backends/test_mellea_tool.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def callable(input: int) -> str:
1313
return str(input)
1414

1515

16+
async def async_callable(input: int) -> str:
17+
"""Common async callable to test tool functionality."""
18+
return str(input)
19+
20+
1621
@tool
1722
def langchain_tool(input: int) -> str:
1823
"""Common langchain tool to test functionality."""
@@ -258,5 +263,54 @@ class NotATool:
258263
assert "smolagents Tool type" in error_msg
259264

260265

266+
def test_from_callable_async():
267+
"""Async function produces a MelleaTool whose .run() returns the awaited value."""
268+
t = MelleaTool.from_callable(async_callable)
269+
assert isinstance(t, MelleaTool)
270+
assert t.name == async_callable.__name__
271+
272+
# .run() from sync code returns the awaited value, not a coroutine.
273+
result = t.run(1)
274+
assert result == "1"
275+
assert t.run(input=2) == "2"
276+
277+
278+
def test_from_callable_async_schema_matches_sync():
279+
"""Schema introspection produces the same parameter schema for sync and async."""
280+
281+
def sync_fn(x: int, y: str = "default") -> str:
282+
"""A sync function.
283+
284+
Args:
285+
x: An integer.
286+
y: A string with a default.
287+
"""
288+
return y * x
289+
290+
async def async_fn(x: int, y: str = "default") -> str:
291+
"""A sync function.
292+
293+
Args:
294+
x: An integer.
295+
y: A string with a default.
296+
"""
297+
return y * x
298+
299+
sync_t = MelleaTool.from_callable(sync_fn, "same_name")
300+
async_t = MelleaTool.from_callable(async_fn, "same_name")
301+
assert async_t.as_json_tool == sync_t.as_json_tool
302+
303+
304+
def test_from_callable_async_propagates_exceptions():
305+
"""Exceptions from async functions propagate through .run() (not wrapped in a coroutine)."""
306+
307+
async def raiser() -> str:
308+
raise RuntimeError("boom")
309+
310+
t = MelleaTool.from_callable(raiser)
311+
with pytest.raises(RuntimeError, match="boom"):
312+
t.run()
313+
314+
261315
if __name__ == "__main__":
262316
pytest.main([__file__])

test/backends/test_tool_decorator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,18 @@ def base_func(x: int) -> int:
268268
assert tool1.name == "base_func"
269269
assert tool2.name == "custom"
270270

271+
def test_decorator_on_async_function(self):
272+
"""Test that @tool works end-to-end on an async function."""
273+
274+
@tool
275+
async def decorated(input: int) -> str:
276+
"""Async tool via decorator."""
277+
return str(input * 2)
278+
279+
assert isinstance(decorated, MelleaTool)
280+
assert decorated.name == "decorated"
281+
assert decorated.run(3) == "6"
282+
271283

272284
# ============================================================================
273285
# Test Cases: Usage Patterns

test/typing/check_tools.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,37 @@ def sample_func(x: int) -> str:
159159
# Verify the return type is preserved through .run()
160160
output = result.run(42)
161161
assert_type(output, str)
162+
163+
164+
# Test async support: from_callable and @tool should narrow Awaitable[R] to R
165+
async def async_plain(a: str, b: int) -> list[str]:
166+
"""An async plain function to wrap."""
167+
return [a] * b
168+
169+
170+
def check_from_callable_async_return_type() -> None:
171+
"""Verify MelleaTool.from_callable narrows Awaitable[R] to R on .run()."""
172+
wrapped = MelleaTool.from_callable(async_plain)
173+
result = wrapped.run("test", 3)
174+
# Same classmethod+generic inference limitation as the sync from_callable checks
175+
# above; use an assignment to verify awaited-type compatibility.
176+
_: list[str] = result # type: ignore[assignment]
177+
178+
179+
def check_from_callable_async_with_name() -> None:
180+
"""Verify async overload narrows when a custom name is supplied."""
181+
wrapped = MelleaTool.from_callable(async_plain, name="custom")
182+
result = wrapped.run("test", 3)
183+
_: list[str] = result # type: ignore[assignment]
184+
185+
186+
@tool
187+
async def decorated_async(x: int) -> str:
188+
"""Async function wrapped via the @tool decorator."""
189+
return str(x)
190+
191+
192+
def check_tool_decorator_async() -> None:
193+
"""Verify @tool on an async function narrows to the awaited return type."""
194+
result = decorated_async.run(42)
195+
assert_type(result, str)

0 commit comments

Comments
 (0)