Skip to content

Commit 5269a6b

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Register all streaming tool at runner
previously we only register streaming tool that accept stream input at runner, now uniformly register all streaming tool at runner. Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 869447996
1 parent b53bc55 commit 5269a6b

3 files changed

Lines changed: 72 additions & 38 deletions

File tree

src/google/adk/flows/llm_flows/functions.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -825,17 +825,11 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
825825
run_tool_and_update_queue(tool, function_args, tool_context)
826826
)
827827

828-
# Register streaming tool using original logic
828+
# The tool is already registered in active_streaming_tools by
829+
# runners.py at startup (all async-generator tools are registered
830+
# there). Just attach the background task.
829831
async with streaming_lock:
830-
if invocation_context.active_streaming_tools is None:
831-
invocation_context.active_streaming_tools = {}
832-
833-
if tool.name in invocation_context.active_streaming_tools:
834-
invocation_context.active_streaming_tools[tool.name].task = task
835-
else:
836-
invocation_context.active_streaming_tools[tool.name] = (
837-
ActiveStreamingTool(task=task)
838-
)
832+
invocation_context.active_streaming_tools[tool.name].task = task
839833

840834
# Immediately return a pending response.
841835
# This is required by current live model.

src/google/adk/runners.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,38 +1023,30 @@ async def run_live(
10231023
canonical_tools = await invocation_context.agent.canonical_tools(
10241024
invocation_context
10251025
)
1026+
# Register all async-generator tools as streaming tools.
1027+
# A streaming tool is any tool whose underlying function is an
1028+
# async generator (i.e. uses `yield`). There are two sub-types:
1029+
# 1. Input-streaming tools: accept a `input_stream:
1030+
# LiveRequestQueue` parameter to consume the live audio/video
1031+
# stream. The stream is created lazily in `_call_live` when
1032+
# the model actually calls the tool.
1033+
# 2. Output-streaming tools: async generators that yield results
1034+
# over time but don't consume the live stream. They are run
1035+
# as background tasks when called.
1036+
# Both types are registered here with `stream=None`. The
1037+
# distinction between them is made at call time.
10261038
for tool in canonical_tools:
1027-
# We use `inspect.signature()` to examine the tool's underlying function (`tool.func`).
1028-
# This approach is deliberately chosen over `typing.get_type_hints()` for robustness.
1029-
#
1030-
# The Problem with `get_type_hints()`:
1031-
# `get_type_hints()` attempts to resolve forward-referenced (string-based) type
1032-
# annotations. This resolution can easily fail with a `NameError` (e.g., "Union not found")
1033-
# if the type isn't available in the scope where `get_type_hints()` is called.
1034-
# This is a common and brittle issue in framework code that inspects functions
1035-
# defined in separate user modules.
1036-
#
1037-
# Why `inspect.signature()` is Better Here:
1038-
# `inspect.signature()` does NOT resolve the annotations; it retrieves the raw
1039-
# annotation object as it was defined on the function. This allows us to
1040-
# perform a direct and reliable identity check (`param.annotation is LiveRequestQueue`)
1041-
# without risking a `NameError`.
10421039
callable_to_inspect = tool.func if hasattr(tool, 'func') else tool
1043-
# Ensure the target is actually callable before inspecting to avoid errors.
10441040
if not callable(callable_to_inspect):
10451041
continue
1046-
for param in inspect.signature(callable_to_inspect).parameters.values():
1047-
if param.annotation is LiveRequestQueue:
1048-
if not invocation_context.active_streaming_tools:
1049-
invocation_context.active_streaming_tools = {}
1050-
1051-
logger.debug(
1052-
'Register streaming tool with input stream: %s', tool.name
1053-
)
1054-
active_streaming_tool = ActiveStreamingTool()
1055-
invocation_context.active_streaming_tools[tool.name] = (
1056-
active_streaming_tool
1057-
)
1042+
if inspect.isasyncgenfunction(callable_to_inspect):
1043+
if not invocation_context.active_streaming_tools:
1044+
invocation_context.active_streaming_tools = {}
1045+
logger.debug('Register streaming tool: %s', tool.name)
1046+
active_streaming_tool = ActiveStreamingTool()
1047+
invocation_context.active_streaming_tools[tool.name] = (
1048+
active_streaming_tool
1049+
)
10581050

10591051
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
10601052
async with Aclosing(ctx.agent.run_live(ctx)) as agen:

tests/unittests/streaming/test_streaming.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,3 +1497,51 @@ def capturing_method(*args, **kwargs):
14971497
assert (
14981498
active_tools['monitor_stock_price'].stream is None
14991499
), 'Expected stream to be reset to None after stop_streaming'
1500+
1501+
1502+
def test_output_streaming_tool_registered_at_startup():
1503+
"""Test that output-streaming tools (async generators without LiveRequestQueue) are registered at startup."""
1504+
response1 = LlmResponse(turn_complete=True)
1505+
1506+
mock_model = testing_utils.MockModel.create([response1])
1507+
1508+
async def monitor_stock_price(stock_symbol: str):
1509+
"""Yield periodic price updates."""
1510+
yield f'price for {stock_symbol}'
1511+
1512+
root_agent = Agent(
1513+
name='root_agent',
1514+
model=mock_model,
1515+
tools=[monitor_stock_price],
1516+
)
1517+
1518+
runner = _LiveTestRunner(root_agent=root_agent)
1519+
1520+
# Capture invocation context to verify registration.
1521+
captured_context = None
1522+
original_method = runner.runner._new_invocation_context_for_live
1523+
1524+
def capturing_method(*args, **kwargs):
1525+
nonlocal captured_context
1526+
ctx = original_method(*args, **kwargs)
1527+
captured_context = ctx
1528+
return ctx
1529+
1530+
runner.runner._new_invocation_context_for_live = capturing_method
1531+
1532+
live_request_queue = LiveRequestQueue()
1533+
live_request_queue.send_realtime(
1534+
blob=types.Blob(data=b'test', mime_type='audio/pcm')
1535+
)
1536+
1537+
runner.run_live(live_request_queue, max_responses=1)
1538+
1539+
# Output-streaming tool should be registered with stream=None.
1540+
assert captured_context is not None
1541+
active_tools = captured_context.active_streaming_tools or {}
1542+
assert (
1543+
'monitor_stock_price' in active_tools
1544+
), 'Expected output-streaming tool to be registered at startup'
1545+
assert (
1546+
active_tools['monitor_stock_price'].stream is None
1547+
), 'Expected stream to be None for output-streaming tool'

0 commit comments

Comments
 (0)