Skip to content

Commit 4c8449b

Browse files
hangfeicopybara-github
authored andcommitted
fix: fix Live Session Resumption
Previous implementation doesn't pass the actual handle to server. Now we cache the handle and pass it over when reconnection happens. To enable: run_config = RunConfig( session_resumption=types.SessionResumptionConfig(transparent=True) ) PiperOrigin-RevId: 789144709
1 parent 314d6a4 commit 4c8449b

5 files changed

Lines changed: 131 additions & 72 deletions

File tree

src/google/adk/agents/invocation_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ class InvocationContext(BaseModel):
151151
transcription_cache: Optional[list[TranscriptionEntry]] = None
152152
"""Caches necessary data, audio or contents, that are needed by transcription."""
153153

154+
live_session_resumption_handle: Optional[str] = None
155+
"""The handle for live session resumption."""
156+
154157
run_config: Optional[RunConfig] = None
155158
"""Configurations for live agents under this invocation."""
156159

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 115 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing import TYPE_CHECKING
2626

2727
from google.genai import types
28+
from websockets.exceptions import ConnectionClosed
2829
from websockets.exceptions import ConnectionClosedOK
2930

3031
from . import functions
@@ -86,80 +87,114 @@ async def run_live(
8687
invocation_context.agent.name,
8788
llm_request,
8889
)
89-
async with llm.connect(llm_request) as llm_connection:
90-
if llm_request.contents:
91-
# Sends the conversation history to the model.
92-
with tracer.start_as_current_span('send_data'):
93-
94-
if invocation_context.transcription_cache:
95-
from . import audio_transcriber
96-
97-
audio_transcriber = audio_transcriber.AudioTranscriber(
98-
init_client=True
99-
if invocation_context.run_config.input_audio_transcription
100-
is None
101-
else False
102-
)
103-
contents = audio_transcriber.transcribe_file(invocation_context)
104-
logger.debug('Sending history to model: %s', contents)
105-
await llm_connection.send_history(contents)
106-
invocation_context.transcription_cache = None
107-
trace_send_data(invocation_context, event_id, contents)
108-
else:
109-
await llm_connection.send_history(llm_request.contents)
110-
trace_send_data(invocation_context, event_id, llm_request.contents)
111-
112-
send_task = asyncio.create_task(
113-
self._send_to_model(llm_connection, invocation_context)
114-
)
11590

91+
attempt = 1
92+
93+
while True:
94+
# TODO: check multi-agent case
11695
try:
117-
async for event in self._receive_from_model(
118-
llm_connection,
119-
event_id,
120-
invocation_context,
121-
llm_request,
122-
):
123-
# Empty event means the queue is closed.
124-
if not event:
125-
break
126-
logger.debug('Receive new event: %s', event)
127-
yield event
128-
# send back the function response
129-
if event.get_function_responses():
130-
logger.debug('Sending back last function response event: %s', event)
131-
invocation_context.live_request_queue.send_content(event.content)
132-
if (
133-
event.content
134-
and event.content.parts
135-
and event.content.parts[0].function_response
136-
and event.content.parts[0].function_response.name
137-
== 'transfer_to_agent'
138-
):
139-
await asyncio.sleep(1)
140-
# cancel the tasks that belongs to the closed connection.
141-
send_task.cancel()
142-
await llm_connection.close()
143-
if (
144-
event.content
145-
and event.content.parts
146-
and event.content.parts[0].function_response
147-
and event.content.parts[0].function_response.name
148-
== 'task_completed'
149-
):
150-
# this is used for sequential agent to signal the end of the agent.
151-
await asyncio.sleep(1)
152-
# cancel the tasks that belongs to the closed connection.
153-
send_task.cancel()
154-
return
155-
finally:
156-
# Clean up
157-
if not send_task.done():
158-
send_task.cancel()
159-
try:
160-
await send_task
161-
except asyncio.CancelledError:
162-
pass
96+
# On subsequent attempts, use the saved token to reconnect
97+
if invocation_context.live_session_resumption_handle:
98+
logger.info('Attempting to reconnect (Attempt %s)...', attempt)
99+
attempt += 1
100+
if not llm_request.live_connect_config:
101+
llm_request.live_connect_config = types.LiveConnectConfig()
102+
llm_request.live_connect_config.session_resumption.handle = (
103+
invocation_context.live_session_resumption_handle
104+
)
105+
llm_request.live_connect_config.session_resumption.transparent = True
106+
107+
logger.info(
108+
'Establishing live connection for agent: %s',
109+
invocation_context.agent.name,
110+
)
111+
async with llm.connect(llm_request) as llm_connection:
112+
if llm_request.contents:
113+
# Sends the conversation history to the model.
114+
with tracer.start_as_current_span('send_data'):
115+
116+
if invocation_context.transcription_cache:
117+
from . import audio_transcriber
118+
119+
audio_transcriber = audio_transcriber.AudioTranscriber(
120+
init_client=True
121+
if invocation_context.run_config.input_audio_transcription
122+
is None
123+
else False
124+
)
125+
contents = audio_transcriber.transcribe_file(invocation_context)
126+
logger.debug('Sending history to model: %s', contents)
127+
await llm_connection.send_history(contents)
128+
invocation_context.transcription_cache = None
129+
trace_send_data(invocation_context, event_id, contents)
130+
else:
131+
await llm_connection.send_history(llm_request.contents)
132+
trace_send_data(
133+
invocation_context, event_id, llm_request.contents
134+
)
135+
136+
send_task = asyncio.create_task(
137+
self._send_to_model(llm_connection, invocation_context)
138+
)
139+
140+
try:
141+
async for event in self._receive_from_model(
142+
llm_connection,
143+
event_id,
144+
invocation_context,
145+
llm_request,
146+
):
147+
# Empty event means the queue is closed.
148+
if not event:
149+
break
150+
logger.debug('Receive new event: %s', event)
151+
yield event
152+
# send back the function response
153+
if event.get_function_responses():
154+
logger.debug(
155+
'Sending back last function response event: %s', event
156+
)
157+
invocation_context.live_request_queue.send_content(
158+
event.content
159+
)
160+
if (
161+
event.content
162+
and event.content.parts
163+
and event.content.parts[0].function_response
164+
and event.content.parts[0].function_response.name
165+
== 'transfer_to_agent'
166+
):
167+
await asyncio.sleep(1)
168+
# cancel the tasks that belongs to the closed connection.
169+
send_task.cancel()
170+
await llm_connection.close()
171+
if (
172+
event.content
173+
and event.content.parts
174+
and event.content.parts[0].function_response
175+
and event.content.parts[0].function_response.name
176+
== 'task_completed'
177+
):
178+
# this is used for sequential agent to signal the end of the agent.
179+
await asyncio.sleep(1)
180+
# cancel the tasks that belongs to the closed connection.
181+
send_task.cancel()
182+
return
183+
finally:
184+
# Clean up
185+
if not send_task.done():
186+
send_task.cancel()
187+
try:
188+
await send_task
189+
except asyncio.CancelledError:
190+
pass
191+
except (ConnectionClosed, ConnectionClosedOK) as e:
192+
logger.warning(f'Connection closed: {e}.')
193+
except Exception as e:
194+
logger.error(
195+
f'An unexpected error occurred in live flow: {e}', exc_info=True
196+
)
197+
raise
163198

164199
async def _send_to_model(
165200
self,
@@ -246,6 +281,14 @@ def get_author_for_event(llm_response):
246281
try:
247282
while True:
248283
async for llm_response in llm_connection.receive():
284+
if llm_response.live_session_resumption_update:
285+
logger.info(
286+
'Update session resumption hanlde:'
287+
f' {llm_response.live_session_resumption_update}.'
288+
)
289+
invocation_context.live_session_resumption_handle = (
290+
llm_response.live_session_resumption_update.new_handle
291+
)
249292
model_response_event = Event(
250293
id=Event.new_id(),
251294
invocation_id=invocation_context.invocation_id,

src/google/adk/models/gemini_llm_connection.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,13 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
219219
for function_call in message.tool_call.function_calls
220220
]
221221
yield LlmResponse(content=types.Content(role='model', parts=parts))
222+
if message.session_resumption_update:
223+
logger.info('Redeived session reassumption message: %s', message)
224+
yield (
225+
LlmResponse(
226+
live_session_resumption_update=message.session_resumption_update
227+
)
228+
)
222229

223230
async def close(self):
224231
"""Closes the llm server connection."""

src/google/adk/models/google_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
289289
],
290290
)
291291
llm_request.live_connect_config.tools = llm_request.config.tools
292+
logger.info('Connecting to live with llm_request:%s', llm_request)
292293
async with self._live_api_client.aio.live.connect(
293294
model=llm_request.model, config=llm_request.live_connect_config
294295
) as live_session:

src/google/adk/models/llm_response.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ class LlmResponse(BaseModel):
8989
usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None
9090
"""The usage metadata of the LlmResponse"""
9191

92+
live_session_resumption_update: Optional[
93+
types.LiveServerSessionResumptionUpdate
94+
] = None
95+
"""The session resumption update of the LlmResponse"""
96+
9297
@staticmethod
9398
def create(
9499
generate_content_response: types.GenerateContentResponse,

0 commit comments

Comments
 (0)