|
14 | 14 |
|
15 | 15 | """Unit tests for BaseLlmFlow toolset integration.""" |
16 | 16 |
|
| 17 | +import asyncio |
17 | 18 | from unittest import mock |
18 | 19 | from unittest.mock import AsyncMock |
19 | 20 |
|
@@ -243,6 +244,130 @@ def _my_tool(sides: int) -> int: |
243 | 244 | ) |
244 | 245 |
|
245 | 246 |
|
| 247 | +@pytest.mark.asyncio |
| 248 | +async def test_process_agent_tools_resolves_unions_in_parallel(): |
| 249 | + """``_convert_tool_union_to_tools`` is dispatched for every tool_union concurrently. |
| 250 | +
|
| 251 | + Each mocked resolution blocks until ``all_started`` is set; the event |
| 252 | + is only set once every call has been entered. If |
| 253 | + ``_process_agent_tools`` were still serial, the first call would |
| 254 | + block forever waiting for the event the second call hasn't yet |
| 255 | + entered to set. |
| 256 | + """ |
| 257 | + num_tools = 5 |
| 258 | + started_count = 0 |
| 259 | + all_started = asyncio.Event() |
| 260 | + release = asyncio.Event() |
| 261 | + |
| 262 | + async def blocking_convert(tool_union, *args, **kwargs): |
| 263 | + del args, kwargs |
| 264 | + nonlocal started_count |
| 265 | + started_count += 1 |
| 266 | + if started_count == num_tools: |
| 267 | + all_started.set() |
| 268 | + await release.wait() |
| 269 | + return [_AsyncProcessLlmRequestTool(name=tool_union.__name__)] |
| 270 | + |
| 271 | + def _make_func(i): |
| 272 | + def _f(): |
| 273 | + """Test function.""" |
| 274 | + return i |
| 275 | + |
| 276 | + _f.__name__ = f'fn_{i}' |
| 277 | + return _f |
| 278 | + |
| 279 | + funcs = [_make_func(i) for i in range(num_tools)] |
| 280 | + |
| 281 | + with mock.patch( |
| 282 | + 'google.adk.agents.llm_agent._convert_tool_union_to_tools', |
| 283 | + side_effect=blocking_convert, |
| 284 | + ): |
| 285 | + agent = Agent(name='test_agent', tools=funcs) |
| 286 | + invocation_context = await testing_utils.create_invocation_context( |
| 287 | + agent=agent, user_content='test message' |
| 288 | + ) |
| 289 | + flow = BaseLlmFlowForTesting() |
| 290 | + llm_request = LlmRequest() |
| 291 | + |
| 292 | + async def drive(): |
| 293 | + async for _ in flow._preprocess_async(invocation_context, llm_request): |
| 294 | + pass |
| 295 | + |
| 296 | + drive_task = asyncio.create_task(drive()) |
| 297 | + try: |
| 298 | + # If resolution were serial this would hang; release the gate as |
| 299 | + # soon as every coroutine has entered. |
| 300 | + await asyncio.wait_for(all_started.wait(), timeout=5.0) |
| 301 | + finally: |
| 302 | + release.set() |
| 303 | + await asyncio.wait_for(drive_task, timeout=5.0) |
| 304 | + |
| 305 | + assert started_count == num_tools |
| 306 | + |
| 307 | + |
| 308 | +@pytest.mark.asyncio |
| 309 | +async def test_process_agent_tools_preserves_order_when_later_unions_resolve_first(): |
| 310 | + """``process_llm_request`` is called in original ``agent.tools`` order even when later unions resolve first.""" |
| 311 | + |
| 312 | + resolution_started_evt = [asyncio.Event(), asyncio.Event()] |
| 313 | + process_call_order: list[str] = [] |
| 314 | + |
| 315 | + async def staggered_convert(tool_union, *args, **kwargs): |
| 316 | + del args, kwargs |
| 317 | + if tool_union.__name__ == 'fn_slow': |
| 318 | + # Resolve only after fn_fast's resolution has completed. |
| 319 | + await resolution_started_evt[1].wait() |
| 320 | + tool_name = 'slow_tool' |
| 321 | + else: |
| 322 | + tool_name = 'fast_tool' |
| 323 | + resolution_started_evt[1].set() |
| 324 | + return [ |
| 325 | + _AsyncProcessLlmRequestTool( |
| 326 | + name=tool_name, on_process=process_call_order.append |
| 327 | + ) |
| 328 | + ] |
| 329 | + |
| 330 | + def fn_slow(): |
| 331 | + """Slow-resolving function.""" |
| 332 | + return 0 |
| 333 | + |
| 334 | + def fn_fast(): |
| 335 | + """Fast-resolving function.""" |
| 336 | + return 0 |
| 337 | + |
| 338 | + with mock.patch( |
| 339 | + 'google.adk.agents.llm_agent._convert_tool_union_to_tools', |
| 340 | + side_effect=staggered_convert, |
| 341 | + ): |
| 342 | + # agent.tools order is [slow, fast]; resolution completes [fast, slow]. |
| 343 | + agent = Agent(name='test_agent', tools=[fn_slow, fn_fast]) |
| 344 | + invocation_context = await testing_utils.create_invocation_context( |
| 345 | + agent=agent, user_content='test message' |
| 346 | + ) |
| 347 | + flow = BaseLlmFlowForTesting() |
| 348 | + llm_request = LlmRequest() |
| 349 | + |
| 350 | + async for _ in flow._preprocess_async(invocation_context, llm_request): |
| 351 | + pass |
| 352 | + |
| 353 | + # Even though fast_tool was resolved first, process_llm_request must |
| 354 | + # be invoked in agent.tools order (slow_tool first). |
| 355 | + assert process_call_order == ['slow_tool', 'fast_tool'] |
| 356 | + |
| 357 | + |
| 358 | +class _AsyncProcessLlmRequestTool: |
| 359 | + """Minimal stand-in for a BaseTool that records process_llm_request calls.""" |
| 360 | + |
| 361 | + def __init__(self, name: str, on_process=None): |
| 362 | + self.name = name |
| 363 | + self._on_process = on_process |
| 364 | + |
| 365 | + async def process_llm_request(self, *, tool_context, llm_request): |
| 366 | + del tool_context, llm_request |
| 367 | + if self._on_process is not None: |
| 368 | + self._on_process(self.name) |
| 369 | + |
| 370 | + |
246 | 371 | # TODO(b/448114567): Remove the following |
247 | 372 | # test_handle_after_model_callback_grounding tests once the workaround |
248 | 373 | # is no longer needed. |
|
0 commit comments