Skip to content

Commit 26d8390

Browse files
Merge branch 'main' into fix/2942-jinja2-instruction-templating
2 parents 3b0f26b + 9e3b43f commit 26d8390

5 files changed

Lines changed: 225 additions & 52 deletions

File tree

src/google/adk/agents/invocation_context.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..artifacts.base_artifact_service import BaseArtifactService
3232
from ..auth.auth_credential import AuthCredential
3333
from ..auth.credential_service.base_credential_service import BaseCredentialService
34+
from ..events._branch_path import _BranchPath
3435
from ..events.event import Event
3536
from ..memory.base_memory_service import BaseMemoryService
3637
from ..plugins.plugin_manager import PluginManager
@@ -466,9 +467,28 @@ def should_pause_invocation(self, event: Event) -> bool:
466467
if not event.long_running_tool_ids or not event.get_function_calls():
467468
return False
468469

470+
events = self.session.events if self.session else []
469471
for fc in event.get_function_calls():
470472
if fc.id in event.long_running_tool_ids:
471-
return True
473+
# Check if there is a newer user event in the session that belongs to a sub-branch of this tool call.
474+
# This indicates the tool call is resuming to process that nested input.
475+
is_resolving_sub_branch = False
476+
event_index = -1
477+
# Search backwards since the checked event is typically near the end of history.
478+
for i in range(len(events) - 1, -1, -1):
479+
if events[i].id == event.id:
480+
event_index = i
481+
break
482+
if event_index != -1:
483+
is_resolving_sub_branch = any(
484+
e.author == "user"
485+
and e.branch
486+
and fc.id in _BranchPath.from_string(e.branch).run_ids
487+
for e in events[event_index + 1 :]
488+
)
489+
490+
if not is_resolving_sub_branch:
491+
return True
472492

473493
return False
474494

@@ -483,11 +503,23 @@ def _find_matching_function_call(
483503
if not function_responses:
484504
return None
485505

486-
# Search backwards from the event before the current response event.
506+
events = self._get_events(current_invocation=True)
507+
if events and events[-1].id == function_response_event.id:
508+
search_space = events[:-1]
509+
else:
510+
search_space = events
511+
487512
return find_event_by_function_call_id(
488-
self._get_events(current_invocation=True)[:-1], function_responses[0].id
513+
search_space, function_responses[0].id
489514
)
490515

516+
def stamp_event_branch_context(self, event: Event) -> None:
517+
"""Stamps the event with the branch and isolation scope of its matching function call."""
518+
if function_call := self._find_matching_function_call(event):
519+
event.branch = function_call.branch
520+
if function_call.isolation_scope is not None:
521+
event.isolation_scope = function_call.isolation_scope
522+
491523

492524
def new_invocation_context_id() -> str:
493525
return "e-" + cast(str, platform_uuid.new_uuid())

src/google/adk/models/gemini_llm_connection.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -86,33 +86,10 @@ async def send_history(self, history: list[types.Content]):
8686
]
8787

8888
if contents:
89-
# Gemini Enterprise Agent Platform does not support history_config in the
90-
# SDK. To initialize a live session with prior history without hitting a
91-
# 1007 protocol error (invalid role mid-session), we consolidate previous
92-
# multi-turn interactions into a unified contextual preamble on a single
93-
# user role turn.
94-
if (
95-
self._is_gemini_3_1_flash_live
96-
and self._api_backend != GoogleLLMVariant.GEMINI_API
97-
):
98-
collapsed_text = 'Previous conversation history:\n'
99-
for c in contents:
100-
text_parts = ''.join(p.text for p in c.parts if p.text)
101-
collapsed_text += f'[{c.role}]: {text_parts}\n'
102-
contents = [
103-
types.Content(
104-
role='user', parts=[types.Part.from_text(text=collapsed_text)]
105-
)
106-
]
107-
10889
logger.debug('Sending history to live connection: %s', contents)
10990
await self._gemini_session.send_client_content(
11091
turns=contents,
111-
turn_complete=(
112-
True
113-
if self._is_gemini_3_1_flash_live
114-
else contents[-1].role == 'user'
115-
),
92+
turn_complete=contents[-1].role == 'user',
11693
)
11794
else:
11895
logger.info('no content is sent')

src/google/adk/runners.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ async def _append_user_event(
740740
if iso is not None:
741741
event.isolation_scope = iso
742742
_apply_run_config_custom_metadata(event, ic.run_config)
743+
ic.stamp_event_branch_context(event)
743744
return await self.session_service.append_event(
744745
session=ic.session, event=event
745746
)
@@ -1482,10 +1483,7 @@ async def _append_new_message_to_session(
14821483
content=new_message,
14831484
)
14841485
_apply_run_config_custom_metadata(event, invocation_context.run_config)
1485-
# If new_message is a function response, find the matching function call
1486-
# and use its branch as the new event's branch.
1487-
if function_call := invocation_context._find_matching_function_call(event):
1488-
event.branch = function_call.branch
1486+
invocation_context.stamp_event_branch_context(event)
14891487

14901488
await self.session_service.append_event(
14911489
session=invocation_context.session, event=event

tests/unittests/agents/test_invocation_context.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ def event_to_pause(self, long_running_function_call) -> Event:
152152
)
153153

154154
def _create_test_invocation_context(
155-
self, resumability_config
155+
self, resumability_config: ResumabilityConfig | None = None
156156
) -> InvocationContext:
157157
"""Create a mock invocation context for testing."""
158158
ctx = InvocationContext(
159159
session_service=Mock(spec=BaseSessionService),
160160
agent=Mock(spec=BaseAgent),
161161
invocation_id='inv_1',
162-
session=Mock(spec=Session),
162+
session=Mock(spec=Session, events=[]),
163163
resumability_config=resumability_config,
164164
)
165165
return ctx
@@ -208,6 +208,69 @@ def test_should_not_pause_invocation_with_no_function_calls(
208208
nonpausable_event
209209
)
210210

211+
def test_should_not_pause_when_user_resumes_in_sub_branch(
212+
self, event_to_pause, long_running_function_call
213+
):
214+
"""We do not pause the invocation if a subsequent user event belongs to a sub-branch."""
215+
# Arrange
216+
mock_invocation_context = self._create_test_invocation_context()
217+
user_event = Event(
218+
invocation_id='inv_1',
219+
author='user',
220+
branch=f'agent@{long_running_function_call.id}.child',
221+
)
222+
mock_invocation_context.session.events = [event_to_pause, user_event]
223+
224+
# Act
225+
should_pause = mock_invocation_context.should_pause_invocation(
226+
event_to_pause
227+
)
228+
229+
# Assert
230+
assert not should_pause
231+
232+
def test_should_not_pause_when_user_resumes_in_deeply_nested_sub_branch(
233+
self, event_to_pause, long_running_function_call
234+
):
235+
"""We do not pause if the user resumes in a deeply nested sub-branch containing the tool call."""
236+
# Arrange
237+
mock_invocation_context = self._create_test_invocation_context()
238+
user_event = Event(
239+
invocation_id='inv_1',
240+
author='user',
241+
branch=f'parent@other.child@{long_running_function_call.id}.grandchild',
242+
)
243+
mock_invocation_context.session.events = [event_to_pause, user_event]
244+
245+
# Act
246+
should_pause = mock_invocation_context.should_pause_invocation(
247+
event_to_pause
248+
)
249+
250+
# Assert
251+
assert not should_pause
252+
253+
def test_should_pause_when_user_resumes_in_different_branch(
254+
self, event_to_pause
255+
):
256+
"""We still pause the invocation if the subsequent user event belongs to a different branch."""
257+
# Arrange
258+
mock_invocation_context = self._create_test_invocation_context()
259+
user_event = Event(
260+
invocation_id='inv_1',
261+
author='user',
262+
branch='parent@different_id.child',
263+
)
264+
mock_invocation_context.session.events = [event_to_pause, user_event]
265+
266+
# Act
267+
should_pause = mock_invocation_context.should_pause_invocation(
268+
event_to_pause
269+
)
270+
271+
# Assert
272+
assert should_pause
273+
211274
def test_is_resumable_true(self):
212275
"""Tests that is_resumable is True when resumability is enabled."""
213276
invocation_context = self._create_test_invocation_context(
@@ -534,3 +597,32 @@ def test_find_matching_function_call_no_response_in_event(
534597
invocation_context = test_invocation_context([fc_event, fr_event])
535598
match = invocation_context._find_matching_function_call(fr_event_no_fr)
536599
assert match is None
600+
601+
def test_stamp_event_branch_context_preserves_isolation_scope(
602+
self, test_invocation_context
603+
):
604+
"""Tests stamp_event_branch_context does not overwrite existing isolation_scope with None."""
605+
fc = Part.from_function_call(name='some_tool', args={})
606+
fc.function_call.id = 'test_function_call_id'
607+
fc_event = Event(
608+
invocation_id='inv_1',
609+
author='agent',
610+
branch='root@1',
611+
isolation_scope=None, # Coordinator FC has None scope
612+
content=testing_utils.ModelContent([fc]),
613+
)
614+
fr = Part.from_function_response(
615+
name='some_tool', response={'result': 'ok'}
616+
)
617+
fr.function_response.id = 'test_function_call_id'
618+
fr_event = Event(
619+
invocation_id='inv_1',
620+
author='agent',
621+
isolation_scope='task_123', # Pre-populated active task scope
622+
content=Content(role='user', parts=[fr]),
623+
)
624+
invocation_context = test_invocation_context([fc_event, fr_event])
625+
626+
invocation_context.stamp_event_branch_context(fr_event)
627+
assert fr_event.branch == 'root@1'
628+
assert fr_event.isolation_scope == 'task_123'

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 93 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -949,54 +949,128 @@ async def test_send_history_filters_various_audio_mime_types(
949949

950950
@pytest.mark.asyncio
951951
async def test_send_history_gemini_31_turn_complete(mock_gemini_session):
952-
"""Verify Gemini 3.1 Live history seeding explicitly appends turn_complete=True."""
952+
"""Verify Gemini 3.1 Live history seeding sets turn_complete based on history[-1].role == 'user'."""
953953
conn = GeminiLlmConnection(
954954
mock_gemini_session,
955955
api_backend=GoogleLLMVariant.GEMINI_API,
956956
model_version='gemini-3.1-flash-live-preview',
957957
)
958958
mock_gemini_session.send_client_content = mock.AsyncMock()
959959

960-
mock_contents = [
960+
# Last turn is model -> turn_complete=False
961+
mock_contents_model = [
961962
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
962963
types.Content(role='model', parts=[types.Part.from_text(text='hello')]),
963964
]
964-
await conn.send_history(mock_contents)
965+
await conn.send_history(mock_contents_model)
965966

966967
mock_gemini_session.send_client_content.assert_called_once_with(
967-
turns=mock_contents,
968+
turns=mock_contents_model,
969+
turn_complete=False,
970+
)
971+
972+
# Last turn is user -> turn_complete=True
973+
mock_gemini_session.send_client_content.reset_mock()
974+
mock_contents_user = [
975+
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
976+
]
977+
await conn.send_history(mock_contents_user)
978+
979+
mock_gemini_session.send_client_content.assert_called_once_with(
980+
turns=mock_contents_user,
968981
turn_complete=True,
969982
)
970983

971984

972985
@pytest.mark.asyncio
973-
async def test_send_history_collapse_vertex_ai(mock_gemini_session):
974-
"""Verify history prompt collapse when seeding Gemini 3.1 Live on Vertex AI backend."""
986+
async def test_send_history_vertex_ai_no_collapse(mock_gemini_session):
987+
"""Verify history is sent without collapsing on Vertex AI backend."""
975988
conn = GeminiLlmConnection(
976989
mock_gemini_session,
977990
api_backend=GoogleLLMVariant.VERTEX_AI,
978991
model_version='gemini-3.1-flash-live-preview',
979992
)
980993
mock_gemini_session.send_client_content = mock.AsyncMock()
981994

982-
mock_contents = [
995+
# Last turn is model -> turn_complete=False
996+
mock_contents_model = [
983997
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
984998
types.Content(role='model', parts=[types.Part.from_text(text='hello')]),
985999
]
986-
await conn.send_history(mock_contents)
1000+
await conn.send_history(mock_contents_model)
9871001

988-
assert mock_gemini_session.send_client_content.call_count == 1
989-
called_turns = mock_gemini_session.send_client_content.call_args.kwargs[
990-
'turns'
1002+
mock_gemini_session.send_client_content.assert_called_once_with(
1003+
turns=mock_contents_model,
1004+
turn_complete=False,
1005+
)
1006+
1007+
# Last turn is user -> turn_complete=True
1008+
mock_gemini_session.send_client_content.reset_mock()
1009+
mock_contents_user = [
1010+
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
1011+
types.Content(role='model', parts=[types.Part.from_text(text='hello')]),
1012+
types.Content(
1013+
role='user', parts=[types.Part.from_text(text='how are you?')]
1014+
),
9911015
]
992-
assert len(called_turns) == 1
993-
assert called_turns[0].role == 'user'
994-
assert 'Previous conversation history:' in called_turns[0].parts[0].text
995-
assert '[user]: hi' in called_turns[0].parts[0].text
996-
assert '[model]: hello' in called_turns[0].parts[0].text
997-
assert (
998-
mock_gemini_session.send_client_content.call_args.kwargs['turn_complete']
999-
is True
1016+
await conn.send_history(mock_contents_user)
1017+
1018+
mock_gemini_session.send_client_content.assert_called_once_with(
1019+
turns=mock_contents_user,
1020+
turn_complete=True,
1021+
)
1022+
1023+
1024+
@pytest.mark.asyncio
1025+
async def test_send_history_turn_complete_determined_by_filtered_content(
1026+
mock_gemini_session,
1027+
):
1028+
"""Verify turn_complete is determined by the last element of filtered content instead of unfiltered history."""
1029+
conn = GeminiLlmConnection(
1030+
mock_gemini_session,
1031+
api_backend=GoogleLLMVariant.GEMINI_API,
1032+
model_version='gemini-3.1-flash-live-preview',
1033+
)
1034+
mock_gemini_session.send_client_content = mock.AsyncMock()
1035+
1036+
# Scenario: Last turn in history is a user audio turn (gets filtered out).
1037+
# The remaining last turn is model's turn -> turn_complete should be False.
1038+
audio_part = types.Part(
1039+
inline_data=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm')
1040+
)
1041+
history_with_final_audio_user_turn = [
1042+
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
1043+
types.Content(role='model', parts=[types.Part.from_text(text='hello')]),
1044+
types.Content(role='user', parts=[audio_part]),
1045+
]
1046+
1047+
await conn.send_history(history_with_final_audio_user_turn)
1048+
1049+
expected_contents = [
1050+
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
1051+
types.Content(role='model', parts=[types.Part.from_text(text='hello')]),
1052+
]
1053+
mock_gemini_session.send_client_content.assert_called_once_with(
1054+
turns=expected_contents,
1055+
turn_complete=False,
1056+
)
1057+
1058+
# Scenario: Last turn in history is a model audio turn (gets filtered out).
1059+
# The remaining last turn is user's turn -> turn_complete should be True.
1060+
mock_gemini_session.send_client_content.reset_mock()
1061+
history_with_final_audio_model_turn = [
1062+
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
1063+
types.Content(role='model', parts=[audio_part]),
1064+
]
1065+
1066+
await conn.send_history(history_with_final_audio_model_turn)
1067+
1068+
expected_contents = [
1069+
types.Content(role='user', parts=[types.Part.from_text(text='hi')]),
1070+
]
1071+
mock_gemini_session.send_client_content.assert_called_once_with(
1072+
turns=expected_contents,
1073+
turn_complete=True,
10001074
)
10011075

10021076

0 commit comments

Comments
 (0)