Skip to content

Commit e452a8e

Browse files
committed
fix(agents): add branch_idle_timeout_secs to ParallelAgent to detect silent stalls
When an upstream model stream (LiteLLM/vLLM or any backend) silently stops producing chunks without closing the connection, process_an_agent in _merge_agent_run blocks forever on __anext__() with no exception or EOF. The queue.get() in the merge loop then also waits indefinitely, leaving the /run_sse HTTP stream open until an external gateway times out (issue #5455). Root cause: there is no per-chunk idle timeout anywhere in the parallel branch merge loop, so a silently stalled model connection is indistinguishable from a legitimately slow one. Fix: add _iter_with_idle_timeout(), which wraps each __anext__() call with asyncio.wait_for(). If no event arrives within branch_idle_timeout_secs, it logs a warning and raises asyncio.TimeoutError, unblocking both the branch task and the merge loop so the SSE stream can close with an explicit error rather than hanging. Additional fix (pre-3.11 path): _merge_agent_run_pre_3_11 could silently swallow a branch exception when the branch put its sentinel into the queue before propagate_exceptions() ran. Added a final propagate_exceptions() call after the while loop exits to catch that race. Also: changed await queue.put(sentinel) → queue.put_nowait(sentinel) in all finally blocks. The queue is unbounded so put_nowait never blocks, and it cannot be interrupted by CancelledError during task cancellation. Usage: ParallelAgent( name='recommend', sub_agents=[workout_agent, sleep_agent], branch_idle_timeout_secs=120.0, # raise if no event for 2 min ) Fixes #5455
1 parent 0928d23 commit e452a8e

2 files changed

Lines changed: 540 additions & 17 deletions

File tree

src/google/adk/agents/parallel_agent.py

Lines changed: 108 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from __future__ import annotations
1818

1919
import asyncio
20+
import logging
2021
import sys
2122
from typing import AsyncGenerator
2223
from typing import ClassVar
24+
from typing import Optional
2325

2426
from typing_extensions import override
2527

@@ -31,6 +33,8 @@
3133
from .invocation_context import InvocationContext
3234
from .parallel_agent_config import ParallelAgentConfig
3335

36+
logger = logging.getLogger('google_adk.' + __name__)
37+
3438

3539
def _create_branch_ctx_for_sub_agent(
3640
agent: BaseAgent,
@@ -48,29 +52,75 @@ def _create_branch_ctx_for_sub_agent(
4852
return invocation_context
4953

5054

55+
async def _iter_with_idle_timeout(
56+
gen: AsyncGenerator[Event, None],
57+
timeout_secs: float,
58+
branch_name: str,
59+
) -> AsyncGenerator[Event, None]:
60+
"""Wrap *gen*, raising TimeoutError if no event arrives within *timeout_secs*.
61+
62+
Uses asyncio.wait_for on each __anext__ call so that a branch whose upstream
63+
model stream silently stalls (connection open, no chunks) is detected and
64+
cancelled rather than hanging the parent ParallelAgent indefinitely.
65+
"""
66+
while True:
67+
try:
68+
event = await asyncio.wait_for(
69+
gen.__anext__(),
70+
timeout=timeout_secs,
71+
)
72+
except StopAsyncIteration:
73+
return
74+
except asyncio.TimeoutError as exc:
75+
logger.warning(
76+
'ParallelAgent branch %r has not produced an event for %.1fs. '
77+
'The upstream model stream may have stalled. Cancelling the branch '
78+
'so the parent agent is not blocked indefinitely.',
79+
branch_name,
80+
timeout_secs,
81+
)
82+
raise asyncio.TimeoutError(
83+
f'Branch {branch_name!r} idle for >{timeout_secs}s without an event'
84+
) from exc
85+
yield event
86+
87+
5188
async def _merge_agent_run(
5289
agent_runs: list[AsyncGenerator[Event, None]],
90+
*,
91+
branch_names: list[str] | None = None,
92+
branch_idle_timeout_secs: float | None = None,
5393
) -> AsyncGenerator[Event, None]:
5494
"""Merges agent runs using asyncio.TaskGroup on Python 3.11+."""
5595
sentinel = object()
56-
queue = asyncio.Queue()
96+
queue: asyncio.Queue = asyncio.Queue()
97+
names = branch_names or [f'branch-{i}' for i in range(len(agent_runs))]
5798

5899
# Agents are processed in parallel.
59100
# Events for each agent are put on queue sequentially.
60-
async def process_an_agent(events_for_one_agent):
101+
async def process_an_agent(events_for_one_agent, branch_name: str):
61102
try:
62-
async for event in events_for_one_agent:
103+
gen = (
104+
_iter_with_idle_timeout(
105+
events_for_one_agent, branch_idle_timeout_secs, branch_name
106+
)
107+
if branch_idle_timeout_secs is not None
108+
else events_for_one_agent
109+
)
110+
async for event in gen:
63111
resume_signal = asyncio.Event()
64-
await queue.put((event, resume_signal))
112+
# put_nowait: the queue is unbounded so this never blocks, and it is
113+
# safe to call from a finally block that may run during cancellation.
114+
queue.put_nowait((event, resume_signal))
65115
# Wait for upstream to consume event before generating new events.
66116
await resume_signal.wait()
67117
finally:
68-
# Mark agent as finished.
69-
await queue.put((sentinel, None))
118+
# Mark agent as finished. put_nowait is cancellation-safe (see above).
119+
queue.put_nowait((sentinel, None))
70120

71121
async with asyncio.TaskGroup() as tg:
72-
for events_for_one_agent in agent_runs:
73-
tg.create_task(process_an_agent(events_for_one_agent))
122+
for events_for_one_agent, name in zip(agent_runs, names):
123+
tg.create_task(process_an_agent(events_for_one_agent, name))
74124

75125
sentinel_count = 0
76126
# Run until all agents finished processing.
@@ -88,6 +138,9 @@ async def process_an_agent(events_for_one_agent):
88138
# TODO - remove once Python <3.11 is no longer supported.
89139
async def _merge_agent_run_pre_3_11(
90140
agent_runs: list[AsyncGenerator[Event, None]],
141+
*,
142+
branch_names: list[str] | None = None,
143+
branch_idle_timeout_secs: float | None = None,
91144
) -> AsyncGenerator[Event, None]:
92145
"""Merges agent runs for Python 3.10 without asyncio.TaskGroup.
93146
@@ -96,12 +149,16 @@ async def _merge_agent_run_pre_3_11(
96149
97150
Args:
98151
agent_runs: Async generators that yield events from each agent.
152+
branch_names: Optional names for each branch, used in log messages.
153+
branch_idle_timeout_secs: If set, cancel a branch that produces no event
154+
for this many seconds (guards against silently stalled model streams).
99155
100156
Yields:
101157
Event: The next event from the merged generator.
102158
"""
103159
sentinel = object()
104-
queue = asyncio.Queue()
160+
queue: asyncio.Queue = asyncio.Queue()
161+
names = branch_names or [f'branch-{i}' for i in range(len(agent_runs))]
105162

106163
def propagate_exceptions(tasks):
107164
# Propagate exceptions and errors from tasks.
@@ -113,21 +170,31 @@ def propagate_exceptions(tasks):
113170

114171
# Agents are processed in parallel.
115172
# Events for each agent are put on queue sequentially.
116-
async def process_an_agent(events_for_one_agent):
173+
async def process_an_agent(events_for_one_agent, branch_name: str):
117174
try:
118-
async for event in events_for_one_agent:
175+
gen = (
176+
_iter_with_idle_timeout(
177+
events_for_one_agent, branch_idle_timeout_secs, branch_name
178+
)
179+
if branch_idle_timeout_secs is not None
180+
else events_for_one_agent
181+
)
182+
async for event in gen:
119183
resume_signal = asyncio.Event()
120-
await queue.put((event, resume_signal))
184+
queue.put_nowait((event, resume_signal))
121185
# Wait for upstream to consume event before generating new events.
122186
await resume_signal.wait()
123187
finally:
124-
# Mark agent as finished.
125-
await queue.put((sentinel, None))
188+
# put_nowait is cancellation-safe: the queue is unbounded so it never
189+
# blocks, and it will not raise even if the task is being cancelled.
190+
queue.put_nowait((sentinel, None))
126191

127192
tasks = []
128193
try:
129-
for events_for_one_agent in agent_runs:
130-
tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent)))
194+
for events_for_one_agent, name in zip(agent_runs, names):
195+
tasks.append(
196+
asyncio.create_task(process_an_agent(events_for_one_agent, name))
197+
)
131198

132199
sentinel_count = 0
133200
# Run until all agents finished processing.
@@ -142,6 +209,10 @@ async def process_an_agent(events_for_one_agent):
142209
# Signal to agent that event has been processed by runner and it can
143210
# continue now.
144211
resume_signal.set()
212+
# A task may have put its sentinel AND raised an exception; if the sentinel
213+
# was consumed before propagate_exceptions ran in the loop body, the error
214+
# would be silently lost. Check once more after the loop to surface it.
215+
propagate_exceptions(tasks)
145216
finally:
146217
for task in tasks:
147218
task.cancel()
@@ -155,11 +226,23 @@ class ParallelAgent(BaseAgent):
155226
156227
- Running different algorithms simultaneously.
157228
- Generating multiple responses for review by a subsequent evaluation agent.
229+
230+
Attributes:
231+
branch_idle_timeout_secs: Optional per-branch idle timeout in seconds.
232+
When set, any branch that produces no event for this many seconds is
233+
cancelled and raises ``asyncio.TimeoutError``, which unblocks the
234+
parent agent instead of hanging indefinitely. This guards against
235+
upstream model streams that stall silently (connection open, no
236+
chunks arriving). ``None`` (the default) disables the timeout and
237+
preserves the original unbounded-wait behaviour.
158238
"""
159239

160240
config_type: ClassVar[type[BaseAgentConfig]] = ParallelAgentConfig
161241
"""The config type for this agent."""
162242

243+
branch_idle_timeout_secs: Optional[float] = None
244+
"""Per-branch idle timeout in seconds; None disables the guard."""
245+
163246
@override
164247
async def _run_async_impl(
165248
self, ctx: InvocationContext
@@ -173,13 +256,15 @@ async def _run_async_impl(
173256
yield self._create_agent_state_event(ctx)
174257

175258
agent_runs = []
259+
branch_names = []
176260
# Prepare and collect async generators for each sub-agent.
177261
for sub_agent in self.sub_agents:
178262
sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
179263

180264
# Only include sub-agents that haven't finished in a previous run.
181265
if not sub_agent_ctx.end_of_agents.get(sub_agent.name):
182266
agent_runs.append(sub_agent.run_async(sub_agent_ctx))
267+
branch_names.append(sub_agent.name)
183268

184269
pause_invocation = False
185270
try:
@@ -188,7 +273,13 @@ async def _run_async_impl(
188273
if sys.version_info >= (3, 11)
189274
else _merge_agent_run_pre_3_11
190275
)
191-
async with Aclosing(merge_func(agent_runs)) as agen:
276+
async with Aclosing(
277+
merge_func(
278+
agent_runs,
279+
branch_names=branch_names,
280+
branch_idle_timeout_secs=self.branch_idle_timeout_secs,
281+
)
282+
) as agen:
192283
async for event in agen:
193284
yield event
194285
if ctx.should_pause_invocation(event):

0 commit comments

Comments
 (0)