Skip to content

Commit 9566737

Browse files
committed
Support sub-agent escalation event in ParallelAgent
1 parent 1104523 commit 9566737

3 files changed

Lines changed: 251 additions & 10 deletions

File tree

src/google/adk/agents/parallel_agent.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,25 @@ def _create_branch_ctx_for_sub_agent(
4848
return invocation_context
4949

5050

51+
def _has_escalate_action(event: Event) -> bool:
52+
"""Returns whether the event asks the parent workflow to exit early."""
53+
return bool(event.actions.escalate)
54+
55+
56+
def _cancel_tasks(tasks: list[asyncio.Task[None]]) -> None:
57+
"""Cancels still-running merge tasks."""
58+
for task in tasks:
59+
if not task.done():
60+
task.cancel()
61+
62+
5163
async def _merge_agent_run(
5264
agent_runs: list[AsyncGenerator[Event, None]],
5365
) -> AsyncGenerator[Event, None]:
5466
"""Merges agent runs using asyncio.TaskGroup on Python 3.11+."""
5567
sentinel = object()
5668
queue = asyncio.Queue()
69+
tasks: list[asyncio.Task[None]] = []
5770

5871
# Agents are processed in parallel.
5972
# Events for each agent are put on queue sequentially.
@@ -70,7 +83,7 @@ async def process_an_agent(events_for_one_agent):
7083

7184
async with asyncio.TaskGroup() as tg:
7285
for events_for_one_agent in agent_runs:
73-
tg.create_task(process_an_agent(events_for_one_agent))
86+
tasks.append(tg.create_task(process_an_agent(events_for_one_agent)))
7487

7588
sentinel_count = 0
7689
# Run until all agents finished processing.
@@ -81,6 +94,9 @@ async def process_an_agent(events_for_one_agent):
8194
sentinel_count += 1
8295
else:
8396
yield event
97+
if _has_escalate_action(event):
98+
_cancel_tasks(tasks)
99+
return
84100
# Signal to agent that it should generate next event.
85101
resume_signal.set()
86102

@@ -124,7 +140,7 @@ async def process_an_agent(events_for_one_agent):
124140
# Mark agent as finished.
125141
await queue.put((sentinel, None))
126142

127-
tasks = []
143+
tasks: list[asyncio.Task[None]] = []
128144
try:
129145
for events_for_one_agent in agent_runs:
130146
tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent)))
@@ -139,12 +155,16 @@ async def process_an_agent(events_for_one_agent):
139155
sentinel_count += 1
140156
else:
141157
yield event
158+
if _has_escalate_action(event):
159+
_cancel_tasks(tasks)
160+
return
142161
# Signal to agent that event has been processed by runner and it can
143162
# continue now.
144163
resume_signal.set()
145164
finally:
146-
for task in tasks:
147-
task.cancel()
165+
_cancel_tasks(tasks)
166+
if tasks:
167+
await asyncio.gather(*tasks, return_exceptions=True)
148168

149169

150170
class ParallelAgent(BaseAgent):
@@ -181,6 +201,7 @@ async def _run_async_impl(
181201
if not sub_agent_ctx.end_of_agents.get(sub_agent.name):
182202
agent_runs.append(sub_agent.run_async(sub_agent_ctx))
183203

204+
escalated = False
184205
pause_invocation = False
185206
try:
186207
merge_func = (
@@ -191,15 +212,18 @@ async def _run_async_impl(
191212
async with Aclosing(merge_func(agent_runs)) as agen:
192213
async for event in agen:
193214
yield event
194-
if ctx.should_pause_invocation(event):
215+
if _has_escalate_action(event):
216+
escalated = True
217+
elif ctx.should_pause_invocation(event):
195218
pause_invocation = True
196219

197-
if pause_invocation:
220+
if pause_invocation and not escalated:
198221
return
199222

200223
# Once all sub-agents are done, mark the ParallelAgent as final.
201-
if ctx.is_resumable and all(
202-
ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents
224+
if ctx.is_resumable and (
225+
escalated
226+
or all(ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents)
203227
):
204228
ctx.set_agent_state(self.name, end_of_agent=True)
205229
yield self._create_agent_state_event(ctx)

tests/unittests/agents/test_parallel_agent.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
"""Tests for the ParallelAgent."""
1616

1717
import asyncio
18+
from types import SimpleNamespace
1819
from typing import AsyncGenerator
1920

21+
from google.adk.agents import parallel_agent as parallel_agent_module
2022
from google.adk.agents.base_agent import BaseAgent
2123
from google.adk.agents.base_agent import BaseAgentState
2224
from google.adk.agents.invocation_context import InvocationContext
@@ -25,6 +27,7 @@
2527
from google.adk.agents.sequential_agent import SequentialAgentState
2628
from google.adk.apps.app import ResumabilityConfig
2729
from google.adk.events.event import Event
30+
from google.adk.events.event_actions import EventActions
2831
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2932
from google.genai import types
3033
import pytest
@@ -36,14 +39,21 @@ class _TestingAgent(BaseAgent):
3639
delay: float = 0
3740
"""The delay before the agent generates an event."""
3841

39-
def event(self, ctx: InvocationContext):
42+
def event(
43+
self,
44+
ctx: InvocationContext,
45+
*,
46+
text: str | None = None,
47+
actions: EventActions | None = None,
48+
):
4049
return Event(
4150
author=self.name,
4251
branch=ctx.branch,
4352
invocation_id=ctx.invocation_id,
4453
content=types.Content(
45-
parts=[types.Part(text=f'Hello, async {self.name}!')]
54+
parts=[types.Part(text=text or f'Hello, async {self.name}!')]
4655
),
56+
actions=actions if actions is not None else EventActions(),
4757
)
4858

4959
@override
@@ -342,6 +352,22 @@ async def _run_async_impl(
342352
yield self.event(ctx)
343353

344354

355+
class _TestingAgentWithEscalateAction(_TestingAgent):
356+
"""Mock agent for testing escalation short-circuit behavior."""
357+
358+
@override
359+
async def _run_async_impl(
360+
self, ctx: InvocationContext
361+
) -> AsyncGenerator[Event, None]:
362+
await asyncio.sleep(self.delay)
363+
yield self.event(
364+
ctx,
365+
text=f'Escalating from {self.name}!',
366+
actions=EventActions(escalate=True),
367+
)
368+
yield self.event(ctx, text='This event should be cancelled after escalation.')
369+
370+
345371
@pytest.mark.asyncio
346372
async def test_stop_agent_if_sub_agent_fails(
347373
request: pytest.FixtureRequest,
@@ -373,3 +399,84 @@ async def test_stop_agent_if_sub_agent_fails(
373399
async for _ in agen:
374400
# The infinite agent could iterate a few times depending on scheduling.
375401
pass
402+
403+
404+
@pytest.mark.asyncio
405+
@pytest.mark.parametrize('is_resumable', [True, False])
406+
@pytest.mark.parametrize('use_pre_3_11_merge', [False, True])
407+
async def test_run_async_short_circuits_other_agents_on_escalate_action(
408+
request: pytest.FixtureRequest,
409+
monkeypatch: pytest.MonkeyPatch,
410+
is_resumable: bool,
411+
use_pre_3_11_merge: bool,
412+
):
413+
if use_pre_3_11_merge:
414+
monkeypatch.setattr(
415+
parallel_agent_module,
416+
'sys',
417+
SimpleNamespace(version_info=(3, 10)),
418+
)
419+
420+
fast_agent = _TestingAgent(
421+
name=f'{request.function.__name__}_test_fast_agent',
422+
delay=0.05,
423+
)
424+
escalating_agent = _TestingAgentWithEscalateAction(
425+
name=f'{request.function.__name__}_test_escalating_agent',
426+
delay=0.1,
427+
)
428+
slow_agent = _TestingAgent(
429+
name=f'{request.function.__name__}_test_slow_agent',
430+
delay=0.5,
431+
)
432+
parallel_agent = ParallelAgent(
433+
name=f'{request.function.__name__}_test_parallel_agent',
434+
sub_agents=[fast_agent, escalating_agent, slow_agent],
435+
)
436+
parent_ctx = await _create_parent_invocation_context(
437+
request.function.__name__, parallel_agent, is_resumable=is_resumable
438+
)
439+
440+
events = [e async for e in parallel_agent.run_async(parent_ctx)]
441+
442+
assert all(event.author != slow_agent.name for event in events)
443+
assert all(
444+
not event.content
445+
or not event.content.parts
446+
or event.content.parts[0].text
447+
!= 'This event should be cancelled after escalation.'
448+
for event in events
449+
)
450+
451+
if is_resumable:
452+
assert len(events) == 4
453+
454+
assert events[0].author == parallel_agent.name
455+
assert not events[0].actions.end_of_agent
456+
457+
assert events[1].author == fast_agent.name
458+
assert events[1].branch == f'{parallel_agent.name}.{fast_agent.name}'
459+
assert events[1].content.parts[0].text == f'Hello, async {fast_agent.name}!'
460+
461+
assert events[2].author == escalating_agent.name
462+
assert events[2].branch == f'{parallel_agent.name}.{escalating_agent.name}'
463+
assert events[2].content.parts[0].text == (
464+
f'Escalating from {escalating_agent.name}!'
465+
)
466+
assert events[2].actions.escalate
467+
468+
assert events[3].author == parallel_agent.name
469+
assert events[3].actions.end_of_agent
470+
else:
471+
assert len(events) == 2
472+
473+
assert events[0].author == fast_agent.name
474+
assert events[0].branch == f'{parallel_agent.name}.{fast_agent.name}'
475+
assert events[0].content.parts[0].text == f'Hello, async {fast_agent.name}!'
476+
477+
assert events[1].author == escalating_agent.name
478+
assert events[1].branch == f'{parallel_agent.name}.{escalating_agent.name}'
479+
assert events[1].content.parts[0].text == (
480+
f'Escalating from {escalating_agent.name}!'
481+
)
482+
assert events[1].actions.escalate

tests/unittests/runners/test_resume_invocation.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,20 @@
1313
# limitations under the License.
1414
"""Tests for edge cases of resuming invocations."""
1515

16+
import asyncio
1617
import copy
18+
from typing import AsyncGenerator
1719

20+
from google.adk.agents.base_agent import BaseAgent
21+
from google.adk.agents.invocation_context import InvocationContext
1822
from google.adk.agents.llm_agent import LlmAgent
23+
from google.adk.agents.parallel_agent import ParallelAgent
1924
from google.adk.apps.app import App
2025
from google.adk.apps.app import ResumabilityConfig
26+
from google.adk.events.event import Event
27+
from google.adk.events.event_actions import EventActions
2128
from google.adk.tools.long_running_tool import LongRunningFunctionTool
29+
from google.genai import types
2230
from google.genai.types import FunctionResponse
2331
from google.genai.types import Part
2432
import pytest
@@ -41,6 +49,43 @@ def test_tool() -> dict[str, str]:
4149
return {"result": "test tool result"}
4250

4351

52+
test_tool.__test__ = False
53+
54+
55+
class _ParallelEscalationTestingAgent(BaseAgent):
56+
"""A testing agent that emits a single event after a delay."""
57+
58+
delay: float = 0
59+
response_text: str = ""
60+
escalate: bool = False
61+
emit_follow_up_after_first_event: bool = False
62+
63+
def _create_event(
64+
self,
65+
ctx: InvocationContext,
66+
text: str,
67+
*,
68+
escalate: bool = False,
69+
) -> Event:
70+
return Event(
71+
author=self.name,
72+
branch=ctx.branch,
73+
invocation_id=ctx.invocation_id,
74+
content=types.Content(role="model", parts=[types.Part(text=text)]),
75+
actions=EventActions(escalate=True) if escalate else EventActions(),
76+
)
77+
78+
async def _run_async_impl(
79+
self, ctx: InvocationContext
80+
) -> AsyncGenerator[Event, None]:
81+
await asyncio.sleep(self.delay)
82+
yield self._create_event(
83+
ctx, self.response_text, escalate=self.escalate
84+
)
85+
if self.emit_follow_up_after_first_event:
86+
yield self._create_event(ctx, "This event should not be emitted.")
87+
88+
4489
@pytest.mark.asyncio
4590
async def test_resume_invocation_from_sub_agent():
4691
"""A test case for an edge case, where an invocation-to-resume starts from a sub-agent.
@@ -252,3 +297,68 @@ async def test_resume_any_invocation():
252297
),
253298
(root_agent.name, testing_utils.END_OF_AGENT),
254299
]
300+
301+
302+
@pytest.mark.asyncio
303+
async def test_resumable_parallel_agent_escalation_short_circuits_persisted_run():
304+
"""Runner persists fast+escalating events and marks the parent run complete."""
305+
fast_agent = _ParallelEscalationTestingAgent(
306+
name="fast_agent",
307+
delay=0.05,
308+
response_text="fast response",
309+
)
310+
escalating_agent = _ParallelEscalationTestingAgent(
311+
name="escalating_agent",
312+
delay=0.1,
313+
response_text="escalating response",
314+
escalate=True,
315+
emit_follow_up_after_first_event=True,
316+
)
317+
slow_agent = _ParallelEscalationTestingAgent(
318+
name="slow_agent",
319+
delay=0.5,
320+
response_text="slow response",
321+
)
322+
runner = testing_utils.InMemoryRunner(
323+
app=App(
324+
name="test_app",
325+
root_agent=ParallelAgent(
326+
name="root_agent",
327+
sub_agents=[fast_agent, escalating_agent, slow_agent],
328+
),
329+
resumability_config=ResumabilityConfig(is_resumable=True),
330+
)
331+
)
332+
333+
invocation_events = await runner.run_async("test user query")
334+
simplified_events = testing_utils.simplify_resumable_app_events(
335+
copy.deepcopy(invocation_events)
336+
)
337+
338+
assert simplified_events == [
339+
("root_agent", {}),
340+
("fast_agent", "fast response"),
341+
("escalating_agent", "escalating response"),
342+
("root_agent", testing_utils.END_OF_AGENT),
343+
]
344+
345+
session = await runner.runner.session_service.get_session(
346+
app_name=runner.app_name,
347+
user_id="test_user",
348+
session_id=runner.session_id,
349+
)
350+
persisted_events = [
351+
event
352+
for event in session.events
353+
if event.invocation_id == invocation_events[0].invocation_id
354+
and event.author != "user"
355+
]
356+
assert testing_utils.simplify_resumable_app_events(
357+
copy.deepcopy(persisted_events)
358+
) == simplified_events
359+
assert all(event.author != "slow_agent" for event in persisted_events)
360+
361+
# A completed resumable invocation should not restart cancelled siblings.
362+
assert not await runner.run_async(
363+
invocation_id=invocation_events[0].invocation_id
364+
)

0 commit comments

Comments
 (0)