Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions contributing/samples/live_bidi_streaming_multi_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def get_current_weather(location: str):

root_agent = Agent(
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
# model='gemini-live-2.5-flash-preview-native-audio', # for Vertex project
model="gemini-live-2.5-flash-preview", # for AI studio key
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
# model="gemini-live-2.5-flash-preview", # for AI studio key
name="root_agent",
instruction="""
You are a helpful assistant that can check time, roll dice and check if numbers are prime.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def stop_streaming(function_name: str):


root_agent = Agent(
model="gemini-live-2.5-flash-preview",
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
# model="gemini-live-2.5-flash-preview", # for AI studio key
name="video_streaming_agent",
instruction="""
You are a monitoring agent. You can do video monitoring and stock price monitoring
Expand Down
3 changes: 2 additions & 1 deletion contributing/samples/live_tool_callbacks_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,9 @@ async def after_tool_async_callback(

# Create the agent with tool callbacks
root_agent = Agent(
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
# model="gemini-2.0-flash-live-001", # for AI studio key
# model="gemini-live-2.5-flash-preview", # for AI studio key
name="tool_callbacks_agent",
description=(
"Live streaming agent that demonstrates tool callbacks functionality. "
Expand Down
3 changes: 3 additions & 0 deletions src/google/adk/agents/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ class InvocationContext(BaseModel):
transcription_cache: Optional[list[TranscriptionEntry]] = None
"""Caches necessary data, audio or contents, that are needed by transcription."""

live_session_resumption_handle: Optional[str] = None
"""The handle for live session resumption."""

run_config: Optional[RunConfig] = None
"""Configurations for live agents under this invocation."""

Expand Down
188 changes: 116 additions & 72 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import TYPE_CHECKING

from google.genai import types
from websockets.exceptions import ConnectionClosed
from websockets.exceptions import ConnectionClosedOK

from . import functions
Expand Down Expand Up @@ -86,80 +87,115 @@ async def run_live(
invocation_context.agent.name,
llm_request,
)
async with llm.connect(llm_request) as llm_connection:
if llm_request.contents:
# Sends the conversation history to the model.
with tracer.start_as_current_span('send_data'):

if invocation_context.transcription_cache:
from . import audio_transcriber

audio_transcriber = audio_transcriber.AudioTranscriber(
init_client=True
if invocation_context.run_config.input_audio_transcription
is None
else False
)
contents = audio_transcriber.transcribe_file(invocation_context)
logger.debug('Sending history to model: %s', contents)
await llm_connection.send_history(contents)
invocation_context.transcription_cache = None
trace_send_data(invocation_context, event_id, contents)
else:
await llm_connection.send_history(llm_request.contents)
trace_send_data(invocation_context, event_id, llm_request.contents)

send_task = asyncio.create_task(
self._send_to_model(llm_connection, invocation_context)
)

attempt = 1
while True:
try:
async for event in self._receive_from_model(
llm_connection,
event_id,
invocation_context,
llm_request,
):
# Empty event means the queue is closed.
if not event:
break
logger.debug('Receive new event: %s', event)
yield event
# send back the function response
if event.get_function_responses():
logger.debug('Sending back last function response event: %s', event)
invocation_context.live_request_queue.send_content(event.content)
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'transfer_to_agent'
):
await asyncio.sleep(1)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
await llm_connection.close()
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'task_completed'
):
# this is used for sequential agent to signal the end of the agent.
await asyncio.sleep(1)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
return
finally:
# Clean up
if not send_task.done():
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
# On subsequent attempts, use the saved token to reconnect
if invocation_context.live_session_resumption_handle:
logger.info('Attempting to reconnect (Attempt %s)...', attempt)
attempt += 1
if not llm_request.live_connect_config:
llm_request.live_connect_config = types.LiveConnectConfig()
llm_request.live_connect_config.session_resumption.handle = (
invocation_context.live_session_resumption_handle
)
llm_request.live_connect_config.session_resumption.transparent = True

logger.info(
'Establishing live connection for agent: %s',
invocation_context.agent.name,
)
async with llm.connect(llm_request) as llm_connection:
if llm_request.contents:
# Sends the conversation history to the model.
with tracer.start_as_current_span('send_data'):

if invocation_context.transcription_cache:
from . import audio_transcriber

audio_transcriber = audio_transcriber.AudioTranscriber(
init_client=True
if invocation_context.run_config.input_audio_transcription
is None
else False
)
contents = audio_transcriber.transcribe_file(invocation_context)
logger.debug('Sending history to model: %s', contents)
await llm_connection.send_history(contents)
invocation_context.transcription_cache = None
trace_send_data(invocation_context, event_id, contents)
else:
await llm_connection.send_history(llm_request.contents)
trace_send_data(
invocation_context, event_id, llm_request.contents
)

send_task = asyncio.create_task(
self._send_to_model(llm_connection, invocation_context)
)

try:
async for event in self._receive_from_model(
llm_connection,
event_id,
invocation_context,
llm_request,
):
# Empty event means the queue is closed.
if not event:
break
logger.debug('Receive new event: %s', event)
yield event
# send back the function response
if event.get_function_responses():
logger.debug(
'Sending back last function response event: %s', event
)
invocation_context.live_request_queue.send_content(
event.content
)
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'transfer_to_agent'
):
await asyncio.sleep(1)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
await llm_connection.close()
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'task_completed'
):
# this is used for sequential agent to signal the end of the agent.
await asyncio.sleep(1)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
return
finally:
# Clean up
if not send_task.done():
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
except (ConnectionClosed, ConnectionClosedOK) as e:
# when the session timeout, it will just close and not throw exception.
# so this is for bad cases
logger.error(f'Connection closed: {e}.')
raise
except Exception as e:
logger.error(
f'An unexpected error occurred in live flow: {e}', exc_info=True
)
raise

async def _send_to_model(
self,
Expand Down Expand Up @@ -246,6 +282,14 @@ def get_author_for_event(llm_response):
try:
while True:
async for llm_response in llm_connection.receive():
if llm_response.live_session_resumption_update:
logger.info(
'Update session resumption hanlde:'
f' {llm_response.live_session_resumption_update}.'
)
invocation_context.live_session_resumption_handle = (
llm_response.live_session_resumption_update.new_handle
)
model_response_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
Expand Down
7 changes: 7 additions & 0 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
if message.session_resumption_update:
logger.info('Redeived session reassumption message: %s', message)
yield (
LlmResponse(
live_session_resumption_update=message.session_resumption_update
)
)

async def close(self):
"""Closes the llm server connection."""
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
],
)
llm_request.live_connect_config.tools = llm_request.config.tools
logger.info('Connecting to live with llm_request:%s', llm_request)
async with self._live_api_client.aio.live.connect(
model=llm_request.model, config=llm_request.live_connect_config
) as live_session:
Expand Down
5 changes: 5 additions & 0 deletions src/google/adk/models/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ class LlmResponse(BaseModel):
usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None
"""The usage metadata of the LlmResponse"""

live_session_resumption_update: Optional[
types.LiveServerSessionResumptionUpdate
] = None
"""The session resumption update of the LlmResponse"""

@staticmethod
def create(
generate_content_response: types.GenerateContentResponse,
Expand Down
Loading