1515"""Tests for the ParallelAgent."""
1616
1717import asyncio
18+ from types import SimpleNamespace
1819from typing import AsyncGenerator
1920
21+ from google .adk .agents import parallel_agent as parallel_agent_module
2022from google .adk .agents .base_agent import BaseAgent
2123from google .adk .agents .base_agent import BaseAgentState
2224from google .adk .agents .invocation_context import InvocationContext
2527from google .adk .agents .sequential_agent import SequentialAgentState
2628from google .adk .apps .app import ResumabilityConfig
2729from google .adk .events .event import Event
30+ from google .adk .events .event_actions import EventActions
2831from google .adk .sessions .in_memory_session_service import InMemorySessionService
2932from google .genai import types
3033import pytest
@@ -36,14 +39,21 @@ class _TestingAgent(BaseAgent):
3639 delay : float = 0
3740 """The delay before the agent generates an event."""
3841
39- def event (self , ctx : InvocationContext ):
42+ def event (
43+ self ,
44+ ctx : InvocationContext ,
45+ * ,
46+ text : str | None = None ,
47+ actions : EventActions | None = None ,
48+ ):
4049 return Event (
4150 author = self .name ,
4251 branch = ctx .branch ,
4352 invocation_id = ctx .invocation_id ,
4453 content = types .Content (
45- parts = [types .Part (text = f'Hello, async { self .name } !' )]
54+ parts = [types .Part (text = text or f'Hello, async { self .name } !' )]
4655 ),
56+ actions = actions if actions is not None else EventActions (),
4757 )
4858
4959 @override
@@ -342,6 +352,22 @@ async def _run_async_impl(
342352 yield self .event (ctx )
343353
344354
355+ class _TestingAgentWithEscalateAction (_TestingAgent ):
356+ """Mock agent for testing escalation short-circuit behavior."""
357+
358+ @override
359+ async def _run_async_impl (
360+ self , ctx : InvocationContext
361+ ) -> AsyncGenerator [Event , None ]:
362+ await asyncio .sleep (self .delay )
363+ yield self .event (
364+ ctx ,
365+ text = f'Escalating from { self .name } !' ,
366+ actions = EventActions (escalate = True ),
367+ )
368+ yield self .event (ctx , text = 'This event should be cancelled after escalation.' )
369+
370+
345371@pytest .mark .asyncio
346372async def test_stop_agent_if_sub_agent_fails (
347373 request : pytest .FixtureRequest ,
@@ -373,3 +399,84 @@ async def test_stop_agent_if_sub_agent_fails(
373399 async for _ in agen :
374400 # The infinite agent could iterate a few times depending on scheduling.
375401 pass
402+
403+
404+ @pytest .mark .asyncio
405+ @pytest .mark .parametrize ('is_resumable' , [True , False ])
406+ @pytest .mark .parametrize ('use_pre_3_11_merge' , [False , True ])
407+ async def test_run_async_short_circuits_other_agents_on_escalate_action (
408+ request : pytest .FixtureRequest ,
409+ monkeypatch : pytest .MonkeyPatch ,
410+ is_resumable : bool ,
411+ use_pre_3_11_merge : bool ,
412+ ):
413+ if use_pre_3_11_merge :
414+ monkeypatch .setattr (
415+ parallel_agent_module ,
416+ 'sys' ,
417+ SimpleNamespace (version_info = (3 , 10 )),
418+ )
419+
420+ fast_agent = _TestingAgent (
421+ name = f'{ request .function .__name__ } _test_fast_agent' ,
422+ delay = 0.05 ,
423+ )
424+ escalating_agent = _TestingAgentWithEscalateAction (
425+ name = f'{ request .function .__name__ } _test_escalating_agent' ,
426+ delay = 0.1 ,
427+ )
428+ slow_agent = _TestingAgent (
429+ name = f'{ request .function .__name__ } _test_slow_agent' ,
430+ delay = 0.5 ,
431+ )
432+ parallel_agent = ParallelAgent (
433+ name = f'{ request .function .__name__ } _test_parallel_agent' ,
434+ sub_agents = [fast_agent , escalating_agent , slow_agent ],
435+ )
436+ parent_ctx = await _create_parent_invocation_context (
437+ request .function .__name__ , parallel_agent , is_resumable = is_resumable
438+ )
439+
440+ events = [e async for e in parallel_agent .run_async (parent_ctx )]
441+
442+ assert all (event .author != slow_agent .name for event in events )
443+ assert all (
444+ not event .content
445+ or not event .content .parts
446+ or event .content .parts [0 ].text
447+ != 'This event should be cancelled after escalation.'
448+ for event in events
449+ )
450+
451+ if is_resumable :
452+ assert len (events ) == 4
453+
454+ assert events [0 ].author == parallel_agent .name
455+ assert not events [0 ].actions .end_of_agent
456+
457+ assert events [1 ].author == fast_agent .name
458+ assert events [1 ].branch == f'{ parallel_agent .name } .{ fast_agent .name } '
459+ assert events [1 ].content .parts [0 ].text == f'Hello, async { fast_agent .name } !'
460+
461+ assert events [2 ].author == escalating_agent .name
462+ assert events [2 ].branch == f'{ parallel_agent .name } .{ escalating_agent .name } '
463+ assert events [2 ].content .parts [0 ].text == (
464+ f'Escalating from { escalating_agent .name } !'
465+ )
466+ assert events [2 ].actions .escalate
467+
468+ assert events [3 ].author == parallel_agent .name
469+ assert events [3 ].actions .end_of_agent
470+ else :
471+ assert len (events ) == 2
472+
473+ assert events [0 ].author == fast_agent .name
474+ assert events [0 ].branch == f'{ parallel_agent .name } .{ fast_agent .name } '
475+ assert events [0 ].content .parts [0 ].text == f'Hello, async { fast_agent .name } !'
476+
477+ assert events [1 ].author == escalating_agent .name
478+ assert events [1 ].branch == f'{ parallel_agent .name } .{ escalating_agent .name } '
479+ assert events [1 ].content .parts [0 ].text == (
480+ f'Escalating from { escalating_agent .name } !'
481+ )
482+ assert events [1 ].actions .escalate
0 commit comments