Skip to content

Commit 37b5a0f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Pass state from first bidi_stream_query request to async_create_session
The `bidi_stream_query` method in AdkApp now correctly extracts the "state" field from the initial request in the stream and forwards it to `async_create_session` when a new session is being created. This ensures that session state is properly initialized for bidi-directional streaming calls. A unit test is added to verify this behavior. PiperOrigin-RevId: 895368130
1 parent 09794ba commit 37b5a0f

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,35 @@ async def test_async_bidi_stream_query(self):
648648
events.append(event)
649649
assert len(events) == 1
650650

651+
@pytest.mark.asyncio
652+
async def test_async_bidi_stream_query_with_state(self):
653+
app = reasoning_engines.AdkApp(
654+
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)
655+
)
656+
assert app._tmpl_attrs.get("runner") is None
657+
app.set_up()
658+
app._tmpl_attrs["runner"] = _MockRunner()
659+
request_queue = asyncio.Queue()
660+
request_dict = {
661+
"user_id": _TEST_USER_ID,
662+
"state": {"test_key": "test_val"},
663+
"live_request": {
664+
"input": "What is the exchange rate from USD to SEK?",
665+
},
666+
}
667+
668+
await request_queue.put(request_dict)
669+
await request_queue.put(None) # Sentinel to end the stream.
670+
671+
with mock.patch.object(
672+
app, "async_create_session", wraps=app.async_create_session
673+
) as mock_create_session:
674+
async for _ in app.bidi_stream_query(request_queue):
675+
pass
676+
mock_create_session.assert_called_once_with(
677+
user_id=_TEST_USER_ID, state={"test_key": "test_val"}
678+
)
679+
651680
def test_create_session(self):
652681
app = reasoning_engines.AdkApp(
653682
agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL)

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1178,7 +1178,8 @@ async def bidi_stream_query(
11781178
if not self._tmpl_attrs.get("runner"):
11791179
self.set_up()
11801180
if not session_id:
1181-
session = await self.async_create_session(user_id=user_id)
1181+
state = first_request.get("state")
1182+
session = await self.async_create_session(user_id=user_id, state=state)
11821183
session_id = session.id
11831184
run_config = _validate_run_config(run_config)
11841185

0 commit comments

Comments
 (0)