Skip to content

Commit a6995d4

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add bidirectional streaming support (bidi_stream_query) to stable AdkApp template.
This change implements and registers the bidirectional streaming query operation (`bidi_stream_query`) in the stable `AdkApp` template, resolving parity issues with the preview template and aligning with public documentation. This addresses issues where `bidi_stream_query` was not found in stable deployments, despite being documented. Fixes google/adk-python#5611 PiperOrigin-RevId: 913150508
1 parent 4ba222b commit a6995d4

2 files changed

Lines changed: 173 additions & 0 deletions

File tree

tests/unit/vertex_adk/test_agent_engine_templates_adk.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import importlib
1717
import json
1818
import os
19+
import asyncio
1920
import re
2021
import sys
2122
from typing import Optional
@@ -308,6 +309,34 @@ async def run_async(self, *args, **kwargs):
308309
}
309310
)
310311

312+
async def run_live(self, *args, **kwargs):
313+
from google.adk.events import event
314+
315+
yield event.Event(
316+
**{
317+
"author": "currency_exchange_agent",
318+
"content": {
319+
"parts": [
320+
{
321+
"thought_signature": b"test_signature",
322+
"function_call": {
323+
"args": {
324+
"currency_date": "2025-04-03",
325+
"currency_from": "USD",
326+
"currency_to": "SEK",
327+
},
328+
"id": "af-c5a57692-9177-4091-a3df-098f834ee849",
329+
"name": "get_exchange_rate",
330+
},
331+
}
332+
],
333+
"role": "model",
334+
},
335+
"id": "9aaItGK9",
336+
"invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7",
337+
}
338+
)
339+
311340

312341
@pytest.mark.usefixtures("google_auth_mock")
313342
class TestAdkApp:
@@ -904,6 +933,62 @@ def test_span_content_capture_enabled_with_tracing(
904933
app.set_up()
905934
assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "true"
906935

936+
@pytest.mark.asyncio
937+
async def test_async_bidi_stream_query(
938+
self,
939+
default_instrumentor_builder_mock: mock.Mock,
940+
get_project_id_mock: mock.Mock,
941+
):
942+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
943+
assert app._tmpl_attrs.get("runner") is None
944+
app.set_up()
945+
app._tmpl_attrs["runner"] = _MockRunner()
946+
request_queue = asyncio.Queue()
947+
request_dict = {
948+
"user_id": _TEST_USER_ID,
949+
"live_request": {
950+
"input": "What is the exchange rate from USD to SEK?",
951+
},
952+
}
953+
954+
await request_queue.put(request_dict)
955+
await request_queue.put(None) # Sentinel to end the stream.
956+
events = []
957+
async for event in app.bidi_stream_query(request_queue):
958+
events.append(event)
959+
assert len(events) == 1
960+
961+
@pytest.mark.asyncio
962+
async def test_async_bidi_stream_query_with_state(
963+
self,
964+
default_instrumentor_builder_mock: mock.Mock,
965+
get_project_id_mock: mock.Mock,
966+
):
967+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
968+
assert app._tmpl_attrs.get("runner") is None
969+
app.set_up()
970+
app._tmpl_attrs["runner"] = _MockRunner()
971+
request_queue = asyncio.Queue()
972+
request_dict = {
973+
"user_id": _TEST_USER_ID,
974+
"state": {"test_key": "test_val"},
975+
"live_request": {
976+
"input": "What is the exchange rate from USD to SEK?",
977+
},
978+
}
979+
980+
await request_queue.put(request_dict)
981+
await request_queue.put(None) # Sentinel to end the stream.
982+
983+
with mock.patch.object(
984+
app, "async_create_session", wraps=app.async_create_session
985+
) as mock_create_session:
986+
async for _ in app.bidi_stream_query(request_queue):
987+
pass
988+
mock_create_session.assert_called_once_with(
989+
user_id=_TEST_USER_ID, state={"test_key": "test_val"}
990+
)
991+
907992

908993
def test_dump_event_for_json():
909994
from google.adk.events import event

vertexai/agent_engines/templates/adk.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,6 +1754,93 @@ async def async_search_memory(self, *, user_id: str, query: str):
17541754
query=query,
17551755
)
17561756

1757+
1758+
async def bidi_stream_query(
1759+
self,
1760+
request_queue: Any,
1761+
) -> AsyncIterable[Any]:
1762+
"""Bidi streaming query the ADK application.
1763+
1764+
Args:
1765+
request_queue:
1766+
The queue of requests to stream responses for, with the type of
1767+
asyncio.Queue[Any].
1768+
1769+
Raises:
1770+
TypeError: If the request_queue is not an asyncio.Queue instance.
1771+
ValueError: If the first request does not have a user_id.
1772+
ValidationError: If failed to convert to LiveRequest.
1773+
1774+
Yields:
1775+
The stream responses of querying the ADK application.
1776+
"""
1777+
from google.adk.agents.live_request_queue import LiveRequest
1778+
from google.adk.agents.live_request_queue import LiveRequestQueue
1779+
from vertexai.agent_engines import _utils
1780+
1781+
# Manual type check needed as Pydantic doesn't support asyncio.Queue.
1782+
if not isinstance(request_queue, asyncio.Queue):
1783+
raise TypeError("request_queue must be an asyncio.Queue instance.")
1784+
1785+
first_request = await request_queue.get()
1786+
user_id = first_request.get("user_id")
1787+
if not user_id:
1788+
raise ValueError("The first request must have a user_id.")
1789+
1790+
session_id = first_request.get("session_id")
1791+
run_config = first_request.get("run_config")
1792+
first_live_request = first_request.get("live_request")
1793+
1794+
if not self._tmpl_attrs.get("runner"):
1795+
self.set_up()
1796+
if not session_id:
1797+
state = first_request.get("state")
1798+
session = await self.async_create_session(user_id=user_id, state=state)
1799+
session_id = session["id"] if isinstance(session, dict) else session.id
1800+
run_config = _validate_run_config(run_config)
1801+
1802+
live_request_queue = LiveRequestQueue()
1803+
1804+
if first_live_request and isinstance(first_live_request, Dict):
1805+
live_request_queue.send(LiveRequest.model_validate(first_live_request))
1806+
1807+
# Forwards live requests to the agent.
1808+
async def _forward_requests():
1809+
while True:
1810+
request = await request_queue.get()
1811+
live_request = LiveRequest.model_validate(request)
1812+
live_request_queue.send(live_request)
1813+
1814+
# Forwards events to the client.
1815+
async def _forward_events():
1816+
if run_config:
1817+
events_async = self._tmpl_attrs.get("runner").run_live(
1818+
user_id=user_id,
1819+
session_id=session_id,
1820+
live_request_queue=live_request_queue,
1821+
run_config=run_config,
1822+
)
1823+
else:
1824+
events_async = self._tmpl_attrs.get("runner").run_live(
1825+
user_id=user_id,
1826+
session_id=session_id,
1827+
live_request_queue=live_request_queue,
1828+
)
1829+
async for event in events_async:
1830+
yield _utils.dump_event_for_json(event)
1831+
1832+
requests_task = asyncio.create_task(_forward_requests())
1833+
1834+
try:
1835+
async for event in _forward_events():
1836+
yield event
1837+
finally:
1838+
requests_task.cancel()
1839+
try:
1840+
await requests_task
1841+
except asyncio.CancelledError:
1842+
pass
1843+
17571844
def register_operations(self) -> Dict[str, List[str]]:
17581845
"""Registers the operations of the ADK application."""
17591846
return {
@@ -1776,6 +1863,7 @@ def register_operations(self) -> Dict[str, List[str]]:
17761863
"async_stream_query",
17771864
"streaming_agent_run_with_events",
17781865
],
1866+
"bidi_stream": ["bidi_stream_query"],
17791867
}
17801868

17811869
def _telemetry_enabled(self) -> Optional[bool]:

0 commit comments

Comments
 (0)