Skip to content

Commit b868aee

Browse files
committed
Balance hallucinated tool callback lifecycle
1 parent 7af4ea9 commit b868aee

File tree

2 files changed

+98
-8
lines changed

2 files changed

+98
-8
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ async def _run_on_tool_error_callbacks(
480480
invocation_context, function_call, tool_confirmation
481481
)
482482

483+
_tool_lookup_error: Exception | None = None
483484
try:
484485
tool = _get_tool(function_call, tools_dict)
485486
except ValueError as tool_error:
@@ -488,9 +489,7 @@ async def _run_on_tool_error_callbacks(
488489
# OTel span are created *before* on_tool_error_callback fires. This
489490
# keeps the callback lifecycle balanced (push/pop) and prevents plugins
490491
# like BigQueryAgentAnalyticsPlugin from corrupting their span stacks.
491-
_tool_lookup_error: Exception = tool_error
492-
else:
493-
_tool_lookup_error = None
492+
_tool_lookup_error = tool_error
494493

495494
async def _run_with_trace():
496495
nonlocal function_args
@@ -722,13 +721,12 @@ async def _run_on_tool_error_callbacks(
722721

723722
tool_context = _create_tool_context(invocation_context, function_call)
724723

724+
_tool_lookup_error: Exception | None = None
725725
try:
726726
tool = _get_tool(function_call, tools_dict)
727727
except ValueError as tool_error:
728728
tool = BaseTool(name=function_call.name, description='Tool not found')
729-
_tool_lookup_error: Exception = tool_error
730-
else:
731-
_tool_lookup_error = None
729+
_tool_lookup_error = tool_error
732730

733731
async def _run_with_trace():
734732
nonlocal function_args

tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,9 +429,101 @@ async def test_hallucinated_tool_raises_when_no_error_callback(
429429
agent=agent, user_content="", plugins=[mock_plugin]
430430
)
431431

432+
function_call = types.FunctionCall(name="nonexistent_tool", args={})
433+
content = types.Content(parts=[types.Part(function_call=function_call)])
434+
event = Event(
435+
invocation_id=invocation_context.invocation_id,
436+
author=agent.name,
437+
content=content,
438+
)
439+
tools_dict = {mock_tool.name: mock_tool}
440+
441+
with pytest.raises(ValueError, match="nonexistent_tool"):
442+
await handle_function_calls_async(
443+
invocation_context,
444+
event,
445+
tools_dict,
446+
)
447+
448+
449+
@pytest.mark.asyncio
450+
async def test_hallucinated_tool_fires_before_and_error_callbacks_live(
451+
mock_tool, mock_plugin
452+
):
453+
"""Live path regression test for hallucinated tool callback ordering."""
454+
mock_plugin.enable_before_tool_callback = True
455+
mock_plugin.enable_on_tool_error_callback = True
456+
457+
call_order = []
458+
original_before = mock_plugin.before_tool_callback
459+
original_error = mock_plugin.on_tool_error_callback
460+
461+
async def tracking_before(**kwargs):
462+
call_order.append("before_tool")
463+
return await original_before(**kwargs)
464+
465+
async def tracking_error(**kwargs):
466+
call_order.append("on_tool_error")
467+
return await original_error(**kwargs)
468+
469+
mock_plugin.before_tool_callback = tracking_before
470+
mock_plugin.on_tool_error_callback = tracking_error
471+
472+
model = testing_utils.MockModel.create(responses=[])
473+
agent = Agent(
474+
name="agent",
475+
model=model,
476+
tools=[mock_tool],
477+
)
478+
invocation_context = await testing_utils.create_invocation_context(
479+
agent=agent, user_content="", plugins=[mock_plugin]
480+
)
481+
432482
function_call = types.FunctionCall(
433-
name="nonexistent_tool", args={}
483+
name="hallucinated_tool_xyz", args={"query": "test"}
484+
)
485+
content = types.Content(parts=[types.Part(function_call=function_call)])
486+
event = Event(
487+
invocation_id=invocation_context.invocation_id,
488+
author=agent.name,
489+
content=content,
490+
)
491+
tools_dict = {mock_tool.name: mock_tool}
492+
493+
result_event = await handle_function_calls_live(
494+
invocation_context,
495+
event,
496+
tools_dict,
497+
)
498+
499+
assert result_event is not None
500+
part = result_event.content.parts[0]
501+
assert part.function_response.response == mock_plugin.on_tool_error_response
502+
503+
assert "before_tool" in call_order
504+
assert "on_tool_error" in call_order
505+
assert call_order.index("before_tool") < call_order.index("on_tool_error")
506+
507+
508+
@pytest.mark.asyncio
509+
async def test_hallucinated_tool_raises_when_no_error_callback_live(
510+
mock_tool, mock_plugin
511+
):
512+
"""Live path should propagate ValueError for hallucinated tools."""
513+
mock_plugin.enable_before_tool_callback = False
514+
mock_plugin.enable_on_tool_error_callback = False
515+
516+
model = testing_utils.MockModel.create(responses=[])
517+
agent = Agent(
518+
name="agent",
519+
model=model,
520+
tools=[mock_tool],
434521
)
522+
invocation_context = await testing_utils.create_invocation_context(
523+
agent=agent, user_content="", plugins=[mock_plugin]
524+
)
525+
526+
function_call = types.FunctionCall(name="nonexistent_tool", args={})
435527
content = types.Content(parts=[types.Part(function_call=function_call)])
436528
event = Event(
437529
invocation_id=invocation_context.invocation_id,
@@ -441,7 +533,7 @@ async def test_hallucinated_tool_raises_when_no_error_callback(
441533
tools_dict = {mock_tool.name: mock_tool}
442534

443535
with pytest.raises(ValueError, match="nonexistent_tool"):
444-
await handle_function_calls_async(
536+
await handle_function_calls_live(
445537
invocation_context,
446538
event,
447539
tools_dict,

0 commit comments

Comments
 (0)