Skip to content

Commit 5f806ed

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: Refactor runner to infer invocation_id from FunctionResponse Event for HITL resuming
invocation_id is no longer required in resuming case, unless no new_message is provided. Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 875432024
1 parent de4dee8 commit 5f806ed

4 files changed

Lines changed: 181 additions & 54 deletions

File tree

src/google/adk/agents/invocation_context.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -396,23 +396,20 @@ def should_pause_invocation(self, event: Event) -> bool:
396396
return False
397397

398398
# TODO: Move this method from invocation_context to a dedicated module.
399-
# TODO: Converge this method with find_matching_function_call in llm_flows.
400399
def _find_matching_function_call(
401400
self, function_response_event: Event
402401
) -> Optional[Event]:
403402
"""Finds the function call event in the current invocation that matches the function response id."""
403+
from ..flows.llm_flows.functions import find_event_by_function_call_id
404+
404405
function_responses = function_response_event.get_function_responses()
405406
if not function_responses:
406407
return None
407-
function_call_id = function_responses[0].id
408-
409-
events = self._get_events(current_invocation=True)
410-
# The last event is function_response_event, so we search backwards from the
411-
# one before it.
412-
for event in reversed(events[:-1]):
413-
if any(fc.id == function_call_id for fc in event.get_function_calls()):
414-
return event
415-
return None
408+
409+
# Search backwards from the event before the current response event.
410+
return find_event_by_function_call_id(
411+
self._get_events(current_invocation=True)[:-1], function_responses[0].id
412+
)
416413

417414

418415
def new_invocation_context_id() -> str:

src/google/adk/flows/llm_flows/functions.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from google.genai import types
3838

3939
from ...agents.active_streaming_tool import ActiveStreamingTool
40-
from ...agents.invocation_context import InvocationContext
4140
from ...agents.live_request_queue import LiveRequestQueue
4241
from ...auth.auth_tool import AuthConfig
4342
from ...auth.auth_tool import AuthToolArguments
@@ -52,6 +51,7 @@
5251
from ...utils.context_utils import Aclosing
5352

5453
if TYPE_CHECKING:
54+
from ...agents.invocation_context import InvocationContext
5555
from ...agents.llm_agent import LlmAgent
5656

5757
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
@@ -1157,6 +1157,18 @@ def merge_parallel_function_response_events(
11571157
return merged_event
11581158

11591159

1160+
def find_event_by_function_call_id(
1161+
events: list[Event],
1162+
function_call_id: str,
1163+
) -> Optional[Event]:
1164+
"""Finds the function call event that matches the function call id."""
1165+
for event in reversed(events):
1166+
for function_call in event.get_function_calls():
1167+
if function_call.id == function_call_id:
1168+
return event
1169+
return None
1170+
1171+
11601172
def find_matching_function_call(
11611173
events: list[Event],
11621174
) -> Optional[Event]:
@@ -1165,25 +1177,8 @@ def find_matching_function_call(
11651177
return None
11661178

11671179
last_event = events[-1]
1168-
if (
1169-
last_event.content
1170-
and last_event.content.parts
1171-
and any(part.function_response for part in last_event.content.parts)
1172-
):
1180+
function_responses = last_event.get_function_responses()
1181+
if not function_responses:
1182+
return None
11731183

1174-
function_call_id = next(
1175-
part.function_response.id
1176-
for part in last_event.content.parts
1177-
if part.function_response
1178-
)
1179-
for i in range(len(events) - 2, -1, -1):
1180-
event = events[i]
1181-
# looking for the system long-running request euc function call
1182-
function_calls = event.get_function_calls()
1183-
if not function_calls:
1184-
continue
1185-
1186-
for function_call in function_calls:
1187-
if function_call.id == function_call_id:
1188-
return event
1189-
return None
1184+
return find_event_by_function_call_id(events[:-1], function_responses[0].id)

src/google/adk/runners.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from .events.event import Event
5050
from .events.event import EventActions
5151
from .flows.llm_flows import contents
52+
from .flows.llm_flows.functions import find_event_by_function_call_id
5253
from .flows.llm_flows.functions import find_matching_function_call
5354
from .memory.base_memory_service import BaseMemoryService
5455
from .memory.in_memory_memory_service import InMemoryMemoryService
@@ -70,6 +71,16 @@ def _is_tool_call_or_response(event: Event) -> bool:
7071
return bool(event.get_function_calls() or event.get_function_responses())
7172

7273

74+
def _get_function_responses_from_content(
75+
content: types.Content,
76+
) -> list[types.FunctionResponse]:
77+
if not content:
78+
return []
79+
return [
80+
part.function_response for part in content.parts if part.function_response
81+
]
82+
83+
7384
def _is_transcription(event: Event) -> bool:
7485
return (
7586
event.input_transcription is not None
@@ -341,6 +352,35 @@ def _enforce_app_name_alignment(self) -> None:
341352
self._app_name_alignment_hint = f'{mismatch_details} {resolution}'
342353
logger.warning('App name mismatch detected. %s', mismatch_details)
343354

355+
def _resolve_invocation_id(
356+
self,
357+
session: Session,
358+
new_message: Optional[types.Content],
359+
invocation_id: Optional[str],
360+
) -> Optional[str]:
361+
"""Infers invocation_id from new_message if it is a function response."""
362+
function_responses = _get_function_responses_from_content(new_message)
363+
if not function_responses:
364+
return invocation_id
365+
366+
fc_event = find_event_by_function_call_id(
367+
session.events, function_responses[0].id
368+
)
369+
if not fc_event:
370+
raise ValueError(
371+
'Function call event not found for function response id:'
372+
f' {function_responses[0].id}'
373+
)
374+
375+
if invocation_id and invocation_id != fc_event.invocation_id:
376+
logger.warning(
377+
'Provided invocation_id %s is ignored because new_message has a '
378+
'function response with invocation_id %s.',
379+
invocation_id,
380+
fc_event.invocation_id,
381+
)
382+
return fc_event.invocation_id
383+
344384
def _format_session_not_found_message(self, session_id: str) -> str:
345385
message = f'Session not found: {session_id}'
346386
if not self._app_name_alignment_hint:
@@ -497,42 +537,57 @@ async def _run_with_trace(
497537
session = await self._get_or_create_session(
498538
user_id=user_id, session_id=session_id
499539
)
540+
500541
if not invocation_id and not new_message:
501542
raise ValueError(
502543
'Running an agent requires either a new_message or an '
503544
'invocation_id to resume a previous invocation. '
504545
f'Session: {session_id}, User: {user_id}'
505546
)
506547

507-
if invocation_id:
508-
if (
509-
not self.resumability_config
510-
or not self.resumability_config.is_resumable
511-
):
512-
raise ValueError(
513-
f'invocation_id: {invocation_id} is provided but the app is not'
514-
' resumable.'
515-
)
516-
invocation_context = await self._setup_context_for_resumed_invocation(
548+
is_resumable = (
549+
self.resumability_config and self.resumability_config.is_resumable
550+
)
551+
if not is_resumable and not new_message:
552+
raise ValueError(
553+
'Running an agent requires a new_message or a resumable app. '
554+
f'Session: {session_id}, User: {user_id}'
555+
)
556+
557+
if not is_resumable:
558+
invocation_context = await self._setup_context_for_new_invocation(
517559
session=session,
518560
new_message=new_message,
519-
invocation_id=invocation_id,
520561
run_config=run_config,
521562
state_delta=state_delta,
522563
)
523-
if invocation_context.end_of_agents.get(
524-
invocation_context.agent.name
525-
):
526-
# Directly return if the current agent in invocation context is
527-
# already final.
528-
return
529564
else:
530-
invocation_context = await self._setup_context_for_new_invocation(
531-
session=session,
532-
new_message=new_message, # new_message is not None.
533-
run_config=run_config,
534-
state_delta=state_delta,
565+
invocation_id = self._resolve_invocation_id(
566+
session, new_message, invocation_id
535567
)
568+
if not invocation_id:
569+
invocation_context = await self._setup_context_for_new_invocation(
570+
session=session,
571+
new_message=new_message,
572+
run_config=run_config,
573+
state_delta=state_delta,
574+
)
575+
else:
576+
invocation_context = (
577+
await self._setup_context_for_resumed_invocation(
578+
session=session,
579+
new_message=new_message,
580+
invocation_id=invocation_id,
581+
run_config=run_config,
582+
state_delta=state_delta,
583+
)
584+
)
585+
if invocation_context.end_of_agents.get(
586+
invocation_context.agent.name
587+
):
588+
# Directly return if the current agent in invocation context is
589+
# already final.
590+
return
536591

537592
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
538593
async with Aclosing(ctx.agent.run_async(ctx)) as agen:

tests/unittests/runners/test_run_tool_confirmation.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,86 @@ async def test_pause_and_resume_on_request_confirmation(
502502
== expected_parts_final
503503
)
504504

505+
@pytest.mark.asyncio
506+
async def test_pause_and_resume_on_request_confirmation_without_invocation_id(
507+
self,
508+
runner: testing_utils.InMemoryRunner,
509+
agent: LlmAgent,
510+
):
511+
"""Tests HITL flow where all tool calls are confirmed."""
512+
events = runner.run("test user query")
513+
514+
# Verify that the invocation is paused when tool confirmation is requested.
515+
# The tool call returns error response, and summarization was skipped.
516+
assert testing_utils.simplify_resumable_app_events(
517+
copy.deepcopy(events)
518+
) == [
519+
(
520+
agent.name,
521+
Part(function_call=FunctionCall(name=agent.tools[0].name, args={})),
522+
),
523+
(
524+
agent.name,
525+
Part(
526+
function_call=FunctionCall(
527+
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
528+
args={
529+
"originalFunctionCall": {
530+
"name": agent.tools[0].name,
531+
"id": mock.ANY,
532+
"args": {},
533+
},
534+
"toolConfirmation": {
535+
"hint": HINT_TEXT,
536+
"confirmed": False,
537+
},
538+
},
539+
)
540+
),
541+
),
542+
(
543+
agent.name,
544+
Part(
545+
function_response=FunctionResponse(
546+
name=agent.tools[0].name, response=TOOL_CALL_ERROR_RESPONSE
547+
)
548+
),
549+
),
550+
]
551+
ask_for_confirmation_function_call_id = (
552+
events[1].content.parts[0].function_call.id
553+
)
554+
invocation_id = events[1].invocation_id
555+
user_confirmation = testing_utils.UserContent(
556+
Part(
557+
function_response=FunctionResponse(
558+
id=ask_for_confirmation_function_call_id,
559+
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
560+
response={"confirmed": True},
561+
)
562+
)
563+
)
564+
events = await runner.run_async(user_confirmation)
565+
expected_parts_final = [
566+
(
567+
agent.name,
568+
Part(
569+
function_response=FunctionResponse(
570+
name=agent.tools[0].name,
571+
response={"result": "confirmed=True"},
572+
)
573+
),
574+
),
575+
(agent.name, "test llm response after tool call"),
576+
(agent.name, testing_utils.END_OF_AGENT),
577+
]
578+
for event in events:
579+
assert event.invocation_id == invocation_id
580+
assert (
581+
testing_utils.simplify_resumable_app_events(copy.deepcopy(events))
582+
== expected_parts_final
583+
)
584+
505585

506586
class TestHITLConfirmationFlowWithSequentialAgentAndResumableApp:
507587
"""Tests the HITL confirmation flow with a resumable sequential agent app."""

0 commit comments

Comments
 (0)