Skip to content

Commit c56bec8

Browse files
google-genai-botDeanChensj
authored andcommitted
fix(runners): Preserve state_delta in NodeRunner path
Merge a0d90de into e623b3b Merge #5767 Resolves #5763 ORIGINAL_AUTHOR=trongthanht3 <trongthanht3@gmail.com> GitOrigin-RevId: 3c3fb06 Change-Id: I4281c6d6d68e5beb2998cf7698f3210854fd5199
1 parent 2d465aa commit c56bec8

2 files changed

Lines changed: 153 additions & 7 deletions

File tree

src/google/adk/runners.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ async def _run_node_async(
455455
user_id: str,
456456
session_id: str,
457457
new_message: Optional[types.Content] = None,
458+
state_delta: Optional[dict[str, Any]] = None,
458459
run_config: Optional[RunConfig] = None,
459460
yield_user_message: bool = False,
460461
node: Optional['BaseNode'] = None,
@@ -512,7 +513,9 @@ async def _run_node_async(
512513

513514
# Append user message to session for history
514515
if new_message:
515-
user_event = await self._append_user_event(ic, new_message)
516+
user_event = await self._append_user_event(
517+
ic, new_message, state_delta=state_delta
518+
)
516519
if yield_user_message and user_event:
517520
yield user_event
518521

@@ -706,14 +709,26 @@ def _resolve_invocation_id_from_fr(
706709
return invocation_ids.pop()
707710

708711
async def _append_user_event(
709-
self, ic: InvocationContext, content: types.Content
712+
self,
713+
ic: InvocationContext,
714+
content: types.Content,
715+
*,
716+
state_delta: Optional[dict[str, Any]] = None,
710717
) -> Event:
711718
"""Append a user message event to the session and return it."""
712-
event = Event(
713-
invocation_id=ic.invocation_id,
714-
author='user',
715-
content=content,
716-
)
719+
if state_delta:
720+
event = Event(
721+
invocation_id=ic.invocation_id,
722+
author='user',
723+
actions=EventActions(state_delta=state_delta),
724+
content=content,
725+
)
726+
else:
727+
event = Event(
728+
invocation_id=ic.invocation_id,
729+
author='user',
730+
content=content,
731+
)
717732
# when a paused task delegation is in flight, stamp
718733
# the new user message with that task's isolation_scope so the
719734
# task agent's content-build (scoped to <fc_id>) sees it.
@@ -989,6 +1004,7 @@ async def run_async(
9891004
user_id=user_id,
9901005
session_id=session_id,
9911006
new_message=new_message,
1007+
state_delta=state_delta,
9921008
run_config=run_config,
9931009
yield_user_message=yield_user_message,
9941010
node=agent_to_run,
@@ -1008,6 +1024,7 @@ async def run_async(
10081024
user_id=user_id,
10091025
session_id=session_id,
10101026
new_message=new_message,
1027+
state_delta=state_delta,
10111028
run_config=run_config,
10121029
yield_user_message=yield_user_message,
10131030
)

tests/unittests/runners/test_runner_node.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from typing import Any
2525
from typing import AsyncGenerator
2626

27+
from google.adk.agents.callback_context import CallbackContext
2728
from google.adk.agents.context import Context
29+
from google.adk.agents.llm_agent import LlmAgent
2830
from google.adk.events.event import Event
2931
from google.adk.runners import Runner
3032
from google.adk.sessions.in_memory_session_service import InMemorySessionService
@@ -49,6 +51,10 @@ async def _run_impl(
4951
yield f'Echo: {text}'
5052

5153

54+
def _user_message(text: str = 'hello') -> types.Content:
55+
return types.Content(parts=[types.Part(text=text)], role='user')
56+
57+
5258
async def _run_node(node, message='hello'):
5359
"""Run a BaseNode via Runner(node=...) and return (events, ss, session)."""
5460
ss = InMemorySessionService()
@@ -288,6 +294,129 @@ async def test_yield_user_message_false_by_default():
288294
assert user_events == []
289295

290296

297+
@pytest.mark.asyncio
298+
async def test_node_runner_applies_state_delta_before_base_node_runs():
299+
"""A BaseNode sees run_async state_delta as session state."""
300+
301+
class _StateReaderNode(BaseNode):
302+
303+
async def _run_impl(
304+
self, *, ctx: Context, node_input: Any
305+
) -> AsyncGenerator[Any, None]:
306+
yield f'state:{ctx.state["test_state"]}'
307+
308+
session_service = InMemorySessionService()
309+
runner = Runner(
310+
app_name='test',
311+
node=_StateReaderNode(name='reader'),
312+
session_service=session_service,
313+
)
314+
session = await session_service.create_session(app_name='test', user_id='u')
315+
316+
events: list[Event] = []
317+
async for event in runner.run_async(
318+
user_id='u',
319+
session_id=session.id,
320+
new_message=_user_message(),
321+
state_delta={'test_state': 'must_change'},
322+
):
323+
events.append(event)
324+
325+
updated = await session_service.get_session(
326+
app_name='test', user_id='u', session_id=session.id
327+
)
328+
user_events = [event for event in updated.events if event.author == 'user']
329+
330+
assert [event.output for event in events if event.output is not None] == [
331+
'state:must_change'
332+
]
333+
assert updated.state['test_state'] == 'must_change'
334+
assert user_events[0].actions.state_delta == {'test_state': 'must_change'}
335+
336+
337+
@pytest.mark.asyncio
338+
async def test_node_runner_yields_user_event_with_state_delta():
339+
"""yield_user_message=True yields the user event with state_delta."""
340+
341+
class _NoopNode(BaseNode):
342+
343+
async def _run_impl(
344+
self, *, ctx: Context, node_input: Any
345+
) -> AsyncGenerator[Any, None]:
346+
yield 'done'
347+
348+
session_service = InMemorySessionService()
349+
runner = Runner(
350+
app_name='test',
351+
node=_NoopNode(name='noop'),
352+
session_service=session_service,
353+
)
354+
session = await session_service.create_session(app_name='test', user_id='u')
355+
356+
events: list[Event] = []
357+
async for event in runner.run_async(
358+
user_id='u',
359+
session_id=session.id,
360+
new_message=_user_message(),
361+
state_delta={'test_state': 'must_change'},
362+
yield_user_message=True,
363+
):
364+
events.append(event)
365+
366+
assert events[0].author == 'user'
367+
assert events[0].actions.state_delta == {'test_state': 'must_change'}
368+
369+
370+
@pytest.mark.asyncio
371+
async def test_node_runner_applies_state_delta_before_llm_agent_runs():
372+
"""An LlmAgent callback sees run_async state_delta before model execution."""
373+
374+
captured_state_value = None
375+
376+
def _before_agent_callback(
377+
callback_context: CallbackContext,
378+
) -> types.Content:
379+
nonlocal captured_state_value
380+
captured_state_value = callback_context.state['test_state']
381+
return types.Content(
382+
role='model',
383+
parts=[types.Part(text=f'state:{captured_state_value}')],
384+
)
385+
386+
session_service = InMemorySessionService()
387+
agent = LlmAgent(
388+
name='state_agent',
389+
before_agent_callback=_before_agent_callback,
390+
)
391+
runner = Runner(app_name='test', agent=agent, session_service=session_service)
392+
session = await session_service.create_session(app_name='test', user_id='u')
393+
394+
events: list[Event] = []
395+
async for event in runner.run_async(
396+
user_id='u',
397+
session_id=session.id,
398+
new_message=_user_message(),
399+
state_delta={'test_state': 'must_change'},
400+
):
401+
events.append(event)
402+
403+
updated = await session_service.get_session(
404+
app_name='test', user_id='u', session_id=session.id
405+
)
406+
user_events = [event for event in updated.events if event.author == 'user']
407+
response_texts = [
408+
part.text
409+
for event in events
410+
if event.content
411+
for part in event.content.parts
412+
if part.text
413+
]
414+
415+
assert captured_state_value == 'must_change'
416+
assert 'state:must_change' in response_texts
417+
assert user_events[0].actions.state_delta == {'test_state': 'must_change'}
418+
419+
291420
# ---------------------------------------------------------------------------
292421
# Resume (HITL)
293422
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)