1515"""Tests for the ParallelAgent."""
1616
1717import asyncio
18+ from types import SimpleNamespace
1819from typing import AsyncGenerator
1920
21+ from google .adk .agents import parallel_agent as parallel_agent_module
2022from google .adk .agents .base_agent import BaseAgent
2123from google .adk .agents .base_agent import BaseAgentState
2224from google .adk .agents .invocation_context import InvocationContext
2527from google .adk .agents .sequential_agent import SequentialAgentState
2628from google .adk .apps .app import ResumabilityConfig
2729from google .adk .events .event import Event
30+ from google .adk .events .event_actions import EventActions
2831from google .adk .sessions .in_memory_session_service import InMemorySessionService
2932from google .genai import types
3033import pytest
@@ -36,14 +39,21 @@ class _TestingAgent(BaseAgent):
3639 delay : float = 0
3740 """The delay before the agent generates an event."""
3841
39- def event (self , ctx : InvocationContext ):
42+ def event (
43+ self ,
44+ ctx : InvocationContext ,
45+ * ,
46+ text : str | None = None ,
47+ actions : EventActions | None = None ,
48+ ):
4049 return Event (
4150 author = self .name ,
4251 branch = ctx .branch ,
4352 invocation_id = ctx .invocation_id ,
4453 content = types .Content (
45- parts = [types .Part (text = f'Hello, async { self .name } !' )]
54+ parts = [types .Part (text = text or f'Hello, async { self .name } !' )]
4655 ),
56+ actions = actions if actions is not None else EventActions (),
4757 )
4858
4959 @override
@@ -342,6 +352,24 @@ async def _run_async_impl(
342352 yield self .event (ctx )
343353
344354
355+ class _TestingAgentWithEscalateAction (_TestingAgent ):
356+ """Mock agent for testing escalation short-circuit behavior."""
357+
358+ @override
359+ async def _run_async_impl (
360+ self , ctx : InvocationContext
361+ ) -> AsyncGenerator [Event , None ]:
362+ await asyncio .sleep (self .delay )
363+ yield self .event (
364+ ctx ,
365+ text = f'Escalating from { self .name } !' ,
366+ actions = EventActions (escalate = True ),
367+ )
368+ yield self .event (
369+ ctx , text = 'This event should be cancelled after escalation.'
370+ )
371+
372+
345373@pytest .mark .asyncio
346374async def test_stop_agent_if_sub_agent_fails (
347375 request : pytest .FixtureRequest ,
@@ -373,3 +401,84 @@ async def test_stop_agent_if_sub_agent_fails(
373401 async for _ in agen :
374402 # The infinite agent could iterate a few times depending on scheduling.
375403 pass
404+
405+
406+ @pytest .mark .asyncio
407+ @pytest .mark .parametrize ('is_resumable' , [True , False ])
408+ @pytest .mark .parametrize ('use_pre_3_11_merge' , [False , True ])
409+ async def test_run_async_short_circuits_other_agents_on_escalate_action (
410+ request : pytest .FixtureRequest ,
411+ monkeypatch : pytest .MonkeyPatch ,
412+ is_resumable : bool ,
413+ use_pre_3_11_merge : bool ,
414+ ):
415+ if use_pre_3_11_merge :
416+ monkeypatch .setattr (
417+ parallel_agent_module ,
418+ 'sys' ,
419+ SimpleNamespace (version_info = (3 , 10 )),
420+ )
421+
422+ fast_agent = _TestingAgent (
423+ name = f'{ request .function .__name__ } _test_fast_agent' ,
424+ delay = 0.05 ,
425+ )
426+ escalating_agent = _TestingAgentWithEscalateAction (
427+ name = f'{ request .function .__name__ } _test_escalating_agent' ,
428+ delay = 0.1 ,
429+ )
430+ slow_agent = _TestingAgent (
431+ name = f'{ request .function .__name__ } _test_slow_agent' ,
432+ delay = 0.5 ,
433+ )
434+ parallel_agent = ParallelAgent (
435+ name = f'{ request .function .__name__ } _test_parallel_agent' ,
436+ sub_agents = [fast_agent , escalating_agent , slow_agent ],
437+ )
438+ parent_ctx = await _create_parent_invocation_context (
439+ request .function .__name__ , parallel_agent , is_resumable = is_resumable
440+ )
441+
442+ events = [e async for e in parallel_agent .run_async (parent_ctx )]
443+
444+ assert all (event .author != slow_agent .name for event in events )
445+ assert all (
446+ not event .content
447+ or not event .content .parts
448+ or event .content .parts [0 ].text
449+ != 'This event should be cancelled after escalation.'
450+ for event in events
451+ )
452+
453+ if is_resumable :
454+ assert len (events ) == 4
455+
456+ assert events [0 ].author == parallel_agent .name
457+ assert not events [0 ].actions .end_of_agent
458+
459+ assert events [1 ].author == fast_agent .name
460+ assert events [1 ].branch == f'{ parallel_agent .name } .{ fast_agent .name } '
461+ assert events [1 ].content .parts [0 ].text == f'Hello, async { fast_agent .name } !'
462+
463+ assert events [2 ].author == escalating_agent .name
464+ assert events [2 ].branch == f'{ parallel_agent .name } .{ escalating_agent .name } '
465+ assert events [2 ].content .parts [0 ].text == (
466+ f'Escalating from { escalating_agent .name } !'
467+ )
468+ assert events [2 ].actions .escalate
469+
470+ assert events [3 ].author == parallel_agent .name
471+ assert events [3 ].actions .end_of_agent
472+ else :
473+ assert len (events ) == 2
474+
475+ assert events [0 ].author == fast_agent .name
476+ assert events [0 ].branch == f'{ parallel_agent .name } .{ fast_agent .name } '
477+ assert events [0 ].content .parts [0 ].text == f'Hello, async { fast_agent .name } !'
478+
479+ assert events [1 ].author == escalating_agent .name
480+ assert events [1 ].branch == f'{ parallel_agent .name } .{ escalating_agent .name } '
481+ assert events [1 ].content .parts [0 ].text == (
482+ f'Escalating from { escalating_agent .name } !'
483+ )
484+ assert events [1 ].actions .escalate
0 commit comments