forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparallel_agent.py
More file actions
230 lines (191 loc) · 7.39 KB
/
parallel_agent.py
File metadata and controls
230 lines (191 loc) · 7.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parallel agent implementation."""
from __future__ import annotations
import asyncio
import sys
from typing import AsyncGenerator
from typing import ClassVar
from typing_extensions import override
from ..events.event import Event
from ..utils.context_utils import Aclosing
from .base_agent import BaseAgent
from .base_agent import BaseAgentState
from .base_agent_config import BaseAgentConfig
from .invocation_context import InvocationContext
from .parallel_agent_config import ParallelAgentConfig
def _create_branch_ctx_for_sub_agent(
agent: BaseAgent,
sub_agent: BaseAgent,
invocation_context: InvocationContext,
) -> InvocationContext:
"""Create isolated branch for every sub-agent."""
invocation_context = invocation_context.model_copy()
branch_suffix = f'{agent.name}.{sub_agent.name}'
invocation_context.branch = (
f'{invocation_context.branch}.{branch_suffix}'
if invocation_context.branch
else branch_suffix
)
return invocation_context
# TODO - remove once Python <3.11 is no longer supported.
async def _merge_agent_run_pre_3_11(
agent_runs: list[AsyncGenerator[Event, None]],
) -> AsyncGenerator[Event, None]:
"""Merges the agent run event generator.
This version works in Python 3.9 and 3.10 and uses custom replacement for
asyncio.TaskGroup for tasks cancellation and exception handling.
This implementation guarantees for each agent, it won't move on until the
generated event is processed by upstream runner.
Args:
agent_runs: A list of async generators that yield events from each agent.
Yields:
Event: The next event from the merged generator.
"""
sentinel = object()
queue = asyncio.Queue()
def propagate_exceptions(tasks):
# Propagate exceptions and errors from tasks.
for task in tasks:
if task.done():
# Ignore the result (None) of correctly finished tasks and re-raise
# exceptions and errors.
task.result()
# Agents are processed in parallel.
# Events for each agent are put on queue sequentially.
async def process_an_agent(events_for_one_agent):
try:
async for event in events_for_one_agent:
resume_signal = asyncio.Event()
await queue.put((event, resume_signal))
# Wait for upstream to consume event before generating new events.
await resume_signal.wait()
finally:
# Mark agent as finished.
await queue.put((sentinel, None))
tasks = []
try:
for events_for_one_agent in agent_runs:
tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent)))
sentinel_count = 0
# Run until all agents finished processing.
while sentinel_count < len(agent_runs):
propagate_exceptions(tasks)
event, resume_signal = await queue.get()
# Agent finished processing.
if event is sentinel:
sentinel_count += 1
else:
yield event
# Signal to agent that event has been processed by runner and it can
# continue now.
resume_signal.set()
finally:
for task in tasks:
task.cancel()
async def _merge_agent_run(
agent_runs: list[AsyncGenerator[Event, None]],
) -> AsyncGenerator[Event, None]:
"""Merges the agent run event generator.
This implementation guarantees for each agent, it won't move on until the
generated event is processed by upstream runner.
Args:
agent_runs: A list of async generators that yield events from each agent.
Yields:
Event: The next event from the merged generator.
"""
sentinel = object()
queue = asyncio.Queue()
# Agents are processed in parallel.
# Events for each agent are put on queue sequentially.
async def process_an_agent(events_for_one_agent):
try:
async for event in events_for_one_agent:
resume_signal = asyncio.Event()
await queue.put((event, resume_signal))
# Wait for upstream to consume event before generating new events.
await resume_signal.wait()
finally:
# Mark agent as finished.
await queue.put((sentinel, None))
async with asyncio.TaskGroup() as tg:
for events_for_one_agent in agent_runs:
tg.create_task(process_an_agent(events_for_one_agent))
sentinel_count = 0
# Run until all agents finished processing.
while sentinel_count < len(agent_runs):
event, resume_signal = await queue.get()
# Agent finished processing.
if event is sentinel:
sentinel_count += 1
else:
yield event
# Signal to agent that it should generate next event.
resume_signal.set()
class ParallelAgent(BaseAgent):
"""A shell agent that runs its sub-agents in parallel in an isolated manner.
This approach is beneficial for scenarios requiring multiple perspectives or
attempts on a single task, such as:
- Running different algorithms simultaneously.
- Generating multiple responses for review by a subsequent evaluation agent.
"""
config_type: ClassVar[type[BaseAgentConfig]] = ParallelAgentConfig
"""The config type for this agent."""
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
if not self.sub_agents:
return
agent_state = self._load_agent_state(ctx, BaseAgentState)
if ctx.is_resumable and agent_state is None:
ctx.set_agent_state(self.name, agent_state=BaseAgentState())
yield self._create_agent_state_event(ctx)
agent_runs = []
# Prepare and collect async generators for each sub-agent.
for sub_agent in self.sub_agents:
sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
# Only include sub-agents that haven't finished in a previous run.
if not sub_agent_ctx.end_of_agents.get(sub_agent.name):
agent_runs.append(sub_agent.run_async(sub_agent_ctx))
pause_invocation = False
try:
# TODO remove if once Python <3.11 is no longer supported.
merge_func = (
_merge_agent_run
if sys.version_info >= (3, 11)
else _merge_agent_run_pre_3_11
)
async with Aclosing(merge_func(agent_runs)) as agen:
async for event in agen:
yield event
if ctx.should_pause_invocation(event):
pause_invocation = True
if pause_invocation:
return
# Once all sub-agents are done, mark the ParallelAgent as final.
if ctx.is_resumable and all(
ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents
):
ctx.set_agent_state(self.name, end_of_agent=True)
yield self._create_agent_state_event(ctx)
finally:
for sub_agent_run in agent_runs:
await sub_agent_run.aclose()
@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
raise NotImplementedError('This is not supported yet for ParallelAgent.')
yield # AsyncGenerator requires having at least one yield statement