|
25 | 25 | from google.adk.flows.llm_flows.functions import _call_tool_in_thread_pool |
26 | 26 | from google.adk.flows.llm_flows.functions import _get_tool_thread_pool |
27 | 27 | from google.adk.flows.llm_flows.functions import _is_sync_tool |
| 28 | +from google.adk.tools.base_tool import BaseTool |
28 | 29 | from google.adk.tools.function_tool import FunctionTool |
| 30 | +from google.adk.tools.set_model_response_tool import SetModelResponseTool |
29 | 31 | from google.adk.tools.tool_context import ToolContext |
30 | 32 | from google.genai import types |
| 33 | +from pydantic import BaseModel |
31 | 34 | import pytest |
32 | 35 |
|
33 | 36 | from ... import testing_utils |
@@ -76,8 +79,6 @@ async def async_gen_func(x: int): |
76 | 79 |
|
77 | 80 | def test_tool_without_func_returns_false(self): |
78 | 81 | """Test that a tool without func attribute returns False.""" |
79 | | - from google.adk.tools.base_tool import BaseTool |
80 | | - |
81 | 82 | tool = BaseTool(name='test', description='test tool') |
82 | 83 | assert _is_sync_tool(tool) is False |
83 | 84 |
|
@@ -398,6 +399,99 @@ async def async_func() -> dict[str, str]: |
398 | 399 |
|
399 | 400 | assert result == {'value': 'main_thread_value'} |
400 | 401 |
|
| 402 | + @pytest.mark.asyncio |
| 403 | + async def test_sync_tool_returning_none_runs_exactly_once(self): |
| 404 | + """Regression test for issue #5284. |
| 405 | +
|
| 406 | + A sync FunctionTool whose underlying function returns None must not |
| 407 | + be re-invoked through the run_async fallback path. |
| 408 | + """ |
| 409 | + call_count = 0 |
| 410 | + |
| 411 | + def side_effect_only_func() -> None: |
| 412 | + nonlocal call_count |
| 413 | + call_count += 1 |
| 414 | + |
| 415 | + tool = FunctionTool(side_effect_only_func) |
| 416 | + model = testing_utils.MockModel.create(responses=[]) |
| 417 | + agent = Agent(name='test_agent', model=model, tools=[tool]) |
| 418 | + invocation_context = await testing_utils.create_invocation_context( |
| 419 | + agent=agent, user_content='' |
| 420 | + ) |
| 421 | + tool_context = ToolContext( |
| 422 | + invocation_context=invocation_context, |
| 423 | + function_call_id='test_id', |
| 424 | + ) |
| 425 | + |
| 426 | + result = await _call_tool_in_thread_pool(tool, {}, tool_context) |
| 427 | + |
| 428 | + assert result is None |
| 429 | + assert call_count == 1 |
| 430 | + |
| 431 | + @pytest.mark.asyncio |
| 432 | + async def test_non_function_tool_sync_falls_back_to_run_async(self): |
| 433 | + """Sync tools that aren't FunctionTool subclasses go through run_async. |
| 434 | +
|
| 435 | + Covers the fall-through path used by tools like SetModelResponseTool |
| 436 | + that have a sync ``func`` attribute but aren't FunctionTool instances. |
| 437 | + """ |
| 438 | + run_async_call_count = 0 |
| 439 | + |
| 440 | + class _SyncNonFunctionTool(BaseTool): |
| 441 | + |
| 442 | + def __init__(self): |
| 443 | + super().__init__(name='custom_tool', description='desc') |
| 444 | + # Sync attribute so _is_sync_tool returns True. |
| 445 | + self.func = lambda: 'unused' |
| 446 | + |
| 447 | + async def run_async(self, *, args, tool_context): |
| 448 | + nonlocal run_async_call_count |
| 449 | + run_async_call_count += 1 |
| 450 | + return {'via': 'run_async'} |
| 451 | + |
| 452 | + tool = _SyncNonFunctionTool() |
| 453 | + model = testing_utils.MockModel.create(responses=[]) |
| 454 | + agent = Agent(name='test_agent', model=model, tools=[tool]) |
| 455 | + invocation_context = await testing_utils.create_invocation_context( |
| 456 | + agent=agent, user_content='' |
| 457 | + ) |
| 458 | + tool_context = ToolContext( |
| 459 | + invocation_context=invocation_context, |
| 460 | + function_call_id='test_id', |
| 461 | + ) |
| 462 | + |
| 463 | + result = await _call_tool_in_thread_pool(tool, {}, tool_context) |
| 464 | + |
| 465 | + assert result == {'via': 'run_async'} |
| 466 | + assert run_async_call_count == 1 |
| 467 | + |
| 468 | + @pytest.mark.asyncio |
| 469 | + async def test_set_model_response_tool_falls_back_to_run_async(self): |
| 470 | + """SetModelResponseTool — the real-world non-FunctionTool sync tool.""" |
| 471 | + |
| 472 | + class _Schema(BaseModel): |
| 473 | + answer: str |
| 474 | + |
| 475 | + tool = SetModelResponseTool(output_schema=_Schema) |
| 476 | + # Precondition: this is the code path the bug report referenced. |
| 477 | + assert _is_sync_tool(tool) |
| 478 | + |
| 479 | + model = testing_utils.MockModel.create(responses=[]) |
| 480 | + agent = Agent(name='test_agent', model=model, tools=[tool]) |
| 481 | + invocation_context = await testing_utils.create_invocation_context( |
| 482 | + agent=agent, user_content='' |
| 483 | + ) |
| 484 | + tool_context = ToolContext( |
| 485 | + invocation_context=invocation_context, |
| 486 | + function_call_id='test_id', |
| 487 | + ) |
| 488 | + |
| 489 | + result = await _call_tool_in_thread_pool( |
| 490 | + tool, {'answer': 'hello'}, tool_context |
| 491 | + ) |
| 492 | + |
| 493 | + assert result == {'answer': 'hello'} |
| 494 | + |
401 | 495 |
|
402 | 496 | class TestToolThreadPoolConfig: |
403 | 497 | """Tests for the tool_thread_pool_config in RunConfig.""" |
|
0 commit comments