Skip to content

Commit ae95a97

Browse files
google-genai-botDeanChensj
authored andcommitted
perf(flows): Resolve agent tool unions in parallel
Merge 31eda49 into 6ce4b87 Merges #5875 ORIGINAL_AUTHOR=wenzhaoy-google <141370433+wenzhaoy-google@users.noreply.github.com> GitOrigin-RevId: 4362766 Change-Id: I2093e78fd71c20bb600b82a979cf4ae92d0b18ff
1 parent ce9011c commit ae95a97

2 files changed

Lines changed: 158 additions & 10 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,15 @@ async def _process_agent_tools(
420420
instances, and calls ``process_llm_request`` on each to register
421421
tool declarations in the request.
422422
423+
Tool-union resolution is dispatched concurrently via ``asyncio.gather``
424+
to overlap I/O-bound listings (e.g. MCP ``list_tools`` over the
425+
network). The subsequent ``process_llm_request`` calls are kept
426+
serial in the original ``agent.tools`` order: some tools read/write
427+
``llm_request`` state (e.g. ``GoogleSearchTool`` writes
428+
``llm_request.model``; ``ComputerUseToolset`` performs an idempotency
429+
check on ``llm_request.config.tools``) and rely on observing the
430+
post-state of earlier tools.
431+
423432
After this function returns, ``llm_request.tools_dict`` maps tool
424433
names to ``BaseTool`` instances ready for function call dispatch.
425434
@@ -429,12 +438,34 @@ async def _process_agent_tools(
429438
llm_request: The LLM request to populate with tool declarations.
430439
"""
431440
agent = invocation_context.agent
432-
if not hasattr(agent, 'tools') or not agent.tools:
441+
if agent is None or not hasattr(agent, 'tools') or not agent.tools:
433442
return
434443

435444
multiple_tools = len(agent.tools) > 1
436445
model = agent.canonical_model
437-
for tool_union in agent.tools:
446+
447+
from ...agents.llm_agent import _convert_tool_union_to_tools
448+
449+
# Resolve tool_unions in parallel. ``asyncio.gather`` preserves
450+
# input order in the returned list, so the serial commit phase below
451+
# still observes ``agent.tools`` order. If any resolution raises,
452+
# gather cancels the siblings and propagates -- same observable
453+
# behavior as the previous serial loop, which would propagate the
454+
# first exception and abandon the rest.
455+
resolved_tools_per_union = await asyncio.gather(*(
456+
_convert_tool_union_to_tools(
457+
tool_union,
458+
ReadonlyContext(invocation_context),
459+
model,
460+
multiple_tools,
461+
)
462+
for tool_union in agent.tools
463+
))
464+
465+
# Serial commit phase, in original ``agent.tools`` order. Mutations
466+
# to ``llm_request`` and reads of its state (model, config.tools,
467+
# tools_dict) preserve today's ordering semantics exactly.
468+
for tool_union, tools in zip(agent.tools, resolved_tools_per_union):
438469
tool_context = ToolContext(invocation_context)
439470

440471
# If it's a toolset, process it first
@@ -443,15 +474,7 @@ async def _process_agent_tools(
443474
tool_context=tool_context, llm_request=llm_request
444475
)
445476

446-
from ...agents.llm_agent import _convert_tool_union_to_tools
447-
448477
# Then process all tools from this tool union
449-
tools = await _convert_tool_union_to_tools(
450-
tool_union,
451-
ReadonlyContext(invocation_context),
452-
model,
453-
multiple_tools,
454-
)
455478
for tool in tools:
456479
await tool.process_llm_request(
457480
tool_context=tool_context, llm_request=llm_request

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Unit tests for BaseLlmFlow toolset integration."""
1616

17+
import asyncio
1718
from unittest import mock
1819
from unittest.mock import AsyncMock
1920

@@ -243,6 +244,130 @@ def _my_tool(sides: int) -> int:
243244
)
244245

245246

247+
@pytest.mark.asyncio
248+
async def test_process_agent_tools_resolves_unions_in_parallel():
249+
"""``_convert_tool_union_to_tools`` is dispatched for every tool_union concurrently.
250+
251+
Each mocked resolution blocks until ``all_started`` is set; the event
252+
is only set once every call has been entered. If
253+
``_process_agent_tools`` were still serial, the first call would
254+
block forever waiting for the event the second call hasn't yet
255+
entered to set.
256+
"""
257+
num_tools = 5
258+
started_count = 0
259+
all_started = asyncio.Event()
260+
release = asyncio.Event()
261+
262+
async def blocking_convert(tool_union, *args, **kwargs):
263+
del args, kwargs
264+
nonlocal started_count
265+
started_count += 1
266+
if started_count == num_tools:
267+
all_started.set()
268+
await release.wait()
269+
return [_AsyncProcessLlmRequestTool(name=tool_union.__name__)]
270+
271+
def _make_func(i):
272+
def _f():
273+
"""Test function."""
274+
return i
275+
276+
_f.__name__ = f'fn_{i}'
277+
return _f
278+
279+
funcs = [_make_func(i) for i in range(num_tools)]
280+
281+
with mock.patch(
282+
'google.adk.agents.llm_agent._convert_tool_union_to_tools',
283+
side_effect=blocking_convert,
284+
):
285+
agent = Agent(name='test_agent', tools=funcs)
286+
invocation_context = await testing_utils.create_invocation_context(
287+
agent=agent, user_content='test message'
288+
)
289+
flow = BaseLlmFlowForTesting()
290+
llm_request = LlmRequest()
291+
292+
async def drive():
293+
async for _ in flow._preprocess_async(invocation_context, llm_request):
294+
pass
295+
296+
drive_task = asyncio.create_task(drive())
297+
try:
298+
# If resolution were serial this would hang; release the gate as
299+
# soon as every coroutine has entered.
300+
await asyncio.wait_for(all_started.wait(), timeout=5.0)
301+
finally:
302+
release.set()
303+
await asyncio.wait_for(drive_task, timeout=5.0)
304+
305+
assert started_count == num_tools
306+
307+
308+
@pytest.mark.asyncio
309+
async def test_process_agent_tools_preserves_order_when_later_unions_resolve_first():
310+
"""``process_llm_request`` is called in original ``agent.tools`` order even when later unions resolve first."""
311+
312+
resolution_started_evt = [asyncio.Event(), asyncio.Event()]
313+
process_call_order: list[str] = []
314+
315+
async def staggered_convert(tool_union, *args, **kwargs):
316+
del args, kwargs
317+
if tool_union.__name__ == 'fn_slow':
318+
# Resolve only after fn_fast's resolution has completed.
319+
await resolution_started_evt[1].wait()
320+
tool_name = 'slow_tool'
321+
else:
322+
tool_name = 'fast_tool'
323+
resolution_started_evt[1].set()
324+
return [
325+
_AsyncProcessLlmRequestTool(
326+
name=tool_name, on_process=process_call_order.append
327+
)
328+
]
329+
330+
def fn_slow():
331+
"""Slow-resolving function."""
332+
return 0
333+
334+
def fn_fast():
335+
"""Fast-resolving function."""
336+
return 0
337+
338+
with mock.patch(
339+
'google.adk.agents.llm_agent._convert_tool_union_to_tools',
340+
side_effect=staggered_convert,
341+
):
342+
# agent.tools order is [slow, fast]; resolution completes [fast, slow].
343+
agent = Agent(name='test_agent', tools=[fn_slow, fn_fast])
344+
invocation_context = await testing_utils.create_invocation_context(
345+
agent=agent, user_content='test message'
346+
)
347+
flow = BaseLlmFlowForTesting()
348+
llm_request = LlmRequest()
349+
350+
async for _ in flow._preprocess_async(invocation_context, llm_request):
351+
pass
352+
353+
# Even though fast_tool was resolved first, process_llm_request must
354+
# be invoked in agent.tools order (slow_tool first).
355+
assert process_call_order == ['slow_tool', 'fast_tool']
356+
357+
358+
class _AsyncProcessLlmRequestTool:
359+
"""Minimal stand-in for a BaseTool that records process_llm_request calls."""
360+
361+
def __init__(self, name: str, on_process=None):
362+
self.name = name
363+
self._on_process = on_process
364+
365+
async def process_llm_request(self, *, tool_context, llm_request):
366+
del tool_context, llm_request
367+
if self._on_process is not None:
368+
self._on_process(self.name)
369+
370+
246371
# TODO(b/448114567): Remove the following
247372
# test_handle_after_model_callback_grounding tests once the workaround
248373
# is no longer needed.

0 commit comments

Comments
 (0)