Skip to content

Commit 7af4ea9

Browse files
committed
fix: balance callback lifecycle for hallucinated tool calls
When an LLM hallucinates a tool name, _get_tool() raises ValueError. Previously, on_tool_error_callback fired immediately — before before_tool_callback and outside the OTel tracer span. This caused plugins that push/pop spans (e.g. BigQueryAgentAnalyticsPlugin's TraceManager) to pop the parent agent span, corrupting the trace stack for all subsequent tool calls. Move the ValueError handling inside _run_with_trace() so that: 1. before_tool_callback always fires first (balanced push) 2. The error is surfaced within the OTel span context 3. on_tool_error_callback fires after before_tool_callback Applied to both handle_function_calls_async and handle_function_calls_live code paths. Fixes #4775
1 parent 4b677e7 commit 7af4ea9

File tree

2 files changed

+149
-23
lines changed

2 files changed

+149
-23
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -484,18 +484,13 @@ async def _run_on_tool_error_callbacks(
484484
tool = _get_tool(function_call, tools_dict)
485485
except ValueError as tool_error:
486486
tool = BaseTool(name=function_call.name, description='Tool not found')
487-
error_response = await _run_on_tool_error_callbacks(
488-
tool=tool,
489-
tool_args=function_args,
490-
tool_context=tool_context,
491-
error=tool_error,
492-
)
493-
if error_response is not None:
494-
return __build_response_event(
495-
tool, error_response, tool_context, invocation_context
496-
)
497-
else:
498-
raise tool_error
487+
# Fall through to _run_with_trace so that before_tool_callback and the
488+
# OTel span are created *before* on_tool_error_callback fires. This
489+
# keeps the callback lifecycle balanced (push/pop) and prevents plugins
490+
# like BigQueryAgentAnalyticsPlugin from corrupting their span stacks.
491+
_tool_lookup_error: Exception = tool_error
492+
else:
493+
_tool_lookup_error = None
499494

500495
async def _run_with_trace():
501496
nonlocal function_args
@@ -520,6 +515,22 @@ async def _run_with_trace():
520515
if function_response:
521516
break
522517

518+
# Step 2.5: If the tool was not found (hallucinated), surface the error
519+
# *after* before_tool_callback so the lifecycle stays balanced.
520+
if _tool_lookup_error is not None:
521+
error_response = await _run_on_tool_error_callbacks(
522+
tool=tool,
523+
tool_args=function_args,
524+
tool_context=tool_context,
525+
error=_tool_lookup_error,
526+
)
527+
if error_response is not None:
528+
return __build_response_event(
529+
tool, error_response, tool_context, invocation_context
530+
)
531+
else:
532+
raise _tool_lookup_error
533+
523534
# Step 3: Otherwise, proceed calling the tool normally.
524535
if function_response is None:
525536
try:
@@ -715,17 +726,9 @@ async def _run_on_tool_error_callbacks(
715726
tool = _get_tool(function_call, tools_dict)
716727
except ValueError as tool_error:
717728
tool = BaseTool(name=function_call.name, description='Tool not found')
718-
error_response = await _run_on_tool_error_callbacks(
719-
tool=tool,
720-
tool_args=function_args,
721-
tool_context=tool_context,
722-
error=tool_error,
723-
)
724-
if error_response is not None:
725-
return __build_response_event(
726-
tool, error_response, tool_context, invocation_context
727-
)
728-
raise tool_error
729+
_tool_lookup_error: Exception = tool_error
730+
else:
731+
_tool_lookup_error = None
729732

730733
async def _run_with_trace():
731734
nonlocal function_args
@@ -755,6 +758,21 @@ async def _run_with_trace():
755758
if function_response:
756759
break
757760

761+
# Step 2.5: If the tool was not found (hallucinated), surface the error
762+
# *after* before_tool_callback so the lifecycle stays balanced.
763+
if _tool_lookup_error is not None:
764+
error_response = await _run_on_tool_error_callbacks(
765+
tool=tool,
766+
tool_args=function_args,
767+
tool_context=tool_context,
768+
error=_tool_lookup_error,
769+
)
770+
if error_response is not None:
771+
return __build_response_event(
772+
tool, error_response, tool_context, invocation_context
773+
)
774+
raise _tool_lookup_error
775+
758776
# Step 3: Otherwise, proceed calling the tool normally.
759777
if function_response is None:
760778
try:

tests/unittests/flows/llm_flows/test_plugin_tool_callbacks.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,5 +340,113 @@ def agent_after_cb(tool, args, tool_context, tool_response):
340340
assert part.function_response.response == mock_plugin.after_tool_response
341341

342342

343+
@pytest.mark.asyncio
344+
async def test_hallucinated_tool_fires_before_and_error_callbacks(
345+
mock_tool, mock_plugin
346+
):
347+
"""Regression test for https://github.com/google/adk-python/issues/4775.
348+
349+
When the LLM hallucinates a tool name, on_tool_error_callback used to fire
350+
*before* before_tool_callback, corrupting plugin span stacks (e.g.
351+
BigQueryAgentAnalyticsPlugin's TraceManager). After the fix, both
352+
callbacks should fire in order: before_tool → on_tool_error.
353+
"""
354+
mock_plugin.enable_before_tool_callback = True
355+
mock_plugin.enable_on_tool_error_callback = True
356+
357+
# Track callback invocation order
358+
call_order = []
359+
original_before = mock_plugin.before_tool_callback
360+
original_error = mock_plugin.on_tool_error_callback
361+
362+
async def tracking_before(**kwargs):
363+
call_order.append("before_tool")
364+
return await original_before(**kwargs)
365+
366+
async def tracking_error(**kwargs):
367+
call_order.append("on_tool_error")
368+
return await original_error(**kwargs)
369+
370+
mock_plugin.before_tool_callback = tracking_before
371+
mock_plugin.on_tool_error_callback = tracking_error
372+
373+
model = testing_utils.MockModel.create(responses=[])
374+
agent = Agent(
375+
name="agent",
376+
model=model,
377+
tools=[mock_tool],
378+
)
379+
invocation_context = await testing_utils.create_invocation_context(
380+
agent=agent, user_content="", plugins=[mock_plugin]
381+
)
382+
383+
# Build function call for a non-existent tool (hallucinated name)
384+
function_call = types.FunctionCall(
385+
name="hallucinated_tool_xyz", args={"query": "test"}
386+
)
387+
content = types.Content(parts=[types.Part(function_call=function_call)])
388+
event = Event(
389+
invocation_id=invocation_context.invocation_id,
390+
author=agent.name,
391+
content=content,
392+
)
393+
tools_dict = {mock_tool.name: mock_tool}
394+
395+
result_event = await handle_function_calls_async(
396+
invocation_context,
397+
event,
398+
tools_dict,
399+
)
400+
401+
# on_tool_error_callback returned a response, so we should get an event
402+
assert result_event is not None
403+
part = result_event.content.parts[0]
404+
assert part.function_response.response == mock_plugin.on_tool_error_response
405+
406+
# Verify that before_tool fired BEFORE on_tool_error
407+
assert "before_tool" in call_order
408+
assert "on_tool_error" in call_order
409+
assert call_order.index("before_tool") < call_order.index("on_tool_error")
410+
411+
412+
@pytest.mark.asyncio
413+
async def test_hallucinated_tool_raises_when_no_error_callback(
414+
mock_tool, mock_plugin
415+
):
416+
"""When a tool is hallucinated and no error callback handles it, ValueError
417+
should propagate — but only after before_tool_callback has had a chance to
418+
run (so plugin stacks remain balanced)."""
419+
mock_plugin.enable_before_tool_callback = False
420+
mock_plugin.enable_on_tool_error_callback = False
421+
422+
model = testing_utils.MockModel.create(responses=[])
423+
agent = Agent(
424+
name="agent",
425+
model=model,
426+
tools=[mock_tool],
427+
)
428+
invocation_context = await testing_utils.create_invocation_context(
429+
agent=agent, user_content="", plugins=[mock_plugin]
430+
)
431+
432+
function_call = types.FunctionCall(
433+
name="nonexistent_tool", args={}
434+
)
435+
content = types.Content(parts=[types.Part(function_call=function_call)])
436+
event = Event(
437+
invocation_id=invocation_context.invocation_id,
438+
author=agent.name,
439+
content=content,
440+
)
441+
tools_dict = {mock_tool.name: mock_tool}
442+
443+
with pytest.raises(ValueError, match="nonexistent_tool"):
444+
await handle_function_calls_async(
445+
invocation_context,
446+
event,
447+
tools_dict,
448+
)
449+
450+
343451
if __name__ == "__main__":
344452
pytest.main([__file__])

0 commit comments

Comments
 (0)