Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ async def _run_node_async(
user_id: str,
session_id: str,
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
yield_user_message: bool = False,
node: Optional['BaseNode'] = None,
Expand Down Expand Up @@ -512,7 +513,9 @@ async def _run_node_async(

# Append user message to session for history
if new_message:
user_event = await self._append_user_event(ic, new_message)
user_event = await self._append_user_event(
ic, new_message, state_delta=state_delta
)
if yield_user_message and user_event:
yield user_event

Expand Down Expand Up @@ -706,14 +709,26 @@ def _resolve_invocation_id_from_fr(
return invocation_ids.pop()

async def _append_user_event(
self, ic: InvocationContext, content: types.Content
self,
ic: InvocationContext,
content: types.Content,
*,
state_delta: Optional[dict[str, Any]] = None,
) -> Event:
"""Append a user message event to the session and return it."""
event = Event(
invocation_id=ic.invocation_id,
author='user',
content=content,
)
if state_delta:
event = Event(
invocation_id=ic.invocation_id,
author='user',
actions=EventActions(state_delta=state_delta),
content=content,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good fix. It correctly threads state_delta through the NodeRunner path, ensuring the session state is updated before any nodes or agents execute.

One minor suggestion: this block could be slightly more concise, though the current version is perfectly clear.

    event = Event(
        invocation_id=ic.invocation_id,
        author='user',
        actions=EventActions(state_delta=state_delta) if state_delta else None,
        content=content,
    )

else:
event = Event(
invocation_id=ic.invocation_id,
author='user',
content=content,
)
# when a paused task delegation is in flight, stamp
# the new user message with that task's isolation_scope so the
# task agent's content-build (scoped to <fc_id>) sees it.
Expand Down Expand Up @@ -989,6 +1004,7 @@ async def run_async(
user_id=user_id,
session_id=session_id,
new_message=new_message,
state_delta=state_delta,
run_config=run_config,
yield_user_message=yield_user_message,
node=agent_to_run,
Expand All @@ -1008,6 +1024,7 @@ async def run_async(
user_id=user_id,
session_id=session_id,
new_message=new_message,
state_delta=state_delta,
run_config=run_config,
yield_user_message=yield_user_message,
)
Expand Down
161 changes: 161 additions & 0 deletions tests/unittests/runners/test_runner_state_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2026 Google LLC
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent test coverage. Verifying both BaseNode and LlmAgent callback scenarios ensures that the state delta is correctly applied in different execution paths.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move the tests to tests/unittests/runners/test_runner_node.py?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved tests to tests/unittests/runners/test_runner_node.py as your request

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Runner state_delta handling.

Verifies that state deltas supplied to Runner.run_async are applied to the
initial user event before NodeRunner executes nodes or agents.
"""

from __future__ import annotations

from typing import Any
from typing import AsyncGenerator

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.context import Context
from google.adk.agents.llm_agent import LlmAgent
from google.adk.events.event import Event
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.workflow._base_node import BaseNode
from google.genai import types
import pytest


def _user_message(text: str = "hello") -> types.Content:
return types.Content(parts=[types.Part(text=text)], role="user")


@pytest.mark.asyncio
async def test_node_runner_applies_state_delta_before_base_node_runs():
"""A BaseNode sees run_async state_delta as session state."""

class _StateReaderNode(BaseNode):

async def _run_impl(
self, *, ctx: Context, node_input: Any
) -> AsyncGenerator[Any, None]:
yield f"state:{ctx.state['test_state']}"

session_service = InMemorySessionService()
runner = Runner(
app_name="test",
node=_StateReaderNode(name="reader"),
session_service=session_service,
)
session = await session_service.create_session(app_name="test", user_id="u")

events: list[Event] = []
async for event in runner.run_async(
user_id="u",
session_id=session.id,
new_message=_user_message(),
state_delta={"test_state": "must_change"},
):
events.append(event)

updated = await session_service.get_session(
app_name="test", user_id="u", session_id=session.id
)
user_events = [event for event in updated.events if event.author == "user"]

assert [event.output for event in events if event.output is not None] == [
"state:must_change"
]
assert updated.state["test_state"] == "must_change"
assert user_events[0].actions.state_delta == {"test_state": "must_change"}


@pytest.mark.asyncio
async def test_node_runner_yields_user_event_with_state_delta():
"""yield_user_message=True yields the user event with state_delta."""

class _NoopNode(BaseNode):

async def _run_impl(
self, *, ctx: Context, node_input: Any
) -> AsyncGenerator[Any, None]:
yield "done"

session_service = InMemorySessionService()
runner = Runner(
app_name="test",
node=_NoopNode(name="noop"),
session_service=session_service,
)
session = await session_service.create_session(app_name="test", user_id="u")

events: list[Event] = []
async for event in runner.run_async(
user_id="u",
session_id=session.id,
new_message=_user_message(),
state_delta={"test_state": "must_change"},
yield_user_message=True,
):
events.append(event)

assert events[0].author == "user"
assert events[0].actions.state_delta == {"test_state": "must_change"}


@pytest.mark.asyncio
async def test_node_runner_applies_state_delta_before_llm_agent_runs():
"""An LlmAgent callback sees run_async state_delta before model execution."""

captured_state_value = None

def _before_agent_callback(
callback_context: CallbackContext,
) -> types.Content:
nonlocal captured_state_value
captured_state_value = callback_context.state["test_state"]
return types.Content(
role="model",
parts=[types.Part(text=f"state:{captured_state_value}")],
)

session_service = InMemorySessionService()
agent = LlmAgent(
name="state_agent",
before_agent_callback=_before_agent_callback,
)
runner = Runner(app_name="test", agent=agent, session_service=session_service)
session = await session_service.create_session(app_name="test", user_id="u")

events: list[Event] = []
async for event in runner.run_async(
user_id="u",
session_id=session.id,
new_message=_user_message(),
state_delta={"test_state": "must_change"},
):
events.append(event)

updated = await session_service.get_session(
app_name="test", user_id="u", session_id=session.id
)
user_events = [event for event in updated.events if event.author == "user"]
response_texts = [
part.text
for event in events
if event.content
for part in event.content.parts
if part.text
]

assert captured_state_value == "must_change"
assert "state:must_change" in response_texts
assert user_events[0].actions.state_delta == {"test_state": "must_change"}
Loading