@@ -1754,6 +1754,93 @@ async def async_search_memory(self, *, user_id: str, query: str):
17541754 query = query ,
17551755 )
17561756
1757+
1758+ async def bidi_stream_query (
1759+ self ,
1760+ request_queue : Any ,
1761+ ) -> AsyncIterable [Any ]:
1762+ """Bidi streaming query the ADK application.
1763+
1764+ Args:
1765+ request_queue:
1766+ The queue of requests to stream responses for, with the type of
1767+ asyncio.Queue[Any].
1768+
1769+ Raises:
1770+ TypeError: If the request_queue is not an asyncio.Queue instance.
1771+ ValueError: If the first request does not have a user_id.
1772+ ValidationError: If failed to convert to LiveRequest.
1773+
1774+ Yields:
1775+ The stream responses of querying the ADK application.
1776+ """
1777+ from google .adk .agents .live_request_queue import LiveRequest
1778+ from google .adk .agents .live_request_queue import LiveRequestQueue
1779+ from vertexai .agent_engines import _utils
1780+
1781+ # Manual type check needed as Pydantic doesn't support asyncio.Queue.
1782+ if not isinstance (request_queue , asyncio .Queue ):
1783+ raise TypeError ("request_queue must be an asyncio.Queue instance." )
1784+
1785+ first_request = await request_queue .get ()
1786+ user_id = first_request .get ("user_id" )
1787+ if not user_id :
1788+ raise ValueError ("The first request must have a user_id." )
1789+
1790+ session_id = first_request .get ("session_id" )
1791+ run_config = first_request .get ("run_config" )
1792+ first_live_request = first_request .get ("live_request" )
1793+
1794+ if not self ._tmpl_attrs .get ("runner" ):
1795+ self .set_up ()
1796+ if not session_id :
1797+ state = first_request .get ("state" )
1798+ session = await self .async_create_session (user_id = user_id , state = state )
1799+ session_id = session ["id" ] if isinstance (session , dict ) else session .id
1800+ run_config = _validate_run_config (run_config )
1801+
1802+ live_request_queue = LiveRequestQueue ()
1803+
1804+ if first_live_request and isinstance (first_live_request , Dict ):
1805+ live_request_queue .send (LiveRequest .model_validate (first_live_request ))
1806+
1807+ # Forwards live requests to the agent.
1808+ async def _forward_requests ():
1809+ while True :
1810+ request = await request_queue .get ()
1811+ live_request = LiveRequest .model_validate (request )
1812+ live_request_queue .send (live_request )
1813+
1814+ # Forwards events to the client.
1815+ async def _forward_events ():
1816+ if run_config :
1817+ events_async = self ._tmpl_attrs .get ("runner" ).run_live (
1818+ user_id = user_id ,
1819+ session_id = session_id ,
1820+ live_request_queue = live_request_queue ,
1821+ run_config = run_config ,
1822+ )
1823+ else :
1824+ events_async = self ._tmpl_attrs .get ("runner" ).run_live (
1825+ user_id = user_id ,
1826+ session_id = session_id ,
1827+ live_request_queue = live_request_queue ,
1828+ )
1829+ async for event in events_async :
1830+ yield _utils .dump_event_for_json (event )
1831+
1832+ requests_task = asyncio .create_task (_forward_requests ())
1833+
1834+ try :
1835+ async for event in _forward_events ():
1836+ yield event
1837+ finally :
1838+ requests_task .cancel ()
1839+ try :
1840+ await requests_task
1841+ except asyncio .CancelledError :
1842+ pass
1843+
17571844 def register_operations (self ) -> Dict [str , List [str ]]:
17581845 """Registers the operations of the ADK application."""
17591846 return {
@@ -1776,6 +1863,7 @@ def register_operations(self) -> Dict[str, List[str]]:
17761863 "async_stream_query" ,
17771864 "streaming_agent_run_with_events" ,
17781865 ],
1866+ "bidi_stream" : ["bidi_stream_query" ],
17791867 }
17801868
17811869 def _telemetry_enabled (self ) -> Optional [bool ]:
0 commit comments