1717from __future__ import annotations
1818
1919import asyncio
20+ import logging
2021import sys
2122from typing import AsyncGenerator
2223from typing import ClassVar
24+ from typing import Optional
2325
2426from typing_extensions import override
2527
3133from .invocation_context import InvocationContext
3234from .parallel_agent_config import ParallelAgentConfig
3335
36+ logger = logging .getLogger ('google_adk.' + __name__ )
37+
3438
3539def _create_branch_ctx_for_sub_agent (
3640 agent : BaseAgent ,
@@ -48,29 +52,75 @@ def _create_branch_ctx_for_sub_agent(
4852 return invocation_context
4953
5054
55+ async def _iter_with_idle_timeout (
56+ gen : AsyncGenerator [Event , None ],
57+ timeout_secs : float ,
58+ branch_name : str ,
59+ ) -> AsyncGenerator [Event , None ]:
60+ """Wrap *gen*, raising TimeoutError if no event arrives within *timeout_secs*.
61+
62+ Uses asyncio.wait_for on each __anext__ call so that a branch whose upstream
63+ model stream silently stalls (connection open, no chunks) is detected and
64+ cancelled rather than hanging the parent ParallelAgent indefinitely.
65+ """
66+ while True :
67+ try :
68+ event = await asyncio .wait_for (
69+ gen .__anext__ (),
70+ timeout = timeout_secs ,
71+ )
72+ except StopAsyncIteration :
73+ return
74+ except asyncio .TimeoutError as exc :
75+ logger .warning (
76+ 'ParallelAgent branch %r has not produced an event for %.1fs. '
77+ 'The upstream model stream may have stalled. Cancelling the branch '
78+ 'so the parent agent is not blocked indefinitely.' ,
79+ branch_name ,
80+ timeout_secs ,
81+ )
82+ raise asyncio .TimeoutError (
83+ f'Branch { branch_name !r} idle for >{ timeout_secs } s without an event'
84+ ) from exc
85+ yield event
86+
87+
5188async def _merge_agent_run (
5289 agent_runs : list [AsyncGenerator [Event , None ]],
90+ * ,
91+ branch_names : list [str ] | None = None ,
92+ branch_idle_timeout_secs : float | None = None ,
5393) -> AsyncGenerator [Event , None ]:
5494 """Merges agent runs using asyncio.TaskGroup on Python 3.11+."""
5595 sentinel = object ()
56- queue = asyncio .Queue ()
96+ queue : asyncio .Queue = asyncio .Queue ()
97+ names = branch_names or [f'branch-{ i } ' for i in range (len (agent_runs ))]
5798
5899 # Agents are processed in parallel.
59100 # Events for each agent are put on queue sequentially.
60- async def process_an_agent (events_for_one_agent ):
101+ async def process_an_agent (events_for_one_agent , branch_name : str ):
61102 try :
62- async for event in events_for_one_agent :
103+ gen = (
104+ _iter_with_idle_timeout (
105+ events_for_one_agent , branch_idle_timeout_secs , branch_name
106+ )
107+ if branch_idle_timeout_secs is not None
108+ else events_for_one_agent
109+ )
110+ async for event in gen :
63111 resume_signal = asyncio .Event ()
64- await queue .put ((event , resume_signal ))
112+ # put_nowait: the queue is unbounded so this never blocks, and it is
113+ # safe to call from a finally block that may run during cancellation.
114+ queue .put_nowait ((event , resume_signal ))
65115 # Wait for upstream to consume event before generating new events.
66116 await resume_signal .wait ()
67117 finally :
68- # Mark agent as finished.
69- await queue .put ((sentinel , None ))
118+ # Mark agent as finished. put_nowait is cancellation-safe (see above).
119+ queue .put_nowait ((sentinel , None ))
70120
71121 async with asyncio .TaskGroup () as tg :
72- for events_for_one_agent in agent_runs :
73- tg .create_task (process_an_agent (events_for_one_agent ))
122+ for events_for_one_agent , name in zip ( agent_runs , names ) :
123+ tg .create_task (process_an_agent (events_for_one_agent , name ))
74124
75125 sentinel_count = 0
76126 # Run until all agents finished processing.
@@ -88,6 +138,9 @@ async def process_an_agent(events_for_one_agent):
88138# TODO - remove once Python <3.11 is no longer supported.
89139async def _merge_agent_run_pre_3_11 (
90140 agent_runs : list [AsyncGenerator [Event , None ]],
141+ * ,
142+ branch_names : list [str ] | None = None ,
143+ branch_idle_timeout_secs : float | None = None ,
91144) -> AsyncGenerator [Event , None ]:
92145 """Merges agent runs for Python 3.10 without asyncio.TaskGroup.
93146
@@ -96,12 +149,16 @@ async def _merge_agent_run_pre_3_11(
96149
97150 Args:
98151 agent_runs: Async generators that yield events from each agent.
152+ branch_names: Optional names for each branch, used in log messages.
153+ branch_idle_timeout_secs: If set, cancel a branch that produces no event
154+ for this many seconds (guards against silently stalled model streams).
99155
100156 Yields:
101157 Event: The next event from the merged generator.
102158 """
103159 sentinel = object ()
104- queue = asyncio .Queue ()
160+ queue : asyncio .Queue = asyncio .Queue ()
161+ names = branch_names or [f'branch-{ i } ' for i in range (len (agent_runs ))]
105162
106163 def propagate_exceptions (tasks ):
107164 # Propagate exceptions and errors from tasks.
@@ -113,21 +170,31 @@ def propagate_exceptions(tasks):
113170
114171 # Agents are processed in parallel.
115172 # Events for each agent are put on queue sequentially.
116- async def process_an_agent (events_for_one_agent ):
173+ async def process_an_agent (events_for_one_agent , branch_name : str ):
117174 try :
118- async for event in events_for_one_agent :
175+ gen = (
176+ _iter_with_idle_timeout (
177+ events_for_one_agent , branch_idle_timeout_secs , branch_name
178+ )
179+ if branch_idle_timeout_secs is not None
180+ else events_for_one_agent
181+ )
182+ async for event in gen :
119183 resume_signal = asyncio .Event ()
120- await queue .put ((event , resume_signal ))
184+ queue .put_nowait ((event , resume_signal ))
121185 # Wait for upstream to consume event before generating new events.
122186 await resume_signal .wait ()
123187 finally :
124- # Mark agent as finished.
125- await queue .put ((sentinel , None ))
188+ # put_nowait is cancellation-safe: the queue is unbounded so it never
189+ # blocks, and it will not raise even if the task is being cancelled.
190+ queue .put_nowait ((sentinel , None ))
126191
127192 tasks = []
128193 try :
129- for events_for_one_agent in agent_runs :
130- tasks .append (asyncio .create_task (process_an_agent (events_for_one_agent )))
194+ for events_for_one_agent , name in zip (agent_runs , names ):
195+ tasks .append (
196+ asyncio .create_task (process_an_agent (events_for_one_agent , name ))
197+ )
131198
132199 sentinel_count = 0
133200 # Run until all agents finished processing.
@@ -142,6 +209,10 @@ async def process_an_agent(events_for_one_agent):
142209 # Signal to agent that event has been processed by runner and it can
143210 # continue now.
144211 resume_signal .set ()
212+ # A task may have put its sentinel AND raised an exception; if the sentinel
213+ # was consumed before propagate_exceptions ran in the loop body, the error
214+ # would be silently lost. Check once more after the loop to surface it.
215+ propagate_exceptions (tasks )
145216 finally :
146217 for task in tasks :
147218 task .cancel ()
@@ -155,11 +226,23 @@ class ParallelAgent(BaseAgent):
155226
156227 - Running different algorithms simultaneously.
157228 - Generating multiple responses for review by a subsequent evaluation agent.
229+
230+ Attributes:
231+ branch_idle_timeout_secs: Optional per-branch idle timeout in seconds.
232+ When set, any branch that produces no event for this many seconds is
233+ cancelled and raises ``asyncio.TimeoutError``, which unblocks the
234+ parent agent instead of hanging indefinitely. This guards against
235+ upstream model streams that stall silently (connection open, no
236+ chunks arriving). ``None`` (the default) disables the timeout and
237+ preserves the original unbounded-wait behaviour.
158238 """
159239
160240 config_type : ClassVar [type [BaseAgentConfig ]] = ParallelAgentConfig
161241 """The config type for this agent."""
162242
243+ branch_idle_timeout_secs : Optional [float ] = None
244+ """Per-branch idle timeout in seconds; None disables the guard."""
245+
163246 @override
164247 async def _run_async_impl (
165248 self , ctx : InvocationContext
@@ -173,13 +256,15 @@ async def _run_async_impl(
173256 yield self ._create_agent_state_event (ctx )
174257
175258 agent_runs = []
259+ branch_names = []
176260 # Prepare and collect async generators for each sub-agent.
177261 for sub_agent in self .sub_agents :
178262 sub_agent_ctx = _create_branch_ctx_for_sub_agent (self , sub_agent , ctx )
179263
180264 # Only include sub-agents that haven't finished in a previous run.
181265 if not sub_agent_ctx .end_of_agents .get (sub_agent .name ):
182266 agent_runs .append (sub_agent .run_async (sub_agent_ctx ))
267+ branch_names .append (sub_agent .name )
183268
184269 pause_invocation = False
185270 try :
@@ -188,7 +273,13 @@ async def _run_async_impl(
188273 if sys .version_info >= (3 , 11 )
189274 else _merge_agent_run_pre_3_11
190275 )
191- async with Aclosing (merge_func (agent_runs )) as agen :
276+ async with Aclosing (
277+ merge_func (
278+ agent_runs ,
279+ branch_names = branch_names ,
280+ branch_idle_timeout_secs = self .branch_idle_timeout_secs ,
281+ )
282+ ) as agen :
192283 async for event in agen :
193284 yield event
194285 if ctx .should_pause_invocation (event ):
0 commit comments