Skip to content

Commit 7251ed3

Browse files
authored
Merge branch 'main' into wjy/fix-interrupted
2 parents 30e1505 + 454188d commit 7251ed3

6 files changed

Lines changed: 168 additions & 49 deletions

File tree

pyproject.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ dependencies = [
4444
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
4545
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
4646
"google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service
47-
"google-genai>=1.64.0, <2.0.0", # Google GenAI SDK
47+
"google-genai>=1.72.0, <2.0.0", # Google GenAI SDK
4848
"graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering
4949
"httpx>=0.27.0, <1.0.0", # HTTP client library
5050
"jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation
51-
"mcp>=1.23.0, <2.0.0", # For MCP Toolset
51+
"mcp>=1.24.0, <2.0.0", # For MCP Toolset
5252
"opentelemetry-api>=1.36.0, <1.39.0", # OpenTelemetry - keep below 1.39.0 due to current agent_engines exporter constraints.
5353
"opentelemetry-exporter-gcp-logging>=1.9.0a0, <2.0.0",
5454
"opentelemetry-exporter-gcp-monitoring>=1.9.0a0, <2.0.0",
@@ -123,7 +123,8 @@ test = [
123123
"a2a-sdk>=0.3.0,<0.4.0",
124124
"anthropic>=0.43.0", # For anthropic model tests
125125
"crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+
126-
"google-cloud-firestore>=2.11.0",
126+
"google-cloud-firestore>=2.11.0, <3.0.0",
127+
"google-cloud-iamconnectorcredentials>=0.1.0, <0.2.0",
127128
"google-cloud-parametermanager>=0.4.0, <1.0.0",
128129
"kubernetes>=29.0.0", # For GkeCodeExecutor
129130
"langchain-community>=0.3.17",
@@ -159,7 +160,7 @@ extensions = [
159160
"beautifulsoup4>=3.2.2", # For load_web_page tool.
160161
"crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+
161162
"docker>=7.0.0", # For ContainerCodeExecutor
162-
"google-cloud-firestore>=2.11.0", # For Firestore services
163+
"google-cloud-firestore>=2.11.0, <3.0.0", # For Firestore services
163164
"google-cloud-parametermanager>=0.4.0, <1.0.0",
164165
"kubernetes>=29.0.0", # For GkeCodeExecutor
165166
"k8s-agent-sandbox>=0.1.1.post3", # For GkeCodeExecutor sandbox mode
@@ -178,6 +179,10 @@ toolbox = ["toolbox-adk>=1.0.0, <2.0.0"]
178179

179180
slack = ["slack-bolt>=1.22.0"]
180181

182+
agent-identity = [
183+
"google-cloud-iamconnectorcredentials>=0.1.0, <0.2.0",
184+
]
185+
181186
[tool.pyink]
182187
# Format py files following Google style-guide
183188
line-length = 80

src/google/adk/models/gemini_llm_connection.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,9 @@ async def send_history(self, history: list[types.Content]):
8181

8282
if contents:
8383
logger.debug('Sending history to live connection: %s', contents)
84-
await self._gemini_session.send(
85-
input=types.LiveClientContent(
86-
turns=contents,
87-
turn_complete=contents[-1].role == 'user',
88-
),
84+
await self._gemini_session.send_client_content(
85+
turns=contents,
86+
turn_complete=contents[-1].role == 'user',
8987
)
9088
else:
9189
logger.info('no content is sent')
@@ -105,10 +103,8 @@ async def send_content(self, content: types.Content):
105103
# All parts have to be function responses.
106104
function_responses = [part.function_response for part in content.parts]
107105
logger.debug('Sending LLM function response: %s', function_responses)
108-
await self._gemini_session.send(
109-
input=types.LiveClientToolResponse(
110-
function_responses=function_responses
111-
),
106+
await self._gemini_session.send_tool_response(
107+
function_responses=function_responses
112108
)
113109
else:
114110
logger.debug('Sending LLM new content %s', content)

src/google/adk/plugins/base_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,10 @@ async def before_run_callback(
155155
async def on_event_callback(
156156
self, *, invocation_context: InvocationContext, event: Event
157157
) -> Optional[Event]:
158-
"""Callback executed after an event is yielded from runner.
158+
"""Callback executed when the runner produces an event.
159159
160-
This is the ideal place to make modification to the event before the event
161-
is handled by the underlying agent app.
160+
This is the ideal place to modify the event before it is persisted to the
161+
session service and yielded to the caller.
162162
163163
Args:
164164
invocation_context: The context for the entire invocation.

src/google/adk/runners.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,34 @@ def _should_append_event(self, event: Event, is_live_call: bool) -> bool:
791791
return False
792792
return True
793793

794+
def _get_output_event(
795+
self,
796+
*,
797+
original_event: Event,
798+
modified_event: Event | None,
799+
run_config: RunConfig | None,
800+
) -> Event:
801+
"""Returns the event that should be persisted and yielded.
802+
803+
Plugins may return a replacement event that only overrides a subset of
804+
fields. Merge those changes onto the original event so the streamed event
805+
and the persisted event stay aligned without losing the original event
806+
identity.
807+
"""
808+
if modified_event is None:
809+
return original_event
810+
811+
_apply_run_config_custom_metadata(modified_event, run_config)
812+
update = {}
813+
for field_name in modified_event.model_fields_set:
814+
if field_name in {'id', 'invocation_id', 'timestamp'}:
815+
continue
816+
update[field_name] = modified_event.__dict__[field_name]
817+
output_event = original_event.model_copy(update=update)
818+
if not output_event.author:
819+
output_event.author = original_event.author
820+
return output_event
821+
794822
async def _exec_with_plugin(
795823
self,
796824
invocation_context: InvocationContext,
@@ -854,13 +882,24 @@ async def _exec_with_plugin(
854882
_apply_run_config_custom_metadata(
855883
event, invocation_context.run_config
856884
)
885+
# Step 3: Run the on_event callbacks before persisting so callback
886+
# changes are stored in the session and match the streamed event.
887+
modified_event = await plugin_manager.run_on_event_callback(
888+
invocation_context=invocation_context, event=event
889+
)
890+
output_event = self._get_output_event(
891+
original_event=event,
892+
modified_event=modified_event,
893+
run_config=invocation_context.run_config,
894+
)
895+
857896
if is_live_call:
858897
if event.partial and _is_transcription(event):
859898
is_transcribing = True
860899
if is_transcribing and _is_tool_call_or_response(event):
861900
# only buffer function call and function response event which is
862901
# non-partial
863-
buffered_events.append(event)
902+
buffered_events.append(output_event)
864903
continue
865904
# Note for live/bidi: for audio response, it's considered as
866905
# non-partial event(event.partial=None)
@@ -881,7 +920,7 @@ async def _exec_with_plugin(
881920
)
882921
if self._should_append_event(event, is_live_call):
883922
await self.session_service.append_event(
884-
session=session, event=event
923+
session=session, event=output_event
885924
)
886925

887926
for buffered_event in buffered_events:
@@ -897,25 +936,15 @@ async def _exec_with_plugin(
897936
if self._should_append_event(event, is_live_call):
898937
logger.debug('Appending non-buffered event: %s', event)
899938
await self.session_service.append_event(
900-
session=session, event=event
939+
session=session, event=output_event
901940
)
902941
else:
903942
if event.partial is not True:
904943
await self.session_service.append_event(
905-
session=session, event=event
944+
session=session, event=output_event
906945
)
907946

908-
# Step 3: Run the on_event callbacks to optionally modify the event.
909-
modified_event = await plugin_manager.run_on_event_callback(
910-
invocation_context=invocation_context, event=event
911-
)
912-
if modified_event:
913-
_apply_run_config_custom_metadata(
914-
modified_event, invocation_context.run_config
915-
)
916-
yield modified_event
917-
else:
918-
yield event
947+
yield output_event
919948

920949
# Step 4: Run the after_run callbacks to perform global cleanup tasks or
921950
# finalizing logs and metrics data.

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ async def test_send_history(gemini_connection, mock_gemini_session):
8181

8282
await gemini_connection.send_history(history)
8383

84-
mock_gemini_session.send.assert_called_once()
85-
call_args = mock_gemini_session.send.call_args[1]
86-
assert 'input' in call_args
87-
assert call_args['input'].turns == history
88-
assert call_args['input'].turn_complete is False # Last message is from model
84+
mock_gemini_session.send_client_content.assert_called_once()
85+
call_args = mock_gemini_session.send_client_content.call_args[1]
86+
assert 'turns' in call_args
87+
assert call_args['turns'] == history
88+
assert call_args['turn_complete'] is False # Last message is from model
8989

9090

9191
@pytest.mark.asyncio
@@ -118,10 +118,10 @@ async def test_send_content_function_response(
118118

119119
await gemini_connection.send_content(content)
120120

121-
mock_gemini_session.send.assert_called_once()
122-
call_args = mock_gemini_session.send.call_args[1]
123-
assert 'input' in call_args
124-
assert call_args['input'].function_responses == [function_response]
121+
mock_gemini_session.send_tool_response.assert_called_once()
122+
call_args = mock_gemini_session.send_tool_response.call_args[1]
123+
assert 'function_responses' in call_args
124+
assert call_args['function_responses'] == [function_response]
125125

126126

127127
@pytest.mark.asyncio
@@ -668,9 +668,9 @@ async def test_send_history_filters_audio(mock_gemini_session, audio_part):
668668

669669
await connection.send_history(history)
670670

671-
mock_gemini_session.send.assert_called_once()
672-
call_args = mock_gemini_session.send.call_args[1]
673-
sent_contents = call_args['input'].turns
671+
mock_gemini_session.send_client_content.assert_called_once()
672+
call_args = mock_gemini_session.send_client_content.call_args[1]
673+
sent_contents = call_args['turns']
674674
# Only the model response should be sent (user audio filtered out)
675675
assert len(sent_contents) == 1
676676
assert sent_contents[0].role == 'model'
@@ -696,9 +696,9 @@ async def test_send_history_keeps_image_data(mock_gemini_session):
696696

697697
await connection.send_history(history)
698698

699-
mock_gemini_session.send.assert_called_once()
700-
call_args = mock_gemini_session.send.call_args[1]
701-
sent_contents = call_args['input'].turns
699+
mock_gemini_session.send_client_content.assert_called_once()
700+
call_args = mock_gemini_session.send_client_content.call_args[1]
701+
sent_contents = call_args['turns']
702702
# Both contents should be sent (image is not filtered)
703703
assert len(sent_contents) == 2
704704
assert sent_contents[0].parts[0].inline_data == image_blob
@@ -728,9 +728,9 @@ async def test_send_history_mixed_content_filters_only_audio(
728728

729729
await connection.send_history(history)
730730

731-
mock_gemini_session.send.assert_called_once()
732-
call_args = mock_gemini_session.send.call_args[1]
733-
sent_contents = call_args['input'].turns
731+
mock_gemini_session.send_client_content.assert_called_once()
732+
call_args = mock_gemini_session.send_client_content.call_args[1]
733+
sent_contents = call_args['turns']
734734
# Content should be sent but only with the text part
735735
assert len(sent_contents) == 1
736736
assert len(sent_contents[0].parts) == 1

tests/unittests/test_runners.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class MockPlugin(BasePlugin):
139139
"Modified user message ON_USER_CALLBACK_MSG from MockPlugin"
140140
)
141141
ON_EVENT_CALLBACK_MSG = "Modified event ON_EVENT_CALLBACK_MSG from MockPlugin"
142+
ON_EVENT_CALLBACK_METADATA = {"plugin_key": "plugin_value"}
142143

143144
def __init__(self):
144145
super().__init__(name="mock_plugin")
@@ -184,6 +185,7 @@ async def on_event_callback(
184185
],
185186
role=event.content.role,
186187
),
188+
custom_metadata=self.ON_EVENT_CALLBACK_METADATA,
187189
)
188190

189191

@@ -359,6 +361,60 @@ async def test_run_live_auto_create_session():
359361
assert session is not None
360362

361363

364+
@pytest.mark.asyncio
365+
async def test_run_live_persists_event_callback_modifications():
366+
"""run_live should persist the same event it streams after callback changes."""
367+
session_service = InMemorySessionService()
368+
artifact_service = InMemoryArtifactService()
369+
plugin = MockPlugin()
370+
plugin.enable_event_callback = True
371+
runner = Runner(
372+
app_name="live_app",
373+
agent=MockLiveAgent("live_agent"),
374+
session_service=session_service,
375+
artifact_service=artifact_service,
376+
plugins=[plugin],
377+
)
378+
await session_service.create_session(
379+
app_name="live_app", user_id="user", session_id="live_session"
380+
)
381+
382+
from google.adk.agents.live_request_queue import LiveRequestQueue
383+
384+
live_queue = LiveRequestQueue()
385+
agen = runner.run_live(
386+
user_id="user",
387+
session_id="live_session",
388+
live_request_queue=live_queue,
389+
)
390+
391+
streamed_event = await agen.__anext__()
392+
await agen.aclose()
393+
394+
session = await session_service.get_session(
395+
app_name="live_app", user_id="user", session_id="live_session"
396+
)
397+
persisted_event = session.events[0]
398+
399+
assert streamed_event.author == "live_agent"
400+
assert streamed_event.invocation_id
401+
assert streamed_event.content.parts[0].text == (
402+
MockPlugin.ON_EVENT_CALLBACK_MSG
403+
)
404+
assert streamed_event.custom_metadata == MockPlugin.ON_EVENT_CALLBACK_METADATA
405+
406+
assert persisted_event.id == streamed_event.id
407+
assert persisted_event.timestamp == streamed_event.timestamp
408+
assert persisted_event.author == streamed_event.author
409+
assert persisted_event.invocation_id == streamed_event.invocation_id
410+
assert persisted_event.content.parts[0].text == (
411+
MockPlugin.ON_EVENT_CALLBACK_MSG
412+
)
413+
assert (
414+
persisted_event.custom_metadata == MockPlugin.ON_EVENT_CALLBACK_METADATA
415+
)
416+
417+
362418
@pytest.mark.asyncio
363419
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
364420
project_root = tmp_path / "workspace"
@@ -747,6 +803,39 @@ async def test_runner_modifies_event_after_execution(self):
747803

748804
assert modified_event_message == MockPlugin.ON_EVENT_CALLBACK_MSG
749805

806+
@pytest.mark.asyncio
807+
async def test_runner_persists_event_callback_modifications(self):
808+
"""Event callback output should be persisted, not only streamed."""
809+
self.plugin.enable_event_callback = True
810+
811+
events = await self.run_test()
812+
streamed_event = events[0]
813+
814+
session = await self.session_service.get_session(
815+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
816+
)
817+
persisted_event = session.events[1]
818+
819+
assert streamed_event.author == "test_agent"
820+
assert streamed_event.invocation_id
821+
assert streamed_event.content.parts[0].text == (
822+
MockPlugin.ON_EVENT_CALLBACK_MSG
823+
)
824+
assert (
825+
streamed_event.custom_metadata == MockPlugin.ON_EVENT_CALLBACK_METADATA
826+
)
827+
828+
assert persisted_event.id == streamed_event.id
829+
assert persisted_event.timestamp == streamed_event.timestamp
830+
assert persisted_event.author == streamed_event.author
831+
assert persisted_event.invocation_id == streamed_event.invocation_id
832+
assert persisted_event.content.parts[0].text == (
833+
MockPlugin.ON_EVENT_CALLBACK_MSG
834+
)
835+
assert (
836+
persisted_event.custom_metadata == MockPlugin.ON_EVENT_CALLBACK_METADATA
837+
)
838+
750839
@pytest.mark.asyncio
751840
async def test_runner_close_calls_plugin_close(self):
752841
"""Test that runner.close() calls plugin manager close."""

0 commit comments

Comments
 (0)