Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 111 additions & 2 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,100 @@


class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider encapsulating all handoff-limit logic into a dedicated helper (e.g., a HandoffLimiter) instead of multiple executor classmethods to keep responsibilities localized and the API clearer.

You can keep the new functionality but reduce complexity by encapsulating all “handoff limit” concerns into a small dedicated helper, instead of spreading it across multiple generic classmethods on FunctionToolExecutor.

1. Encapsulate handoff-limiting logic

Move _resolve_handoff_call_limit, _get_event_extra, _set_event_extra, _coerce_int, and _check_and_increment_handoff_calls into a focused helper (or inner class) that holds the event/config context. That removes the cross-cutting helpers from the executor and turns the (bool, int) tuple into a clearer, self-documenting API.

Example:

@dataclass
class HandoffLimiter:
    ctx: AstrAgentContext
    event: Any
    max_calls: int

    _HANDOFF_CALL_COUNT_EXTRA_KEY = "_subagent_handoff_call_count"
    _DEFAULT_MAX_HANDOFF_CALLS_PER_RUN = 8
    _MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT = 10_000

    @classmethod
    def from_run_context(cls, run_context: ContextWrapper[AstrAgentContext]) -> "HandoffLimiter":
        ctx = run_context.context.context
        event = run_context.context.event
        cfg = ctx.get_config(umo=event.unified_msg_origin)
        subagent_cfg = cfg.get("subagent_orchestrator", {})
        max_calls = cls._coerce_int(
            subagent_cfg.get("max_handoff_calls_per_run", cls._DEFAULT_MAX_HANDOFF_CALLS_PER_RUN),
            default=cls._DEFAULT_MAX_HANDOFF_CALLS_PER_RUN,
            minimum=1,
            maximum=128,
        )
        return cls(ctx=ctx, event=event, max_calls=max_calls)

    def try_increment(self) -> tuple[bool, int]:
        current = self._coerce_int(
            self._get_event_extra(self._HANDOFF_CALL_COUNT_EXTRA_KEY, 0),
            default=0,
            minimum=0,
            maximum=self._MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT,
        )
        if current >= self.max_calls:
            return False, self.max_calls
        if not self._set_event_extra(self._HANDOFF_CALL_COUNT_EXTRA_KEY, current + 1):
            logger.warning(
                "Failed to persist handoff call counter `%s`; reject delegation to fail closed.",
                self._HANDOFF_CALL_COUNT_EXTRA_KEY,
            )
            return False, self.max_calls
        return True, self.max_calls

    # keep your existing _coerce_int/_get_event_extra/_set_event_extra logic here as instance methods

This keeps all existing behavior (including the defensive get_extra / set_extra handling and coercion) but localizes it to a single, discoverable place.

2. Simplify _execute_handoff call site

With the helper, _execute_handoff no longer needs to know how counting works or deal with a raw (bool, int) from a distant helper; it just asks the limiter.

@classmethod
async def _execute_handoff(
    cls,
    tool: HandoffTool,
    run_context: ContextWrapper[AstrAgentContext],
    *,
    image_urls_prepared: bool = False,
    **tool_args: Any,
):
    tool_args = dict(tool_args)
    input_ = tool_args.get("input")

    limiter = HandoffLimiter.from_run_context(run_context)
    allowed, max_handoff_calls = limiter.try_increment()
    if not allowed:
        yield mcp.types.CallToolResult(
            content=[
                mcp.types.TextContent(
                    type="text",
                    text=(
                        "error: handoff_call_limit_reached. "
                        f"max_handoff_calls_per_run={max_handoff_calls}. "
                        "Stop delegating and continue with current context."
                    ),
                )
            ]
        )
        return

    ctx = run_context.context.context
    event = run_context.context.event
    ...

Actionable benefits:

  • All handoff-limit concerns live in HandoffLimiter, not scattered across classmethods on the executor.
  • _execute_handoff reads linearly: “build args → check limiter → maybe early-return → continue.”
  • You retain the current, defensive behavior around config, event extras, and fail-closed semantics without reverting the feature.

Comment thread
sourcery-ai[bot] marked this conversation as resolved.
_HANDOFF_CALL_COUNT_EXTRA_KEY = "_subagent_handoff_call_count"
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider replacing the multiple generic handoff counter helpers with a single, concrete helper that encapsulates limit resolution, counter incrementing, and fail-closed behavior, and then calling it directly from _execute_handoff for clearer intent.

You can simplify this without losing any behavior by collapsing the fragmented helpers into a single, concrete handoff counter helper and assuming the concrete event API.

1. Collapse _resolve_handoff_call_limit + _get_event_extra + _set_event_extra into one focused helper

Instead of _resolve_handoff_call_limit, _check_and_increment_handoff_calls, _get_event_extra, and _set_event_extra, you can have a single method that:

  • Reads config and clamps it.
  • Reads the counter from event.get_extra.
  • Increments and writes it back with event.set_extra.
  • Fails closed if anything looks wrong.

For example:

class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
    _HANDOFF_CALL_COUNT_EXTRA_KEY = "_subagent_handoff_call_count"
    _DEFAULT_MAX_HANDOFF_CALLS_PER_RUN = 8
    _MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT = 10_000

    @classmethod
    def _increment_handoff_call_count_or_reject(
        cls,
        run_context: ContextWrapper[AstrAgentContext],
    ) -> tuple[bool, int]:
        ctx = run_context.context.context
        event = run_context.context.event

        cfg = ctx.get_config(umo=event.unified_msg_origin)
        subagent_cfg = cfg.get("subagent_orchestrator", {})
        if not isinstance(subagent_cfg, dict):
            max_handoff_calls = cls._DEFAULT_MAX_HANDOFF_CALLS_PER_RUN
        else:
            raw_limit = subagent_cfg.get(
                "max_handoff_calls_per_run",
                cls._DEFAULT_MAX_HANDOFF_CALLS_PER_RUN,
            )
            try:
                parsed = int(raw_limit)
            except (TypeError, ValueError):
                parsed = cls._DEFAULT_MAX_HANDOFF_CALLS_PER_RUN
            max_handoff_calls = max(1, min(128, parsed))

        try:
            current = event.get_extra(cls._HANDOFF_CALL_COUNT_EXTRA_KEY, 0)
        except TypeError:
            # Fail closed if the event API is not as expected
            logger.warning(
                "Failed to read handoff call counter `%s`; reject delegation to fail closed.",
                cls._HANDOFF_CALL_COUNT_EXTRA_KEY,
            )
            return False, max_handoff_calls

        try:
            current_int = int(current)
        except (TypeError, ValueError):
            current_int = 0

        current_int = max(
            0,
            min(cls._MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT, current_int),
        )

        if current_int >= max_handoff_calls:
            return False, max_handoff_calls

        try:
            event.set_extra(
                cls._HANDOFF_CALL_COUNT_EXTRA_KEY,
                current_int + 1,
            )
        except TypeError:
            logger.warning(
                "Failed to persist handoff call counter `%s`; reject delegation to fail closed.",
                cls._HANDOFF_CALL_COUNT_EXTRA_KEY,
            )
            return False, max_handoff_calls

        return True, max_handoff_calls

This keeps all the existing safety/fail-closed behavior but removes the need for:

  • _coerce_int as a fully generic helper.
  • _get_event_extra and _set_event_extra with highly flexible signatures.
  • _resolve_handoff_call_limit as a separate layer.

If you still want a reusable clamping helper, keep _coerce_int but narrow its scope (e.g., type-hint value: int | str | None and use it only inside this method).

2. Make _execute_handoff intent clearer

Use the new helper in _execute_handoff with a name that describes the full behavior, so the call site is self-explanatory and not fragmented:

@classmethod
async def _execute_handoff(
    cls,
    tool: HandoffTool,
    run_context: ContextWrapper[AstrAgentContext],
    *,
    image_urls_prepared: bool = False,
    **tool_args: T.Any,
):
    tool_args = dict(tool_args)
    input_ = tool_args.get("input")

    allowed, max_handoff_calls = cls._increment_handoff_call_count_or_reject(
        run_context
    )
    if not allowed:
        yield mcp.types.CallToolResult(
            content=[
                mcp.types.TextContent(
                    type="text",
                    text=(
                        "error: handoff_call_limit_reached. "
                        f"max_handoff_calls_per_run={max_handoff_calls}. "
                        "Stop delegating and continue with current context."
                    ),
                )
            ]
        )
        return

    ctx = run_context.context.context
    event = run_context.context.event
    ...

This reduces cognitive overhead:

  • There is a single place to inspect for “how does the handoff counter work?”
  • _execute_handoff reads linearly: “increment or reject; if allowed, proceed with handoff.”

_DEFAULT_MAX_HANDOFF_CALLS_PER_RUN = 8
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
_MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT = 10_000

@classmethod
def _coerce_int(
cls,
value: T.Any,
*,
default: int,
minimum: int,
maximum: int,
) -> int:
try:
parsed = int(value)
except (TypeError, ValueError):
return default
return max(minimum, min(maximum, parsed))

@classmethod
def _get_event_extra(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
cls,
event: T.Any,
key: str,
default: T.Any = None,
) -> T.Any:
if event is None:
return default

get_extra = getattr(event, "get_extra", None)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
if get_extra is None:
return default

try:
return get_extra(key, default)
except TypeError:
result = get_extra(key)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
return default if result is None else result

@classmethod
def _set_event_extra(cls, event: T.Any, key: str, value: T.Any) -> bool:
set_extra = getattr(event, "set_extra", None)
if set_extra is None:
return False
try:
set_extra(key, value)
return True
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
except TypeError:
return False

@classmethod
def _resolve_handoff_call_limit(
cls,
run_context: ContextWrapper[AstrAgentContext],
) -> int:
ctx = run_context.context.context
event = run_context.context.event
cfg = ctx.get_config(umo=event.unified_msg_origin)
subagent_cfg = cfg.get("subagent_orchestrator", {})
if not isinstance(subagent_cfg, dict):
return cls._DEFAULT_MAX_HANDOFF_CALLS_PER_RUN
return cls._coerce_int(
subagent_cfg.get(
"max_handoff_calls_per_run",
cls._DEFAULT_MAX_HANDOFF_CALLS_PER_RUN,
),
default=cls._DEFAULT_MAX_HANDOFF_CALLS_PER_RUN,
minimum=1,
maximum=128,
)
Comment on lines +119 to +137
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value 8 for the handoff call limit is repeated multiple times in this method. To improve maintainability and avoid magic numbers, it's a good practice to define it as a constant and reuse it. You could define it as a local variable within the method, or preferably as a class-level constant.

Suggested change
def _resolve_handoff_call_limit(
cls,
run_context: ContextWrapper[AstrAgentContext],
) -> int:
ctx = run_context.context.context
event = run_context.context.event
cfg = ctx.get_config(umo=event.unified_msg_origin)
subagent_cfg = cfg.get("subagent_orchestrator", {})
if not isinstance(subagent_cfg, dict):
return 8
return cls._coerce_int(
subagent_cfg.get("max_handoff_calls_per_run", 8),
default=8,
minimum=1,
maximum=128,
)
def _resolve_handoff_call_limit(
cls,
run_context: ContextWrapper[AstrAgentContext],
) -> int:
ctx = run_context.context.context
event = run_context.context.event
cfg = ctx.get_config(umo=event.unified_msg_origin)
subagent_cfg = cfg.get("subagent_orchestrator", {})
default_limit = 8
if not isinstance(subagent_cfg, dict):
return default_limit
return cls._coerce_int(
subagent_cfg.get("max_handoff_calls_per_run", default_limit),
default=default_limit,
minimum=1,
maximum=128,
)


@classmethod
def _check_and_increment_handoff_calls(
cls,
run_context: ContextWrapper[AstrAgentContext],
) -> tuple[bool, int]:
event = run_context.context.event
max_handoff_calls = cls._resolve_handoff_call_limit(run_context)
current_handoff_count = cls._coerce_int(
cls._get_event_extra(event, cls._HANDOFF_CALL_COUNT_EXTRA_KEY, 0),
default=0,
minimum=0,
maximum=cls._MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT,
)
if current_handoff_count >= max_handoff_calls:
return False, max_handoff_calls

cls._set_event_extra(
event,
cls._HANDOFF_CALL_COUNT_EXTRA_KEY,
current_handoff_count + 1,
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
)
return True, max_handoff_calls

@classmethod
def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]:
if image_urls_raw is None:
Expand Down Expand Up @@ -246,6 +340,23 @@ async def _execute_handoff(
):
tool_args = dict(tool_args)
input_ = tool_args.get("input")
ctx = run_context.context.context
allowed, max_handoff_calls = cls._check_and_increment_handoff_calls(run_context)
if not allowed:
yield mcp.types.CallToolResult(
content=[
mcp.types.TextContent(
type="text",
text=(
"error: handoff_call_limit_reached. "
f"max_handoff_calls_per_run={max_handoff_calls}. "
"Stop delegating and continue with current context."
),
)
]
)
return
event = run_context.context.event
if image_urls_prepared:
prepared_image_urls = tool_args.get("image_urls")
if isinstance(prepared_image_urls, list):
Expand All @@ -266,8 +377,6 @@ async def _execute_handoff(
# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)

ctx = run_context.context.context
event = run_context.context.event
umo = event.unified_msg_origin

# Use per-subagent provider override if configured; otherwise fall back
Expand Down
3 changes: 3 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@
"subagent_orchestrator": {
"main_enable": False,
"remove_main_duplicate_tools": False,
# Limits total handoff tool calls in one agent run to prevent
# runaway delegation loops.
"max_handoff_calls_per_run": 8,
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
"router_system_prompt": (
"You are a task router. Your job is to chat naturally, recognize user intent, "
"and delegate work to the most suitable subagent using transfer_to_* tools. "
Expand Down
10 changes: 10 additions & 0 deletions astrbot/dashboard/routes/subagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

from astrbot.core import logger
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle

from .route import Response, Route, RouteContext

DEFAULT_MAX_HANDOFF_CALLS_PER_RUN = int(
DEFAULT_CONFIG.get("subagent_orchestrator", {}).get("max_handoff_calls_per_run", 8)
)


class SubAgentRoute(Route):
def __init__(
Expand Down Expand Up @@ -36,6 +41,7 @@ async def get_config(self):
data = {
"main_enable": False,
"remove_main_duplicate_tools": False,
"max_handoff_calls_per_run": DEFAULT_MAX_HANDOFF_CALLS_PER_RUN,
"agents": [],
}
Comment on lines 37 to 42
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value 8 for max_handoff_calls_per_run is hardcoded here and again on line 54. This value is also defined in astrbot/core/config/default.py. To ensure consistency and avoid magic numbers, it would be better to read this default value from the DEFAULT_CONFIG object instead of hardcoding it in multiple places.


Expand All @@ -50,6 +56,10 @@ async def get_config(self):
# Ensure required keys exist.
data.setdefault("main_enable", False)
data.setdefault("remove_main_duplicate_tools", False)
data.setdefault(
"max_handoff_calls_per_run",
DEFAULT_MAX_HANDOFF_CALLS_PER_RUN,
)
data.setdefault("agents", [])

# Backward/forward compatibility: ensure each agent contains provider_id.
Expand Down
216 changes: 216 additions & 0 deletions tests/unit/test_astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,219 @@ async def _fake_convert_to_file_path(self):
)

assert image_urls == []


@pytest.mark.asyncio
async def test_execute_handoff_rejects_when_call_limit_reached():
captured: dict = {}

class _EventWithExtras:
def __init__(self) -> None:
self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
self.message_obj = SimpleNamespace(message=[])
self._extras = {
FunctionToolExecutor._HANDOFF_CALL_COUNT_EXTRA_KEY: 1,
Comment on lines +353 to +362
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add tests for default and malformed max_handoff_calls_per_run configurations

The new guardrails in _resolve_handoff_call_limit / _coerce_int (missing subagent_orchestrator, non-dict config, non-int values, and values outside [1, 128]) aren’t currently exercised.

To better validate the limit logic, consider parameterizing _execute_handoff tests to cover cases like:

  • subagent_orchestrator missing from the get_config result
  • subagent_orchestrator present but not a dict
  • max_handoff_calls_per_run set to 0, negative, or >128 and verified as clamped
  • max_handoff_calls_per_run set to a non-int string (e.g. "abc") or float-like string

These can be small @pytest.mark.asyncio tests that call _execute_handoff once and assert the effective limit (no premature handoff_call_limit_reached and proper counter increments).

Suggested implementation:

    assert image_urls == []


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "config, expected_limit",
    [
        # No subagent_orchestrator at all -> default
        ({}, 32),
        # subagent_orchestrator present but unusable -> default
        ({"subagent_orchestrator": None}, 32),
        ({"subagent_orchestrator": "not-a-dict"}, 32),
        # Values out of range should be clamped
        ({"subagent_orchestrator": {"max_handoff_calls_per_run": 0}}, 1),
        ({"subagent_orchestrator": {"max_handoff_calls_per_run": -5}}, 1),
        ({"subagent_orchestrator": {"max_handoff_calls_per_run": 129}}, 128),
        # Non-int values should fall back to default
        ({"subagent_orchestrator": {"max_handoff_calls_per_run": "abc"}}, 32),
        ({"subagent_orchestrator": {"max_handoff_calls_per_run": "3.14"}}, 32),
        # Valid in-range int should be honored
        ({"subagent_orchestrator": {"max_handoff_calls_per_run": 10}}, 10),
    ],
)
async def test_execute_handoff_resolves_call_limit_under_various_configs(
    config: dict,
    expected_limit: int,
):
    """
    Exercise _resolve_handoff_call_limit/_coerce_int indirectly via _execute_handoff.

    We run _execute_handoff exactly once and assert that:
      * the derived limit is stored on the event extras
      * the call count is incremented to 1
      * we do not prematurely hit a 'handoff_call_limit_reached' condition
    """
    # Event with extras storage, mirroring the shape used in other tests.
    class _EventWithExtras:
        def __init__(self) -> None:
            self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
            self.message_obj = SimpleNamespace(message=[])
            self._extras = {}

        @property
        def extras(self) -> dict:
            return self._extras

        def get_extra(self, key, default=None):
            return self._extras.get(key, default)

        def set_extra(self, key, value):
            self._extras[key] = value

    event = _EventWithExtras()

    # Capture whether the handoff tool actually executes.
    handoff_calls = []

    async def _fake_handoff_tool(*args, **kwargs):
        handoff_calls.append((args, kwargs))

    # The concrete constructor signature for FunctionToolExecutor may include
    # more parameters; we pass only what is relevant to handoff behavior here.
    executor = FunctionToolExecutor(
        handoff_tool=_fake_handoff_tool,
        get_config=lambda *_args, **_kwargs: config,
    )

    # Act: execute one handoff
    await executor._execute_handoff(event=event, tool_input={})

    # Assert: we didn't prematurely stop due to the limit.
    assert len(handoff_calls) == 1

    # The resolved limit should be written to extras and respect the guardrails.
    assert (
        event.extras[FunctionToolExecutor._HANDOFF_CALL_LIMIT_EXTRA_KEY]
        == expected_limit
    )

    # Call count should be incremented to 1 on the first call.
    assert (
        event.extras[FunctionToolExecutor._HANDOFF_CALL_COUNT_EXTRA_KEY]
        == 1
    )

The above tests assume:

  1. FunctionToolExecutor accepts handoff_tool and get_config keyword arguments and exposes _execute_handoff(event=..., tool_input=...). If the actual signature differs, adjust the FunctionToolExecutor(...) construction and the _execute_handoff(...) call to match the existing tests in this file.
  2. The event type used by _execute_handoff provides accessors similar to extras, get_extra, and set_extra. If the real event API is different, update _EventWithExtras to satisfy the interface required by _execute_handoff, mirroring the _EventWithExtras / event objects used in your existing test_execute_handoff_rejects_when_call_limit_reached test.
  3. The constants _HANDOFF_CALL_LIMIT_EXTRA_KEY and _HANDOFF_CALL_COUNT_EXTRA_KEY exist on FunctionToolExecutor. If they are named differently, update the assertions to use the correct attribute names.
  4. If your project enforces type hints (e.g., with mypy), you may want to refine the type annotations on config and _EventWithExtras to match your existing conventions.

}

def get_extra(self, key: str, default=None):
return self._extras.get(key, default)

def set_extra(self, key: str, value):
self._extras[key] = value

async def _fake_get_current_chat_provider_id(_umo):
return "provider-id"

async def _fake_tool_loop_agent(**kwargs):
captured.update(kwargs)
return SimpleNamespace(completion_text="ok")

context = SimpleNamespace(
get_current_chat_provider_id=_fake_get_current_chat_provider_id,
tool_loop_agent=_fake_tool_loop_agent,
get_config=lambda **_kwargs: {
"provider_settings": {},
"subagent_orchestrator": {"max_handoff_calls_per_run": 1},
},
)
event = _EventWithExtras()
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))
tool = SimpleNamespace(
name="transfer_to_subagent",
provider_id=None,
agent=SimpleNamespace(
name="subagent",
tools=[],
instructions="subagent-instructions",
begin_dialogs=[],
run_hooks=None,
),
)

results = []
async for result in FunctionToolExecutor._execute_handoff(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
tool,
run_context,
image_urls_prepared=True,
input="hello",
image_urls=[],
):
results.append(result)

assert len(results) == 1
assert "handoff_call_limit_reached" in results[0].content[0].text
assert captured == {}


@pytest.mark.asyncio
async def test_execute_handoff_increments_call_count_on_success():
class _EventWithExtras:
def __init__(self) -> None:
self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
self.message_obj = SimpleNamespace(message=[])
self._extras: dict[str, int] = {}

def get_extra(self, key: str, default=None):
return self._extras.get(key, default)

def set_extra(self, key: str, value):
self._extras[key] = value

async def _fake_get_current_chat_provider_id(_umo):
return "provider-id"

async def _fake_tool_loop_agent(**_kwargs):
return SimpleNamespace(completion_text="ok")

context = SimpleNamespace(
get_current_chat_provider_id=_fake_get_current_chat_provider_id,
tool_loop_agent=_fake_tool_loop_agent,
get_config=lambda **_kwargs: {
"provider_settings": {},
"subagent_orchestrator": {"max_handoff_calls_per_run": 2},
},
)
event = _EventWithExtras()
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))
tool = SimpleNamespace(
name="transfer_to_subagent",
provider_id=None,
agent=SimpleNamespace(
name="subagent",
tools=[],
instructions="subagent-instructions",
begin_dialogs=[],
run_hooks=None,
),
)

results = []
async for result in FunctionToolExecutor._execute_handoff(
tool,
run_context,
image_urls_prepared=True,
input="hello",
image_urls=[],
):
results.append(result)

assert len(results) == 1
assert (
event._extras[FunctionToolExecutor._HANDOFF_CALL_COUNT_EXTRA_KEY]
== 1
)


@pytest.mark.asyncio
async def test_execute_handoff_enforces_call_limit_across_multiple_calls():
call_count = {"tool_loop": 0}

class _EventWithExtras:
def __init__(self) -> None:
self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
self.message_obj = SimpleNamespace(message=[])
self._extras: dict[str, int] = {}

def get_extra(self, key: str, default=None):
return self._extras.get(key, default)
Comment on lines +476 to +485
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider a test for the _MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT behavior when event extras contain a very large count

You already verify the per-run limit and that the delegate is not invoked once the limit is reached. Since _check_and_increment_handoff_calls also clamps the stored count via _MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT, please add a test that pre-populates the event extras with an extremely large value (e.g., via _set_event_extra or directly on _extras) and asserts that calls are still rejected even when the configured limit is higher. This will confirm the sanity limit protects against malformed event state and overflow.

Suggested implementation:

@pytest.mark.asyncio
async def test_execute_handoff_rejects_calls_when_event_extra_exceeds_sanity_limit():
    # Pre-populate event extras with an extremely large handoff call count and
    # verify that calls are rejected even when the configured limit is higher.
    class _EventWithExtras:
        def __init__(self) -> None:
            self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
            self.message_obj = SimpleNamespace(message=[])
            self._extras: dict[str, int] = {
                FunctionToolExecutor._HANDOFF_CALL_COUNT_EXTRA_KEY: 10**12,
            }

        def get_extra(self, key: str, default=None):
            return self._extras.get(key, default)

        def set_extra(self, key: str, value) -> None:
            self._extras[key] = value

    event = _EventWithExtras()
    call_count = {"tool_loop": 0}

    async def _delegate(*args, **kwargs):
        call_count["tool_loop"] += 1

    executor = FunctionToolExecutor(
        tool_name="tool_loop",
        delegate=_delegate,
        # Set a very high configured limit so that only the sanity limit
        # can prevent additional calls.
        handoff_call_limit=10**6,
    )

    # Because the stored count is above the sanity limit, the executor should
    # clamp it and immediately reject further calls.
    results = await executor.execute_handoff(event=event)

    assert call_count["tool_loop"] == 0
    assert results == []
    # The stored count should not grow without bound and should be capped by the
    # sanity limit.
    assert (
        event._extras[FunctionToolExecutor._HANDOFF_CALL_COUNT_EXTRA_KEY]
        <= FunctionToolExecutor._MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT
    )


@pytest.mark.asyncio
async def test_execute_handoff_enforces_call_limit_across_multiple_calls():

This patch assumes:

  1. FunctionToolExecutor is already imported into test_astr_agent_tool_exec.py. If not, add the appropriate import at the top of the file.
  2. FunctionToolExecutor exposes _HANDOFF_CALL_COUNT_EXTRA_KEY and _MAX_HANDOFF_CALL_COUNT_SANITY_LIMIT as attributes. If the constant uses a different name or visibility, update the test to match.
  3. If the project prefers using _set_event_extra over direct _extras mutation in tests, you can replace the inline _EventWithExtras with the shared helper or update it to use _set_event_extra instead of directly setting _extras.


def set_extra(self, key: str, value):
self._extras[key] = value

async def _fake_get_current_chat_provider_id(_umo):
return "provider-id"

async def _fake_tool_loop_agent(**_kwargs):
call_count["tool_loop"] += 1
return SimpleNamespace(completion_text="ok")

context = SimpleNamespace(
get_current_chat_provider_id=_fake_get_current_chat_provider_id,
tool_loop_agent=_fake_tool_loop_agent,
get_config=lambda **_kwargs: {
"provider_settings": {},
"subagent_orchestrator": {"max_handoff_calls_per_run": 2},
},
)
event = _EventWithExtras()
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))
tool = SimpleNamespace(
name="transfer_to_subagent",
provider_id=None,
agent=SimpleNamespace(
name="subagent",
tools=[],
instructions="subagent-instructions",
begin_dialogs=[],
run_hooks=None,
),
)

first_results = []
async for result in FunctionToolExecutor._execute_handoff(
tool,
run_context,
image_urls_prepared=True,
input="first",
image_urls=[],
):
first_results.append(result)
assert len(first_results) == 1
assert "handoff_call_limit_reached" not in first_results[0].content[0].text
assert (
event._extras[FunctionToolExecutor._HANDOFF_CALL_COUNT_EXTRA_KEY]
== 1
)

second_results = []
async for result in FunctionToolExecutor._execute_handoff(
tool,
run_context,
image_urls_prepared=True,
input="second",
image_urls=[],
):
second_results.append(result)
assert len(second_results) == 1
assert "handoff_call_limit_reached" not in second_results[0].content[0].text
assert (
event._extras[FunctionToolExecutor._HANDOFF_CALL_COUNT_EXTRA_KEY]
== 2
)

third_results = []
async for result in FunctionToolExecutor._execute_handoff(
tool,
run_context,
image_urls_prepared=True,
input="third",
image_urls=[],
):
third_results.append(result)
assert len(third_results) == 1
assert "handoff_call_limit_reached" in third_results[0].content[0].text
assert (
event._extras[FunctionToolExecutor._HANDOFF_CALL_COUNT_EXTRA_KEY]
== 2
)
assert call_count["tool_loop"] == 2