Skip to content

Commit 78a8851

Browse files
GWealecopybara-github
authored andcommitted
fix: avoid double-execution of sync FunctionTools returning None
_call_tool_in_thread_pool used None as a sentinel to distinguish "FunctionTool ran successfully" from "non-FunctionTool sync tool, needs async fallback". When a sync FunctionTool's function legitimately returned None, the sentinel check fell through to tool.run_async() and re-invoked the underlying function. Restructure the dispatch so the sync-FunctionTool path returns directly and the non-FunctionTool sync path falls through explicitly, removing the ambiguous sentinel. Close #5284 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 905258186
1 parent f0c787f commit 78a8851

2 files changed

Lines changed: 103 additions & 14 deletions

File tree

src/google/adk/flows/llm_flows/functions.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ async def _call_tool_in_thread_pool(
145145
executor = _get_tool_thread_pool(max_workers)
146146

147147
if _is_sync_tool(tool):
148-
# For sync FunctionTool, call the underlying function directly
149-
def run_sync_tool():
150-
if isinstance(tool, FunctionTool):
148+
if isinstance(tool, FunctionTool):
149+
# For sync FunctionTool, call the underlying function directly.
150+
def run_sync_tool():
151151
args_to_call = tool._preprocess_args(args)
152152
signature = inspect.signature(tool.func)
153153
valid_params = {param for param in signature.parameters}
@@ -157,15 +157,10 @@ def run_sync_tool():
157157
k: v for k, v in args_to_call.items() if k in valid_params
158158
}
159159
return tool.func(**args_to_call)
160-
else:
161-
# For other sync tool types, we can't easily run them in thread pool
162-
return None
163160

164-
result = await loop.run_in_executor(
165-
executor, lambda: ctx.run(run_sync_tool)
166-
)
167-
if result is not None:
168-
return result
161+
return await loop.run_in_executor(
162+
executor, lambda: ctx.run(run_sync_tool)
163+
)
169164
else:
170165
# For async tools, run them in a new event loop in a background thread.
171166
# This helps when async functions contain blocking I/O (common user mistake)
@@ -178,7 +173,7 @@ def run_async_tool_in_new_loop():
178173
executor, lambda: ctx.run(run_async_tool_in_new_loop)
179174
)
180175

181-
# Fall back to normal async execution for non-FunctionTool sync tools
176+
# Fall back to normal async execution for non-FunctionTool sync tools.
182177
return await tool.run_async(args=args, tool_context=tool_context)
183178

184179

tests/unittests/flows/llm_flows/test_functions_thread_pool.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
from google.adk.flows.llm_flows.functions import _call_tool_in_thread_pool
2626
from google.adk.flows.llm_flows.functions import _get_tool_thread_pool
2727
from google.adk.flows.llm_flows.functions import _is_sync_tool
28+
from google.adk.tools.base_tool import BaseTool
2829
from google.adk.tools.function_tool import FunctionTool
30+
from google.adk.tools.set_model_response_tool import SetModelResponseTool
2931
from google.adk.tools.tool_context import ToolContext
3032
from google.genai import types
33+
from pydantic import BaseModel
3134
import pytest
3235

3336
from ... import testing_utils
@@ -76,8 +79,6 @@ async def async_gen_func(x: int):
7679

7780
def test_tool_without_func_returns_false(self):
7881
"""Test that a tool without func attribute returns False."""
79-
from google.adk.tools.base_tool import BaseTool
80-
8182
tool = BaseTool(name='test', description='test tool')
8283
assert _is_sync_tool(tool) is False
8384

@@ -398,6 +399,99 @@ async def async_func() -> dict[str, str]:
398399

399400
assert result == {'value': 'main_thread_value'}
400401

402+
@pytest.mark.asyncio
403+
async def test_sync_tool_returning_none_runs_exactly_once(self):
404+
"""Regression test for issue #5284.
405+
406+
A sync FunctionTool whose underlying function returns None must not
407+
be re-invoked through the run_async fallback path.
408+
"""
409+
call_count = 0
410+
411+
def side_effect_only_func() -> None:
412+
nonlocal call_count
413+
call_count += 1
414+
415+
tool = FunctionTool(side_effect_only_func)
416+
model = testing_utils.MockModel.create(responses=[])
417+
agent = Agent(name='test_agent', model=model, tools=[tool])
418+
invocation_context = await testing_utils.create_invocation_context(
419+
agent=agent, user_content=''
420+
)
421+
tool_context = ToolContext(
422+
invocation_context=invocation_context,
423+
function_call_id='test_id',
424+
)
425+
426+
result = await _call_tool_in_thread_pool(tool, {}, tool_context)
427+
428+
assert result is None
429+
assert call_count == 1
430+
431+
@pytest.mark.asyncio
432+
async def test_non_function_tool_sync_falls_back_to_run_async(self):
433+
"""Sync tools that aren't FunctionTool subclasses go through run_async.
434+
435+
Covers the fall-through path used by tools like SetModelResponseTool
436+
that have a sync ``func`` attribute but aren't FunctionTool instances.
437+
"""
438+
run_async_call_count = 0
439+
440+
class _SyncNonFunctionTool(BaseTool):
441+
442+
def __init__(self):
443+
super().__init__(name='custom_tool', description='desc')
444+
# Sync attribute so _is_sync_tool returns True.
445+
self.func = lambda: 'unused'
446+
447+
async def run_async(self, *, args, tool_context):
448+
nonlocal run_async_call_count
449+
run_async_call_count += 1
450+
return {'via': 'run_async'}
451+
452+
tool = _SyncNonFunctionTool()
453+
model = testing_utils.MockModel.create(responses=[])
454+
agent = Agent(name='test_agent', model=model, tools=[tool])
455+
invocation_context = await testing_utils.create_invocation_context(
456+
agent=agent, user_content=''
457+
)
458+
tool_context = ToolContext(
459+
invocation_context=invocation_context,
460+
function_call_id='test_id',
461+
)
462+
463+
result = await _call_tool_in_thread_pool(tool, {}, tool_context)
464+
465+
assert result == {'via': 'run_async'}
466+
assert run_async_call_count == 1
467+
468+
@pytest.mark.asyncio
469+
async def test_set_model_response_tool_falls_back_to_run_async(self):
470+
"""SetModelResponseTool — the real-world non-FunctionTool sync tool."""
471+
472+
class _Schema(BaseModel):
473+
answer: str
474+
475+
tool = SetModelResponseTool(output_schema=_Schema)
476+
# Precondition: this is the code path the bug report referenced.
477+
assert _is_sync_tool(tool)
478+
479+
model = testing_utils.MockModel.create(responses=[])
480+
agent = Agent(name='test_agent', model=model, tools=[tool])
481+
invocation_context = await testing_utils.create_invocation_context(
482+
agent=agent, user_content=''
483+
)
484+
tool_context = ToolContext(
485+
invocation_context=invocation_context,
486+
function_call_id='test_id',
487+
)
488+
489+
result = await _call_tool_in_thread_pool(
490+
tool, {'answer': 'hello'}, tool_context
491+
)
492+
493+
assert result == {'answer': 'hello'}
494+
401495

402496
class TestToolThreadPoolConfig:
403497
"""Tests for the tool_thread_pool_config in RunConfig."""

0 commit comments

Comments
 (0)