Skip to content

Commit e25e725

Browse files
committed
feat (live) : Refactor live SequentialAgent worklow and live event handling for reliability and cleaner session history
1 parent 16a15c8 commit e25e725

4 files changed

Lines changed: 122 additions & 62 deletions

File tree

src/google/adk/agents/sequential_agent.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Sequential agent implementation."""
1616

17+
1718
from __future__ import annotations
1819

1920
from typing import AsyncGenerator
@@ -29,6 +30,8 @@
2930
from .sequential_agent_config import SequentialAgentConfig
3031

3132

33+
34+
3235
class SequentialAgent(BaseAgent):
3336
"""A shell agent that runs its sub-agents in sequence."""
3437

@@ -46,37 +49,21 @@ async def _run_live_impl(
4649
) -> AsyncGenerator[Event, None]:
4750
"""Implementation for live SequentialAgent.
4851
49-
Compared to the non-live case, live agents process a continuous stream of audio
50-
or video, so there is no way to tell if it's finished and should pass
51-
to the next agent or not. So we introduce a task_completed() function so the
52-
model can call this function to signal that it's finished the task and we
53-
can move on to the next agent.
52+
In a live run, this agent executes its sub-agents one by one. It relies
53+
on the `generation_complete` event from the underlying model to determine
54+
when a sub-agent has finished its turn. Once a sub-agent's `run_live`
55+
stream concludes (triggered by the `generation_complete` event), the
56+
`SequentialAgent` will proceed to execute the next sub-agent in the
57+
sequence.
5458
5559
Args:
5660
ctx: The invocation context of the agent.
5761
"""
58-
# There is no way to know if it's using live during init phase so we have to init it here
59-
for sub_agent in self.sub_agents:
60-
# add tool
61-
def task_completed():
62-
"""
63-
Signals that the model has successfully completed the user's question
64-
or task.
65-
"""
66-
return 'Task completion signaled.'
67-
68-
if isinstance(sub_agent, LlmAgent):
69-
# Use function name to dedupe.
70-
if task_completed.__name__ not in sub_agent.tools:
71-
sub_agent.tools.append(task_completed)
72-
sub_agent.instruction += f"""If you finished the user's request
73-
according to its description, call the {task_completed.__name__} function
74-
to exit so the next agents can take over. When calling this function,
75-
do not generate any text other than the function call."""
76-
7762
for sub_agent in self.sub_agents:
7863
async for event in sub_agent.run_live(ctx):
7964
yield event
65+
if event.generation_complete:
66+
break
8067

8168
@classmethod
8269
@override

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -109,26 +109,45 @@ async def run_live(
109109
await llm_connection.send_history(llm_request.contents)
110110
trace_send_data(invocation_context, event_id, llm_request.contents)
111111

112-
send_task = asyncio.create_task(
113-
self._send_to_model(llm_connection, invocation_context)
114-
)
112+
event_queue = asyncio.Queue()
115113

116-
try:
117-
async for event in self._receive_from_model(
118-
llm_connection,
119-
event_id,
120-
invocation_context,
121-
llm_request,
114+
async def send_handler():
115+
"""Handles sending user input and generating user text events."""
116+
async for event in self._send_to_model(
117+
llm_connection, invocation_context
122118
):
123-
# Empty event means the queue is closed.
124-
if not event:
119+
await event_queue.put(event)
120+
121+
async def receive_handler():
122+
"""Handles receiving model output and generating model events."""
123+
try:
124+
async for event in self._receive_from_model(
125+
llm_connection, event_id, invocation_context, llm_request
126+
):
127+
await event_queue.put(event)
128+
finally:
129+
# Signal that the receiving process is complete.
130+
await event_queue.put(None)
131+
132+
send_task = asyncio.create_task(send_handler())
133+
receive_task = asyncio.create_task(receive_handler())
134+
tasks = {send_task, receive_task}
135+
136+
try:
137+
while True:
138+
# Consume events from the unified queue.
139+
event = await event_queue.get()
140+
if event is None: # End of stream signal
125141
break
142+
126143
logger.debug('Receive new event: %s', event)
127144
yield event
128-
# send back the function response
145+
146+
# Forward function responses back to the model.
129147
if event.get_function_responses():
130148
logger.debug('Sending back last function response event: %s', event)
131149
invocation_context.live_request_queue.send_content(event.content)
150+
132151
if (
133152
event.content
134153
and event.content.parts
@@ -140,33 +159,19 @@ async def run_live(
140159
# cancel the tasks that belongs to the closed connection.
141160
send_task.cancel()
142161
await llm_connection.close()
143-
if (
144-
event.content
145-
and event.content.parts
146-
and event.content.parts[0].function_response
147-
and event.content.parts[0].function_response.name
148-
== 'task_completed'
149-
):
150-
# this is used for sequential agent to signal the end of the agent.
151-
await asyncio.sleep(1)
152-
# cancel the tasks that belongs to the closed connection.
153-
send_task.cancel()
154-
return
155162
finally:
156-
# Clean up
157-
if not send_task.done():
158-
send_task.cancel()
159-
try:
160-
await send_task
161-
except asyncio.CancelledError:
162-
pass
163+
# Clean up all running tasks.
164+
for task in tasks:
165+
if not task.done():
166+
task.cancel()
167+
await asyncio.gather(*tasks, return_exceptions=True)
163168

164169
async def _send_to_model(
165170
self,
166171
llm_connection: BaseLlmConnection,
167172
invocation_context: InvocationContext,
168-
):
169-
"""Sends data to model."""
173+
) -> AsyncGenerator[Event, None]:
174+
"""Sends data to model and yields user events for text messages."""
170175
while True:
171176
live_request_queue = invocation_context.live_request_queue
172177
try:
@@ -212,7 +217,23 @@ async def _send_to_model(
212217
)
213218
await llm_connection.send_realtime(live_request.blob)
214219

215-
if live_request.content:
220+
# If the request is a user-sent text message, create and yield an event
221+
# so it can be saved to the session history.
222+
if (
223+
live_request.content
224+
and live_request.content.parts
225+
and live_request.content.parts[0].text
226+
):
227+
user_event = Event(
228+
invocation_id=invocation_context.invocation_id,
229+
author='user',
230+
content=live_request.content,
231+
)
232+
yield user_event
233+
await llm_connection.send_content(live_request.content)
234+
elif live_request.content:
235+
# Handle other content types, like function responses, without creating
236+
# a user event.
216237
await llm_connection.send_content(live_request.content)
217238

218239
async def _receive_from_model(

src/google/adk/models/gemini_llm_connection.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,37 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
140140
Yields:
141141
LlmResponse: The model response.
142142
"""
143-
144143
text = ''
144+
user_text = ''
145145
async for message in self._gemini_session.receive():
146146
logger.debug('Got LLM Live message: %s', message)
147+
148+
model_turn_has_content = False
149+
if message.server_content and message.server_content.model_turn:
150+
content = message.server_content.model_turn
151+
if content and content.parts:
152+
model_turn_has_content = any(
153+
p.text or p.inline_data for p in content.parts
154+
)
155+
156+
model_is_replying = (
157+
message.tool_call
158+
or (
159+
message.server_content
160+
and message.server_content.output_transcription
161+
)
162+
or model_turn_has_content
163+
)
164+
165+
if user_text and model_is_replying:
166+
yield LlmResponse(
167+
content=types.Content(
168+
role='user',
169+
parts=[types.Part.from_text(text=user_text)],
170+
)
171+
)
172+
user_text = ''
173+
147174
if message.server_content:
148175
content = message.server_content.model_turn
149176
if content and content.parts:
@@ -153,6 +180,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
153180
if content.parts[0].text:
154181
text += content.parts[0].text
155182
llm_response.partial = True
183+
if content.parts[0].inline_data:
184+
llm_response.partial = True
156185
# don't yield the merged text event when receiving audio data
157186
elif text and not content.parts[0].inline_data:
158187
yield self.__build_full_text_response(text)
@@ -162,14 +191,15 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
162191
message.server_content.input_transcription
163192
and message.server_content.input_transcription.text
164193
):
165-
user_text = message.server_content.input_transcription.text
194+
user_text_fragment = message.server_content.input_transcription.text
195+
user_text += user_text_fragment
166196
parts = [
167197
types.Part.from_text(
168-
text=user_text,
198+
text=user_text_fragment,
169199
)
170200
]
171201
llm_response = LlmResponse(
172-
content=types.Content(role='user', parts=parts)
202+
content=types.Content(role='user', parts=parts), partial=True
173203
)
174204
yield llm_response
175205
if (
@@ -202,6 +232,21 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
202232
turn_complete=True, interrupted=message.server_content.interrupted
203233
)
204234
break
235+
if message.server_content.generation_complete:
236+
if text:
237+
yield self.__build_full_text_response(text)
238+
text = ''
239+
#yield LlmResponse(generation_complete=True, partial=True)
240+
yield LlmResponse(
241+
content=types.Content(
242+
role='model',
243+
parts=[
244+
types.Part.from_text(text='[SYSTEM] Hand off to second agent')
245+
],
246+
),
247+
generation_complete=True,
248+
partial = True
249+
)
205250
# in case of empty content or parts, we sill surface it
206251
# in case it's an interrupted message, we merge the previous partial
207252
# text. Other we don't merge. because content can be none when model

src/google/adk/models/llm_response.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class LlmResponse(BaseModel):
3535
stream. Only used for streaming mode and when the content is plain text.
3636
turn_complete: Indicates whether the response from the model is complete.
3737
Only used for streaming mode.
38+
generation_complete: Indicates that the model has finished generating content.
39+
Only used for streaming mode.
3840
error_code: Error code if the response is an error. Code varies by model.
3941
error_message: Error message if the response is an error.
4042
interrupted: Flag indicating that LLM was interrupted when generating the
@@ -67,6 +69,11 @@ class LlmResponse(BaseModel):
6769
Only used for streaming mode.
6870
"""
6971

72+
generation_complete: Optional[bool] = None
73+
"""Indicates that the model has finished generating content.
74+
Only used for streaming mode.
75+
"""
76+
7077
error_code: Optional[str] = None
7178
"""Error code if the response is an error. Code varies by model."""
7279

0 commit comments

Comments
 (0)