Skip to content

Commit 6d00943

Browse files
committed
fix(flows): resume long-running tools after matching responses
1 parent f973673 commit 6d00943

6 files changed

Lines changed: 181 additions & 14 deletions

File tree

src/google/adk/agents/invocation_context.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,34 @@ def should_pause_invocation(self, event: Event) -> bool:
396396

397397
return False
398398

399+
def has_unresolved_long_running_tool_calls(
400+
self, events: list[Event]
401+
) -> bool:
402+
"""Returns whether any long-running tool call in events is unresolved."""
403+
if not self.is_resumable or not events:
404+
return False
405+
406+
function_response_ids = {
407+
function_response.id
408+
for event in events
409+
for function_response in event.get_function_responses()
410+
if function_response.id
411+
}
412+
413+
for event in reversed(events):
414+
if not self.should_pause_invocation(event):
415+
continue
416+
417+
paused_function_call_ids = {
418+
function_call.id
419+
for function_call in event.get_function_calls()
420+
if function_call.id in event.long_running_tool_ids
421+
}
422+
if paused_function_call_ids - function_response_ids:
423+
return True
424+
425+
return False
426+
399427
# TODO: Move this method from invocation_context to a dedicated module.
400428
def _find_matching_function_call(
401429
self, function_response_event: Event

src/google/adk/agents/llm_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ async def _run_async_impl(
496496

497497
if ctx.is_resumable:
498498
events = ctx._get_events(current_invocation=True, current_branch=True)
499-
if events and any(ctx.should_pause_invocation(e) for e in events[-2:]):
499+
if ctx.has_unresolved_long_running_tool_calls(events):
500500
return
501501
# Only yield an end state if the last event is no longer a long-running
502502
# tool call.

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def _finalize_model_response_event(
9999
if finalized_event.content:
100100
function_calls = finalized_event.get_function_calls()
101101
if function_calls:
102+
functions.preserve_existing_function_call_ids(
103+
model_response_event, finalized_event
104+
)
102105
functions.populate_client_function_call_id(finalized_event)
103106
finalized_event.long_running_tool_ids = (
104107
functions.get_long_running_function_calls(
@@ -785,19 +788,7 @@ async def _run_one_step_async(
785788
# Long running tool calls should have been handled before this point.
786789
# If there are still long running tool calls, it means the agent is paused
787790
# before, and its branch hasn't been resumed yet.
788-
if (
789-
invocation_context.is_resumable
790-
and events
791-
and len(events) > 1
792-
# TODO: here we are using the last 2 events to decide whether to pause
793-
# the invocation. But this is just being optimistic, we should find a
794-
# way to pause when the long running tool call is followed by more than
795-
# one text responses.
796-
and (
797-
invocation_context.should_pause_invocation(events[-1])
798-
or invocation_context.should_pause_invocation(events[-2])
799-
)
800-
):
791+
if invocation_context.has_unresolved_long_running_tool_calls(events):
801792
return
802793

803794
if (

src/google/adk/flows/llm_flows/functions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,35 @@ def populate_client_function_call_id(model_response_event: Event) -> None:
189189
function_call.id = generate_client_function_call_id()
190190

191191

192+
def preserve_existing_function_call_ids(
193+
previous_event: Event, model_response_event: Event
194+
) -> None:
195+
"""Carries forward function call IDs from a previous streaming event.
196+
197+
Streaming responses may emit partial and final events for the same function
198+
call sequence. The partial event is sent to clients first, while only the
199+
final event is persisted. Preserving IDs across those events keeps
200+
functionResponse routing stable when the client resumes a long-running tool.
201+
202+
Args:
203+
previous_event: The in-flight model response event from an earlier chunk.
204+
model_response_event: The newly finalized event for the current chunk.
205+
"""
206+
previous_function_calls = previous_event.get_function_calls()
207+
current_function_calls = model_response_event.get_function_calls()
208+
if not previous_function_calls or not current_function_calls:
209+
return
210+
211+
for previous_function_call, current_function_call in zip(
212+
previous_function_calls, current_function_calls
213+
):
214+
if current_function_call.id:
215+
continue
216+
if previous_function_call.name != current_function_call.name:
217+
continue
218+
current_function_call.id = previous_function_call.id
219+
220+
192221
def remove_client_function_call_id(content: Optional[types.Content]) -> None:
193222
"""Removes ADK-generated function call IDs from content before sending to LLM.
194223

tests/unittests/agents/test_invocation_context.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from google.adk.sessions.session import Session
2525
from google.genai.types import Content
2626
from google.genai.types import FunctionCall
27+
from google.genai.types import FunctionResponse
2728
from google.genai.types import Part
2829
import pytest
2930

@@ -210,6 +211,82 @@ def test_should_not_pause_invocation_with_no_function_calls(
210211
nonpausable_event
211212
)
212213

214+
def test_has_unresolved_long_running_tool_calls_with_matching_response(self):
215+
"""Tests that matching function responses resolve the pause."""
216+
invocation_context = self._create_test_invocation_context(
217+
ResumabilityConfig(is_resumable=True)
218+
)
219+
function_call = FunctionCall(
220+
id='tool_call_id_1',
221+
name='long_running_function_call',
222+
args={},
223+
)
224+
paused_event = Event(
225+
invocation_id='inv_1',
226+
author='agent',
227+
content=testing_utils.ModelContent([Part(function_call=function_call)]),
228+
long_running_tool_ids={function_call.id},
229+
)
230+
resolved_event = Event(
231+
invocation_id='inv_1',
232+
author='user',
233+
content=Content(
234+
role='user',
235+
parts=[
236+
Part(
237+
function_response=FunctionResponse(
238+
name='long_running_function_call',
239+
response={'result': 'done'},
240+
id=function_call.id,
241+
)
242+
)
243+
],
244+
),
245+
)
246+
247+
assert not invocation_context.has_unresolved_long_running_tool_calls(
248+
[paused_event, resolved_event]
249+
)
250+
251+
def test_has_unresolved_long_running_tool_calls_without_matching_response(
252+
self,
253+
):
254+
"""Tests that unmatched long-running calls still pause the invocation."""
255+
invocation_context = self._create_test_invocation_context(
256+
ResumabilityConfig(is_resumable=True)
257+
)
258+
function_call = FunctionCall(
259+
id='tool_call_id_1',
260+
name='long_running_function_call',
261+
args={},
262+
)
263+
paused_event = Event(
264+
invocation_id='inv_1',
265+
author='agent',
266+
content=testing_utils.ModelContent([Part(function_call=function_call)]),
267+
long_running_tool_ids={function_call.id},
268+
)
269+
unrelated_response_event = Event(
270+
invocation_id='inv_1',
271+
author='user',
272+
content=Content(
273+
role='user',
274+
parts=[
275+
Part(
276+
function_response=FunctionResponse(
277+
name='long_running_function_call',
278+
response={'result': 'done'},
279+
id='different_tool_call_id',
280+
)
281+
)
282+
],
283+
),
284+
)
285+
286+
assert invocation_context.has_unresolved_long_running_tool_calls(
287+
[paused_event, unrelated_response_event]
288+
)
289+
213290
def test_is_resumable_true(self):
214291
"""Tests that is_resumable is True when resumability is enabled."""
215292
invocation_context = self._create_test_invocation_context(

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from google.adk.agents.llm_agent import Agent
2121
from google.adk.events.event import Event
22+
from google.adk.flows.llm_flows.base_llm_flow import _finalize_model_response_event
2223
from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback
2324
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
2425
from google.adk.models.google_llm import Gemini
@@ -41,6 +42,47 @@ class BaseLlmFlowForTesting(BaseLlmFlow):
4142
pass
4243

4344

45+
def test_finalize_model_response_event_preserves_function_call_ids():
46+
"""Test that streaming finalization keeps function call IDs stable."""
47+
previous_event = Event(
48+
id=Event.new_id(),
49+
invocation_id='test_invocation',
50+
author='test_agent',
51+
content=types.Content(
52+
role='model',
53+
parts=[
54+
types.Part(
55+
function_call=types.FunctionCall(
56+
name='track_execution',
57+
args={'call_id': 'partial'},
58+
id='adk-existing-id',
59+
)
60+
)
61+
],
62+
),
63+
partial=True,
64+
)
65+
llm_response = LlmResponse(
66+
content=types.Content(
67+
role='model',
68+
parts=[
69+
types.Part.from_function_call(
70+
name='track_execution', args={'call_id': 'final'}
71+
)
72+
],
73+
),
74+
partial=False,
75+
)
76+
77+
finalized_event = _finalize_model_response_event(
78+
LlmRequest(), llm_response, previous_event
79+
)
80+
81+
function_calls = finalized_event.get_function_calls()
82+
assert len(function_calls) == 1
83+
assert function_calls[0].id == 'adk-existing-id'
84+
85+
4486
@pytest.mark.asyncio
4587
async def test_preprocess_calls_toolset_process_llm_request():
4688
"""Test that _preprocess_async calls process_llm_request on toolsets."""

0 commit comments

Comments
 (0)