2424from typing import Any
2525from typing import AsyncGenerator
2626
27+ from google .adk .agents .callback_context import CallbackContext
2728from google .adk .agents .context import Context
29+ from google .adk .agents .llm_agent import LlmAgent
2830from google .adk .events .event import Event
2931from google .adk .runners import Runner
3032from google .adk .sessions .in_memory_session_service import InMemorySessionService
@@ -49,6 +51,10 @@ async def _run_impl(
4951 yield f'Echo: { text } '
5052
5153
54+ def _user_message (text : str = 'hello' ) -> types .Content :
55+ return types .Content (parts = [types .Part (text = text )], role = 'user' )
56+
57+
5258async def _run_node (node , message = 'hello' ):
5359 """Run a BaseNode via Runner(node=...) and return (events, ss, session)."""
5460 ss = InMemorySessionService ()
@@ -288,6 +294,129 @@ async def test_yield_user_message_false_by_default():
288294 assert user_events == []
289295
290296
297+ @pytest .mark .asyncio
298+ async def test_node_runner_applies_state_delta_before_base_node_runs ():
299+ """A BaseNode sees run_async state_delta as session state."""
300+
301+ class _StateReaderNode (BaseNode ):
302+
303+ async def _run_impl (
304+ self , * , ctx : Context , node_input : Any
305+ ) -> AsyncGenerator [Any , None ]:
306+ yield f'state:{ ctx .state ["test_state" ]} '
307+
308+ session_service = InMemorySessionService ()
309+ runner = Runner (
310+ app_name = 'test' ,
311+ node = _StateReaderNode (name = 'reader' ),
312+ session_service = session_service ,
313+ )
314+ session = await session_service .create_session (app_name = 'test' , user_id = 'u' )
315+
316+ events : list [Event ] = []
317+ async for event in runner .run_async (
318+ user_id = 'u' ,
319+ session_id = session .id ,
320+ new_message = _user_message (),
321+ state_delta = {'test_state' : 'must_change' },
322+ ):
323+ events .append (event )
324+
325+ updated = await session_service .get_session (
326+ app_name = 'test' , user_id = 'u' , session_id = session .id
327+ )
328+ user_events = [event for event in updated .events if event .author == 'user' ]
329+
330+ assert [event .output for event in events if event .output is not None ] == [
331+ 'state:must_change'
332+ ]
333+ assert updated .state ['test_state' ] == 'must_change'
334+ assert user_events [0 ].actions .state_delta == {'test_state' : 'must_change' }
335+
336+
337+ @pytest .mark .asyncio
338+ async def test_node_runner_yields_user_event_with_state_delta ():
339+ """yield_user_message=True yields the user event with state_delta."""
340+
341+ class _NoopNode (BaseNode ):
342+
343+ async def _run_impl (
344+ self , * , ctx : Context , node_input : Any
345+ ) -> AsyncGenerator [Any , None ]:
346+ yield 'done'
347+
348+ session_service = InMemorySessionService ()
349+ runner = Runner (
350+ app_name = 'test' ,
351+ node = _NoopNode (name = 'noop' ),
352+ session_service = session_service ,
353+ )
354+ session = await session_service .create_session (app_name = 'test' , user_id = 'u' )
355+
356+ events : list [Event ] = []
357+ async for event in runner .run_async (
358+ user_id = 'u' ,
359+ session_id = session .id ,
360+ new_message = _user_message (),
361+ state_delta = {'test_state' : 'must_change' },
362+ yield_user_message = True ,
363+ ):
364+ events .append (event )
365+
366+ assert events [0 ].author == 'user'
367+ assert events [0 ].actions .state_delta == {'test_state' : 'must_change' }
368+
369+
370+ @pytest .mark .asyncio
371+ async def test_node_runner_applies_state_delta_before_llm_agent_runs ():
372+ """An LlmAgent callback sees run_async state_delta before model execution."""
373+
374+ captured_state_value = None
375+
376+ def _before_agent_callback (
377+ callback_context : CallbackContext ,
378+ ) -> types .Content :
379+ nonlocal captured_state_value
380+ captured_state_value = callback_context .state ['test_state' ]
381+ return types .Content (
382+ role = 'model' ,
383+ parts = [types .Part (text = f'state:{ captured_state_value } ' )],
384+ )
385+
386+ session_service = InMemorySessionService ()
387+ agent = LlmAgent (
388+ name = 'state_agent' ,
389+ before_agent_callback = _before_agent_callback ,
390+ )
391+ runner = Runner (app_name = 'test' , agent = agent , session_service = session_service )
392+ session = await session_service .create_session (app_name = 'test' , user_id = 'u' )
393+
394+ events : list [Event ] = []
395+ async for event in runner .run_async (
396+ user_id = 'u' ,
397+ session_id = session .id ,
398+ new_message = _user_message (),
399+ state_delta = {'test_state' : 'must_change' },
400+ ):
401+ events .append (event )
402+
403+ updated = await session_service .get_session (
404+ app_name = 'test' , user_id = 'u' , session_id = session .id
405+ )
406+ user_events = [event for event in updated .events if event .author == 'user' ]
407+ response_texts = [
408+ part .text
409+ for event in events
410+ if event .content
411+ for part in event .content .parts
412+ if part .text
413+ ]
414+
415+ assert captured_state_value == 'must_change'
416+ assert 'state:must_change' in response_texts
417+ assert user_events [0 ].actions .state_delta == {'test_state' : 'must_change' }
418+
419+
291420# ---------------------------------------------------------------------------
292421# Resume (HITL)
293422# ---------------------------------------------------------------------------
0 commit comments