Skip to content

Commit 660bbd4

Browse files
committed
fix: populate user_content in resumed invocations
Change-Id: I85fac02e9a926ff02ee9e95b03e6d31fb51926a4
1 parent 0478b02 commit 660bbd4

2 files changed

Lines changed: 80 additions & 8 deletions

File tree

src/google/adk/runners.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ async def _run_node_async(
459459
*,
460460
user_id: str,
461461
session_id: str,
462+
invocation_id: Optional[str] = None,
462463
new_message: Optional[types.Content] = None,
463464
state_delta: Optional[dict[str, Any]] = None,
464465
run_config: Optional[RunConfig] = None,
@@ -481,11 +482,10 @@ async def _run_node_async(
481482
resume_inputs = self._extract_resume_inputs(new_message)
482483
self._validate_new_message(new_message, resume_inputs)
483484

484-
invocation_id = (
485-
self._resolve_invocation_id_from_fr(session, new_message)
486-
if new_message
487-
else None
488-
)
485+
if not invocation_id and new_message:
486+
invocation_id = self._resolve_invocation_id_from_fr(
487+
session, new_message
488+
)
489489

490490
ic = self._new_invocation_context(
491491
session,
@@ -496,12 +496,16 @@ async def _run_node_async(
496496
ic._event_queue = asyncio.Queue()
497497

498498
# 2. Append user message to session and resolve node_input
499-
if resume_inputs:
500-
# Resume: find original user message, use as node_input
499+
node_input = None
500+
if resume_inputs or invocation_id:
501+
# Resume: recover the original user content. new_message here is a
502+
# function response (or None), so it can't populate user_content.
501503
node_input = self._find_original_user_content(
502504
ic.session, ic.invocation_id
503505
)
504-
else:
506+
if node_input:
507+
ic.user_content = node_input
508+
if not node_input:
505509
# Fresh: use user message as node_input
506510
node_input = new_message
507511

@@ -1019,6 +1023,7 @@ async def run_async(
10191023
self._run_node_async(
10201024
user_id=user_id,
10211025
session_id=session_id,
1026+
invocation_id=invocation_id,
10221027
new_message=new_message,
10231028
state_delta=state_delta,
10241029
run_config=run_config,
@@ -1039,6 +1044,7 @@ async def run_async(
10391044
self._run_node_async(
10401045
user_id=user_id,
10411046
session_id=session_id,
1047+
invocation_id=invocation_id,
10421048
new_message=new_message,
10431049
state_delta=state_delta,
10441050
run_config=run_config,

tests/unittests/runners/test_runner_node.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,72 @@ async def _run_impl(
476476
assert 'original:my original input' in outputs
477477

478478

479+
@pytest.mark.asyncio
480+
async def test_resume_populates_invocation_user_content():
481+
"""On resume via a function response, ic.user_content is the original turn."""
482+
seen: list[Any] = []
483+
484+
class _Node(BaseNode):
485+
486+
async def _run_impl(
487+
self, *, ctx: Context, node_input: Any
488+
) -> AsyncGenerator[Any, None]:
489+
if ctx.resume_inputs and 'fc-1' in ctx.resume_inputs:
490+
user_content = ctx.get_invocation_context().user_content
491+
seen.append(user_content.parts[0].text if user_content else None)
492+
yield 'resumed'
493+
return
494+
yield _make_interrupt_event(fc_name='tool')
495+
496+
await _run_two_turns(
497+
_Node(name='node'),
498+
'remember me',
499+
_make_resume_message(fc_name='tool', response={'v': 1}),
500+
)
501+
502+
assert seen == ['remember me']
503+
504+
505+
@pytest.mark.asyncio
506+
async def test_resume_by_invocation_id_populates_user_content():
507+
"""Resuming by invocation_id alone recovers the original user_content."""
508+
seen: list[Any] = []
509+
510+
class _Node(BaseNode):
511+
512+
async def _run_impl(
513+
self, *, ctx: Context, node_input: Any
514+
) -> AsyncGenerator[Any, None]:
515+
user_content = ctx.get_invocation_context().user_content
516+
seen.append(user_content.parts[0].text if user_content else None)
517+
yield _make_interrupt_event(fc_name='tool')
518+
519+
ss = InMemorySessionService()
520+
runner = Runner(app_name='test', node=_Node(name='node'), session_service=ss)
521+
session = await ss.create_session(app_name='test', user_id='u')
522+
523+
async for _ in runner.run_async(
524+
user_id='u',
525+
session_id=session.id,
526+
new_message=types.Content(
527+
parts=[types.Part(text='original text')], role='user'
528+
),
529+
):
530+
pass
531+
532+
updated = await ss.get_session(
533+
app_name='test', user_id='u', session_id=session.id
534+
)
535+
invocation_id = updated.events[0].invocation_id
536+
537+
async for _ in runner.run_async(
538+
user_id='u', session_id=session.id, invocation_id=invocation_id
539+
):
540+
pass
541+
542+
assert seen == ['original text', 'original text']
543+
544+
479545
@pytest.mark.asyncio
480546
async def test_plain_text_does_not_trigger_resume():
481547
"""Sending plain text (no FR) starts fresh, does not enter resume path."""

0 commit comments

Comments
 (0)