Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions mellea/backends/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import json
import re
from collections import defaultdict
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, Sequence
from typing import Any, Literal, ParamSpec, TypeVar, overload

from pydantic import BaseModel, ConfigDict, Field

from mellea.core.utils import MelleaLogger
from mellea.helpers.event_loop_helper import _run_async_in_thread

from ..core import CBlock, Component, TemplateRepresentation
from ..core.base import AbstractMelleaTool
Expand Down Expand Up @@ -179,17 +180,33 @@ def tool_call(*args, **kwargs):
"Please install mellea with tools support: pip install 'mellea[tools]'"
) from e

@overload
@classmethod
def from_callable(
cls, func: Callable[P, Awaitable[R]], name: str | None = None
) -> "MelleaTool[P, R]": ...

@overload
@classmethod
def from_callable(
cls, func: Callable[P, R], name: str | None = None
) -> "MelleaTool[P, R]": ...

@classmethod
def from_callable(
cls, func: Callable[P, R] | Callable[P, Awaitable[R]], name: str | None = None
) -> "MelleaTool[P, R]":
"""Create a MelleaTool from a plain Python callable.

Introspects the callable's signature and docstring to build an
OpenAI-compatible JSON schema automatically.
OpenAI-compatible JSON schema automatically. Async functions (defined
with ``async def``) are supported transparently: the coroutine is
awaited on mellea's shared event loop so sync callers of ``.run()``
receive the resolved value rather than an un-awaited coroutine.

Args:
func (Callable[P, R]): The Python callable to wrap as a tool.
func (Callable[P, R] | Callable[P, Awaitable[R]]): The Python
callable (sync or async) to wrap as a tool.
name (str | None): Optional name override; defaults to ``func.__name__``.

Returns:
Expand All @@ -200,10 +217,22 @@ def from_callable(
as_json = convert_function_to_ollama_tool(func, tool_name).model_dump(
exclude_none=True
)
tool_call = func
if inspect.iscoroutinefunction(func):
async_func = func

def tool_call(*args: P.args, **kwargs: P.kwargs) -> R:
return _run_async_in_thread(async_func(*args, **kwargs))
else:
tool_call = func # type: ignore[assignment]
return MelleaTool(tool_name, tool_call, as_json)


@overload
def tool(
func: Callable[P, Awaitable[R]], *, name: str | None = None
) -> MelleaTool[P, R]: ...


@overload
def tool(func: Callable[P, R], *, name: str | None = None) -> MelleaTool[P, R]: ...

Expand All @@ -215,7 +244,8 @@ def tool(


def tool(
func: Callable[P, R] | None = None, name: str | None = None
func: Callable[P, R] | Callable[P, Awaitable[R]] | None = None,
name: str | None = None,
) -> MelleaTool[P, R] | Callable[[Callable[P, R]], MelleaTool[P, R]]:
"""Decorator to mark a function as a Mellea tool with type-safe parameter and return types.

Expand Down Expand Up @@ -278,7 +308,7 @@ def decorator(f: Callable[P, R]) -> MelleaTool[P, R]:
return decorator
else:
# Called without arguments: @tool
return decorator(func)
return decorator(func) # type: ignore[arg-type]


def add_tools_from_model_options(
Expand Down
54 changes: 54 additions & 0 deletions test/backends/test_mellea_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def callable(input: int) -> str:
return str(input)


async def async_callable(input: int) -> str:
"""Common async callable to test tool functionality."""
return str(input)


@tool
def langchain_tool(input: int) -> str:
"""Common langchain tool to test functionality."""
Expand Down Expand Up @@ -258,5 +263,54 @@ class NotATool:
assert "smolagents Tool type" in error_msg


def test_from_callable_async():
"""Async function produces a MelleaTool whose .run() returns the awaited value."""
t = MelleaTool.from_callable(async_callable)
assert isinstance(t, MelleaTool)
assert t.name == async_callable.__name__

# .run() from sync code returns the awaited value, not a coroutine.
result = t.run(1)
assert result == "1"
assert t.run(input=2) == "2"


def test_from_callable_async_schema_matches_sync():
"""Schema introspection produces the same parameter schema for sync and async."""

def sync_fn(x: int, y: str = "default") -> str:
"""A sync function.

Args:
x: An integer.
y: A string with a default.
"""
return y * x

async def async_fn(x: int, y: str = "default") -> str:
"""A sync function.

Args:
x: An integer.
y: A string with a default.
"""
return y * x

sync_t = MelleaTool.from_callable(sync_fn, "same_name")
async_t = MelleaTool.from_callable(async_fn, "same_name")
assert async_t.as_json_tool == sync_t.as_json_tool


def test_from_callable_async_propagates_exceptions():
"""Exceptions from async functions propagate through .run() (not wrapped in a coroutine)."""

async def raiser() -> str:
raise RuntimeError("boom")

t = MelleaTool.from_callable(raiser)
with pytest.raises(RuntimeError, match="boom"):
t.run()


if __name__ == "__main__":
pytest.main([__file__])
12 changes: 12 additions & 0 deletions test/backends/test_tool_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,18 @@ def base_func(x: int) -> int:
assert tool1.name == "base_func"
assert tool2.name == "custom"

def test_decorator_on_async_function(self):
"""Test that @tool works end-to-end on an async function."""

@tool
async def decorated(input: int) -> str:
"""Async tool via decorator."""
return str(input * 2)

assert isinstance(decorated, MelleaTool)
assert decorated.name == "decorated"
assert decorated.run(3) == "6"


# ============================================================================
# Test Cases: Usage Patterns
Expand Down
34 changes: 34 additions & 0 deletions test/typing/check_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,37 @@ def sample_func(x: int) -> str:
# Verify the return type is preserved through .run()
output = result.run(42)
assert_type(output, str)


# Test async support: from_callable and @tool should narrow Awaitable[R] to R
async def async_plain(a: str, b: int) -> list[str]:
"""An async plain function to wrap."""
return [a] * b


def check_from_callable_async_return_type() -> None:
"""Verify MelleaTool.from_callable narrows Awaitable[R] to R on .run()."""
wrapped = MelleaTool.from_callable(async_plain)
result = wrapped.run("test", 3)
# Same classmethod+generic inference limitation as the sync from_callable checks
# above; use an assignment to verify awaited-type compatibility.
_: list[str] = result # type: ignore[assignment]


def check_from_callable_async_with_name() -> None:
"""Verify async overload narrows when a custom name is supplied."""
wrapped = MelleaTool.from_callable(async_plain, name="custom")
result = wrapped.run("test", 3)
_: list[str] = result # type: ignore[assignment]


@tool
async def decorated_async(x: int) -> str:
"""Async function wrapped via the @tool decorator."""
return str(x)


def check_tool_decorator_async() -> None:
"""Verify @tool on an async function narrows to the awaited return type."""
result = decorated_async.run(42)
assert_type(result, str)
Loading