Skip to content

Commit 9e3b43f

Browse files
DeanChensjcopybara-github
authored andcommitted
fix: Fix branch parsing and unify branch stamping
- Fix branch parsing bug in should_pause_invocation to correctly handle dot-separated nested branch paths and avoid substring collisions. - Unify branch stamping in runners.py by replacing manual branch matching with InvocationContext.stamp_event_branch_context(). Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 938841304
1 parent c007a87 commit 9e3b43f

3 files changed

Lines changed: 131 additions & 9 deletions

File tree

src/google/adk/agents/invocation_context.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..artifacts.base_artifact_service import BaseArtifactService
3232
from ..auth.auth_credential import AuthCredential
3333
from ..auth.credential_service.base_credential_service import BaseCredentialService
34+
from ..events._branch_path import _BranchPath
3435
from ..events.event import Event
3536
from ..memory.base_memory_service import BaseMemoryService
3637
from ..plugins.plugin_manager import PluginManager
@@ -466,9 +467,28 @@ def should_pause_invocation(self, event: Event) -> bool:
466467
if not event.long_running_tool_ids or not event.get_function_calls():
467468
return False
468469

470+
events = self.session.events if self.session else []
469471
for fc in event.get_function_calls():
470472
if fc.id in event.long_running_tool_ids:
471-
return True
473+
# Check if there is a newer user event in the session that belongs to a sub-branch of this tool call.
474+
# This indicates the tool call is resuming to process that nested input.
475+
is_resolving_sub_branch = False
476+
event_index = -1
477+
# Search backwards since the checked event is typically near the end of history.
478+
for i in range(len(events) - 1, -1, -1):
479+
if events[i].id == event.id:
480+
event_index = i
481+
break
482+
if event_index != -1:
483+
is_resolving_sub_branch = any(
484+
e.author == "user"
485+
and e.branch
486+
and fc.id in _BranchPath.from_string(e.branch).run_ids
487+
for e in events[event_index + 1 :]
488+
)
489+
490+
if not is_resolving_sub_branch:
491+
return True
472492

473493
return False
474494

@@ -483,11 +503,23 @@ def _find_matching_function_call(
483503
if not function_responses:
484504
return None
485505

486-
# Search backwards from the event before the current response event.
506+
events = self._get_events(current_invocation=True)
507+
if events and events[-1].id == function_response_event.id:
508+
search_space = events[:-1]
509+
else:
510+
search_space = events
511+
487512
return find_event_by_function_call_id(
488-
self._get_events(current_invocation=True)[:-1], function_responses[0].id
513+
search_space, function_responses[0].id
489514
)
490515

516+
def stamp_event_branch_context(self, event: Event) -> None:
517+
"""Stamps the event with the branch and isolation scope of its matching function call."""
518+
if function_call := self._find_matching_function_call(event):
519+
event.branch = function_call.branch
520+
if function_call.isolation_scope is not None:
521+
event.isolation_scope = function_call.isolation_scope
522+
491523

492524
def new_invocation_context_id() -> str:
493525
return "e-" + cast(str, platform_uuid.new_uuid())

src/google/adk/runners.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ async def _append_user_event(
740740
if iso is not None:
741741
event.isolation_scope = iso
742742
_apply_run_config_custom_metadata(event, ic.run_config)
743+
ic.stamp_event_branch_context(event)
743744
return await self.session_service.append_event(
744745
session=ic.session, event=event
745746
)
@@ -1482,10 +1483,7 @@ async def _append_new_message_to_session(
14821483
content=new_message,
14831484
)
14841485
_apply_run_config_custom_metadata(event, invocation_context.run_config)
1485-
# If new_message is a function response, find the matching function call
1486-
# and use its branch as the new event's branch.
1487-
if function_call := invocation_context._find_matching_function_call(event):
1488-
event.branch = function_call.branch
1486+
invocation_context.stamp_event_branch_context(event)
14891487

14901488
await self.session_service.append_event(
14911489
session=invocation_context.session, event=event

tests/unittests/agents/test_invocation_context.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ def event_to_pause(self, long_running_function_call) -> Event:
152152
)
153153

154154
def _create_test_invocation_context(
155-
self, resumability_config
155+
self, resumability_config: ResumabilityConfig | None = None
156156
) -> InvocationContext:
157157
"""Create a mock invocation context for testing."""
158158
ctx = InvocationContext(
159159
session_service=Mock(spec=BaseSessionService),
160160
agent=Mock(spec=BaseAgent),
161161
invocation_id='inv_1',
162-
session=Mock(spec=Session),
162+
session=Mock(spec=Session, events=[]),
163163
resumability_config=resumability_config,
164164
)
165165
return ctx
@@ -208,6 +208,69 @@ def test_should_not_pause_invocation_with_no_function_calls(
208208
nonpausable_event
209209
)
210210

211+
def test_should_not_pause_when_user_resumes_in_sub_branch(
212+
self, event_to_pause, long_running_function_call
213+
):
214+
"""We do not pause the invocation if a subsequent user event belongs to a sub-branch."""
215+
# Arrange
216+
mock_invocation_context = self._create_test_invocation_context()
217+
user_event = Event(
218+
invocation_id='inv_1',
219+
author='user',
220+
branch=f'agent@{long_running_function_call.id}.child',
221+
)
222+
mock_invocation_context.session.events = [event_to_pause, user_event]
223+
224+
# Act
225+
should_pause = mock_invocation_context.should_pause_invocation(
226+
event_to_pause
227+
)
228+
229+
# Assert
230+
assert not should_pause
231+
232+
def test_should_not_pause_when_user_resumes_in_deeply_nested_sub_branch(
233+
self, event_to_pause, long_running_function_call
234+
):
235+
"""We do not pause if the user resumes in a deeply nested sub-branch containing the tool call."""
236+
# Arrange
237+
mock_invocation_context = self._create_test_invocation_context()
238+
user_event = Event(
239+
invocation_id='inv_1',
240+
author='user',
241+
branch=f'parent@other.child@{long_running_function_call.id}.grandchild',
242+
)
243+
mock_invocation_context.session.events = [event_to_pause, user_event]
244+
245+
# Act
246+
should_pause = mock_invocation_context.should_pause_invocation(
247+
event_to_pause
248+
)
249+
250+
# Assert
251+
assert not should_pause
252+
253+
def test_should_pause_when_user_resumes_in_different_branch(
254+
self, event_to_pause
255+
):
256+
"""We still pause the invocation if the subsequent user event belongs to a different branch."""
257+
# Arrange
258+
mock_invocation_context = self._create_test_invocation_context()
259+
user_event = Event(
260+
invocation_id='inv_1',
261+
author='user',
262+
branch='parent@different_id.child',
263+
)
264+
mock_invocation_context.session.events = [event_to_pause, user_event]
265+
266+
# Act
267+
should_pause = mock_invocation_context.should_pause_invocation(
268+
event_to_pause
269+
)
270+
271+
# Assert
272+
assert should_pause
273+
211274
def test_is_resumable_true(self):
212275
"""Tests that is_resumable is True when resumability is enabled."""
213276
invocation_context = self._create_test_invocation_context(
@@ -534,3 +597,32 @@ def test_find_matching_function_call_no_response_in_event(
534597
invocation_context = test_invocation_context([fc_event, fr_event])
535598
match = invocation_context._find_matching_function_call(fr_event_no_fr)
536599
assert match is None
600+
601+
def test_stamp_event_branch_context_preserves_isolation_scope(
602+
self, test_invocation_context
603+
):
604+
"""Tests stamp_event_branch_context does not overwrite existing isolation_scope with None."""
605+
fc = Part.from_function_call(name='some_tool', args={})
606+
fc.function_call.id = 'test_function_call_id'
607+
fc_event = Event(
608+
invocation_id='inv_1',
609+
author='agent',
610+
branch='root@1',
611+
isolation_scope=None, # Coordinator FC has None scope
612+
content=testing_utils.ModelContent([fc]),
613+
)
614+
fr = Part.from_function_response(
615+
name='some_tool', response={'result': 'ok'}
616+
)
617+
fr.function_response.id = 'test_function_call_id'
618+
fr_event = Event(
619+
invocation_id='inv_1',
620+
author='agent',
621+
isolation_scope='task_123', # Pre-populated active task scope
622+
content=Content(role='user', parts=[fr]),
623+
)
624+
invocation_context = test_invocation_context([fc_event, fr_event])
625+
626+
invocation_context.stamp_event_branch_context(fr_event)
627+
assert fr_event.branch == 'root@1'
628+
assert fr_event.isolation_scope == 'task_123'

0 commit comments

Comments
 (0)