Skip to content

Commit 0524797

Browse files
Amaad Martincopybara-github
authored andcommitted
fix(agents): fix visibility of output_key state delta in callbacks
Co-authored-by: Amaad Martin <amaadmartin@google.com> PiperOrigin-RevId: 916112779
1 parent 9a1e75f commit 0524797

2 files changed

Lines changed: 206 additions & 18 deletions

File tree

src/google/adk/runners.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
608608
async with Aclosing(
609609
self._exec_with_plugin(
610610
invocation_context=invocation_context,
611-
session=session,
611+
session=invocation_context.session,
612612
execute_fn=execute,
613613
is_live_call=False,
614614
)
@@ -622,7 +622,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
622622
logger.debug('Running event compactor.')
623623
await _run_compaction_for_sliding_window(
624624
self.app,
625-
session,
625+
invocation_context.session,
626626
self.session_service,
627627
skip_token_compaction=invocation_context.token_compaction_checked,
628628
)
@@ -841,7 +841,7 @@ async def _exec_with_plugin(
841841
842842
Args:
843843
invocation_context: The invocation context
844-
session: The current session
844+
session: The current session (ignored, kept for backward compatibility)
845845
execute_fn: A callable that returns an AsyncGenerator of Events
846846
is_live_call: Whether this is a live call
847847
@@ -866,7 +866,7 @@ async def _exec_with_plugin(
866866
)
867867
if self._should_append_event(early_exit_event, is_live_call):
868868
await self.session_service.append_event(
869-
session=session,
869+
session=invocation_context.session,
870870
event=early_exit_event,
871871
)
872872
yield early_exit_event
@@ -931,13 +931,13 @@ async def _exec_with_plugin(
931931
)
932932
if self._should_append_event(event, is_live_call):
933933
await self.session_service.append_event(
934-
session=session, event=output_event
934+
session=invocation_context.session, event=output_event
935935
)
936936

937937
for buffered_event in buffered_events:
938938
logger.debug('Appending buffered event: %s', buffered_event)
939939
await self.session_service.append_event(
940-
session=session, event=buffered_event
940+
session=invocation_context.session, event=buffered_event
941941
)
942942
yield buffered_event # yield buffered events to caller
943943
buffered_events = []
@@ -947,12 +947,12 @@ async def _exec_with_plugin(
947947
if self._should_append_event(event, is_live_call):
948948
logger.debug('Appending non-buffered event: %s', event)
949949
await self.session_service.append_event(
950-
session=session, event=output_event
950+
session=invocation_context.session, event=output_event
951951
)
952952
else:
953953
if event.partial is not True:
954954
await self.session_service.append_event(
955-
session=session, event=output_event
955+
session=invocation_context.session, event=output_event
956956
)
957957

958958
yield output_event
@@ -1004,8 +1004,8 @@ async def _append_new_message_to_session(
10041004
file_name = f'artifact_{invocation_context.invocation_id}_{i}'
10051005
await self.artifact_service.save_artifact(
10061006
app_name=self.app_name,
1007-
user_id=session.user_id,
1008-
session_id=session.id,
1007+
user_id=invocation_context.session.user_id,
1008+
session_id=invocation_context.session.id,
10091009
filename=file_name,
10101010
artifact=part,
10111011
)
@@ -1032,7 +1032,9 @@ async def _append_new_message_to_session(
10321032
if function_call := invocation_context._find_matching_function_call(event):
10331033
event.branch = function_call.branch
10341034

1035-
await self.session_service.append_event(session=session, event=event)
1035+
await self.session_service.append_event(
1036+
session=invocation_context.session, event=event
1037+
)
10361038

10371039
async def run_live(
10381040
self,
@@ -1127,7 +1129,9 @@ async def run_live(
11271129
)
11281130

11291131
root_agent = self.agent
1130-
invocation_context.agent = self._find_agent_to_run(session, root_agent)
1132+
invocation_context.agent = self._find_agent_to_run(
1133+
invocation_context.session, root_agent
1134+
)
11311135

11321136
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
11331137
async with Aclosing(ctx.agent.run_live(ctx)) as agen:
@@ -1137,7 +1141,7 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
11371141
async with Aclosing(
11381142
self._exec_with_plugin(
11391143
invocation_context=invocation_context,
1140-
session=session,
1144+
session=invocation_context.session,
11411145
execute_fn=execute,
11421146
is_live_call=True,
11431147
)
@@ -1355,14 +1359,16 @@ async def _setup_context_for_new_invocation(
13551359
# Step 2: Handle new message, by running callbacks and appending to
13561360
# session.
13571361
await self._handle_new_message(
1358-
session=session,
1362+
session=invocation_context.session,
13591363
new_message=new_message,
13601364
invocation_context=invocation_context,
13611365
run_config=run_config,
13621366
state_delta=state_delta,
13631367
)
13641368
# Step 3: Set agent to run for the invocation.
1365-
invocation_context.agent = self._find_agent_to_run(session, self.agent)
1369+
invocation_context.agent = self._find_agent_to_run(
1370+
invocation_context.session, self.agent
1371+
)
13661372
return invocation_context
13671373

13681374
async def _setup_context_for_resumed_invocation(
@@ -1411,7 +1417,7 @@ async def _setup_context_for_resumed_invocation(
14111417
# Step 3: Maybe handle new message.
14121418
if new_message:
14131419
await self._handle_new_message(
1414-
session=session,
1420+
session=invocation_context.session,
14151421
new_message=user_message,
14161422
invocation_context=invocation_context,
14171423
run_config=run_config,
@@ -1425,7 +1431,9 @@ async def _setup_context_for_resumed_invocation(
14251431
# started from a sub-agent and paused on a sub-agent.
14261432
# We should find the appropriate agent to run to continue the invocation.
14271433
if self.agent.name not in invocation_context.end_of_agents:
1428-
invocation_context.agent = self._find_agent_to_run(session, self.agent)
1434+
invocation_context.agent = self._find_agent_to_run(
1435+
invocation_context.session, self.agent
1436+
)
14291437
return invocation_context
14301438

14311439
def _find_user_message_for_invocation(
@@ -1559,7 +1567,7 @@ async def _handle_new_message(
15591567
if 'save_input_blobs_as_artifacts' in run_config.model_fields_set:
15601568
deprecated_save_blobs = run_config.save_input_blobs_as_artifacts
15611569
await self._append_new_message_to_session(
1562-
session=session,
1570+
session=invocation_context.session,
15631571
new_message=new_message,
15641572
invocation_context=invocation_context,
15651573
save_input_blobs_as_artifacts=deprecated_save_blobs,
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for LlmAgent output_key visibility in callbacks."""
16+
17+
from google.adk.agents.callback_context import CallbackContext
18+
from google.adk.agents.live_request_queue import LiveRequestQueue
19+
from google.adk.agents.llm_agent import LlmAgent
20+
from google.adk.agents.sequential_agent import SequentialAgent
21+
from google.adk.events.event import Event
22+
from google.adk.flows.llm_flows.auto_flow import AutoFlow
23+
from google.genai import types
24+
import pytest
25+
from pytest_mock import MockerFixture
26+
27+
from .. import testing_utils
28+
29+
# Standard MockModel will be used instead of SafeMockModel
30+
31+
32+
@pytest.mark.asyncio
33+
async def test_output_key_visibility_in_after_agent_callback():
34+
"""Test that output_key state delta is visible in after_agent_callback."""
35+
mock_response = "Hello! How can I help you?"
36+
mock_model = testing_utils.MockModel.create(responses=[mock_response])
37+
38+
callback_called = False
39+
captured_state_value = None
40+
captured_session_state_value = None
41+
42+
async def check_output_key(callback_context: CallbackContext):
43+
nonlocal callback_called, captured_state_value, captured_session_state_value
44+
callback_called = True
45+
captured_state_value = callback_context.state.get("result", "NOT_FOUND")
46+
captured_session_state_value = callback_context.session.state.get(
47+
"result", "NOT_IN_RAW"
48+
)
49+
50+
agent = LlmAgent(
51+
name="my_agent",
52+
model=mock_model,
53+
instruction="Reply with a short greeting.",
54+
output_key="result",
55+
after_agent_callback=check_output_key,
56+
)
57+
58+
runner = testing_utils.InMemoryRunner(root_agent=agent)
59+
60+
events = await runner.run_async(new_message="hello")
61+
62+
assert callback_called, "Callback was not called"
63+
64+
assert (
65+
captured_state_value == mock_response
66+
), f"Expected {mock_response}, got {captured_state_value}"
67+
assert (
68+
captured_session_state_value == mock_response
69+
), f"Expected {mock_response}, got {captured_session_state_value}"
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_output_key_visibility_in_run_live(mocker: MockerFixture):
74+
"""Test that output_key state delta is visible in after_agent_callback in run_live."""
75+
mock_response = "Hello! How can I help you?"
76+
mock_model = testing_utils.MockModel.create(responses=[mock_response])
77+
78+
callback_called = False
79+
captured_state_value = None
80+
captured_session_state_value = None
81+
82+
async def check_output_key(callback_context: CallbackContext):
83+
nonlocal callback_called, captured_state_value, captured_session_state_value
84+
callback_called = True
85+
captured_state_value = callback_context.state.get("result", "NOT_FOUND")
86+
captured_session_state_value = callback_context.session.state.get(
87+
"result", "NOT_IN_RAW"
88+
)
89+
90+
agent = LlmAgent(
91+
name="my_agent",
92+
model=mock_model,
93+
instruction="Reply with a short greeting.",
94+
output_key="result",
95+
after_agent_callback=check_output_key,
96+
)
97+
98+
async def mock_auto_flow_run_live(self, ctx):
99+
yield Event(
100+
id=Event.new_id(),
101+
invocation_id=ctx.invocation_id,
102+
author=ctx.agent.name,
103+
content=types.Content(parts=[types.Part(text=mock_response)]),
104+
)
105+
106+
mocker.patch.object(AutoFlow, "run_live", mock_auto_flow_run_live)
107+
108+
runner = testing_utils.InMemoryRunner(root_agent=agent)
109+
live_queue = LiveRequestQueue()
110+
111+
agen = runner.runner.run_live(
112+
user_id="test_user",
113+
session_id=runner.session.id,
114+
live_request_queue=live_queue,
115+
)
116+
117+
# Send a message to trigger the agent
118+
live_queue.send_content(
119+
types.Content(role="user", parts=[types.Part(text="hello")])
120+
)
121+
122+
live_queue.close()
123+
124+
async for event in agen:
125+
pass
126+
127+
assert callback_called, "Callback was not called"
128+
assert (
129+
captured_state_value == mock_response
130+
), f"Expected {mock_response}, got {captured_state_value}"
131+
assert (
132+
captured_session_state_value == mock_response
133+
), f"Expected {mock_response}, got {captured_session_state_value}"
134+
135+
136+
@pytest.mark.asyncio
137+
async def test_output_key_visibility_in_sequential_agent():
138+
"""Test that output_key state delta is visible in next agent's before_agent_callback."""
139+
mock_response = "Hello from agent 1!"
140+
mock_model = testing_utils.MockModel.create(
141+
responses=[mock_response, "Hello from agent 2!"]
142+
)
143+
144+
callback_called = False
145+
captured_session_state_value = None
146+
147+
async def check_before_agent(callback_context: CallbackContext):
148+
nonlocal callback_called, captured_session_state_value
149+
callback_called = True
150+
captured_session_state_value = callback_context.session.state.get(
151+
"result", "NOT_FOUND"
152+
)
153+
154+
agent_1 = LlmAgent(
155+
name="agent_1",
156+
model=mock_model,
157+
instruction="Reply with a short greeting.",
158+
output_key="result",
159+
)
160+
161+
agent_2 = LlmAgent(
162+
name="agent_2",
163+
model=mock_model,
164+
instruction="Reply with a short greeting.",
165+
before_agent_callback=check_before_agent,
166+
)
167+
168+
sequential_agent = SequentialAgent(
169+
name="seq_agent",
170+
sub_agents=[agent_1, agent_2],
171+
)
172+
173+
runner = testing_utils.InMemoryRunner(root_agent=sequential_agent)
174+
175+
events = await runner.run_async(new_message="hello")
176+
177+
assert callback_called, "Callback was not called"
178+
assert (
179+
captured_session_state_value == mock_response
180+
), f"Expected {mock_response}, got {captured_session_state_value}"

0 commit comments

Comments
 (0)