Skip to content

Commit 3838dd4

Browse files
GWealecopybara-github
authored andcommitted
fix: fix rewind to preserve initial session state
The rewind logic is updated to ensure that state keys set during session creation are not nullified when rewinding. Previously, any key not present in the state at the rewind point was removed. Now, only keys that have appeared in any event's state delta are considered for nullification during a rewind, preventing the removal of initial session state Close #4933 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 905271916
1 parent 70a7add commit 3838dd4

8 files changed

Lines changed: 345 additions & 65 deletions

File tree

src/google/adk/runners.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,12 @@ async def rewind_async(
646646
session_id=session_id,
647647
get_session_config=run_config.get_session_config,
648648
)
649+
if not rewind_before_invocation_id:
650+
# Guard against matching the synthetic initial-state event that is
651+
# appended by `create_session`; that event has an empty invocation_id by
652+
# design and is not a valid rewind target.
653+
raise ValueError('rewind_before_invocation_id must be non-empty.')
654+
649655
rewind_event_index = -1
650656
for i, event in enumerate(session.events):
651657
if event.invocation_id == rewind_before_invocation_id:
@@ -686,16 +692,34 @@ async def _compute_state_delta_for_rewind(
686692
self, session: Session, rewind_event_index: int
687693
) -> dict[str, Any]:
688694
"""Computes the state delta to reverse changes."""
695+
# State at the rewind point is reconstructed entirely from the event
696+
# stream. Session-scoped initial state from `create_session` is captured
697+
# as a synthetic event by `BaseSessionService._record_initial_state_event`,
698+
# so walking events naturally restores initial values even when a later
699+
# event overwrote them.
689700
state_at_rewind_point: dict[str, Any] = {}
690-
for i in range(rewind_event_index):
691-
if session.events[i].actions.state_delta:
692-
for k, v in session.events[i].actions.state_delta.items():
693-
if k.startswith('app:') or k.startswith('user:'):
694-
continue
695-
if v is None:
696-
state_at_rewind_point.pop(k, None)
697-
else:
698-
state_at_rewind_point[k] = v
701+
all_event_keys: set[str] = set()
702+
703+
for event in session.events[:rewind_event_index]:
704+
if not event.actions.state_delta:
705+
continue
706+
for k, v in event.actions.state_delta.items():
707+
if k.startswith('app:') or k.startswith('user:'):
708+
continue
709+
all_event_keys.add(k)
710+
if v is None:
711+
state_at_rewind_point.pop(k, None)
712+
else:
713+
state_at_rewind_point[k] = v
714+
715+
# Collect any other keys touched by events after the rewind point so we
716+
# know which keys were ever event-sourced.
717+
for event in session.events[rewind_event_index:]:
718+
if not event.actions.state_delta:
719+
continue
720+
for k in event.actions.state_delta:
721+
if not k.startswith('app:') and not k.startswith('user:'):
722+
all_event_keys.add(k)
699723

700724
current_state = session.state
701725
rewind_state_delta = {}
@@ -706,12 +730,13 @@ async def _compute_state_delta_for_rewind(
706730
rewind_state_delta[key] = value_at_rewind
707731

708732
# 2. Set keys to None in rewind_state_delta if they are in current_state
709-
# but not in state_at_rewind_point. These keys were added after the
710-
# rewind point and need to be removed.
733+
# but not in state_at_rewind_point. Only nullify keys that were
734+
# introduced or modified through events; keys set outside the event
735+
# stream are preserved.
711736
for key in current_state:
712737
if key.startswith('app:') or key.startswith('user:'):
713738
continue
714-
if key not in state_at_rewind_point:
739+
if key not in state_at_rewind_point and key in all_event_keys:
715740
rewind_state_delta[key] = None
716741

717742
return rewind_state_delta

src/google/adk/sessions/base_session_service.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
from typing import Any
1919
from typing import Optional
2020

21+
from google.adk.platform import time as platform_time
2122
from pydantic import BaseModel
2223
from pydantic import Field
2324

2425
from ..events.event import Event
26+
from ..events.event_actions import EventActions
2527
from .session import Session
2628
from .state import State
2729

@@ -160,3 +162,36 @@ def _update_session_state(self, session: Session, event: Event) -> None:
160162
return
161163
for key, value in event.actions.state_delta.items():
162164
session.state.update({key: value})
165+
166+
async def _record_initial_state_event(
167+
self, session: Session, state: Optional[dict[str, Any]]
168+
) -> None:
169+
"""Appends a synthetic event carrying the initial non-temp session state.
170+
171+
Subclasses call this from `create_session` so that initial state flows
172+
through `append_event` (the single state-merging path) and so that
173+
`rewind_async` can restore session-scoped initial values for keys later
174+
overwritten or introduced by subsequent events.
175+
176+
Args:
177+
session: The newly created session to attach the event to.
178+
state: The initial state dict supplied to `create_session`. Temp-prefixed
179+
keys are dropped because temp state is ephemeral and never persisted.
180+
"""
181+
if not state:
182+
return
183+
state_delta = {
184+
k: v for k, v in state.items() if not k.startswith(State.TEMP_PREFIX)
185+
}
186+
if not state_delta:
187+
return
188+
# Round to microseconds so the timestamp roundtrips exactly through
189+
# storage backends that persist timestamps as datetime (microsecond
190+
# precision) — keeps in-memory and reloaded events comparable.
191+
timestamp = round(platform_time.get_time(), 6)
192+
initial_event = Event(
193+
author='user',
194+
timestamp=timestamp,
195+
actions=EventActions(state_delta=dict(state_delta)),
196+
)
197+
await self.append_event(session=session, event=initial_event)

src/google/adk/sessions/database_session_service.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -417,11 +417,13 @@ async def create_session(
417417
state: Optional[dict[str, Any]] = None,
418418
session_id: Optional[str] = None,
419419
) -> Session:
420-
# 1. Populate states.
421-
# 2. Build storage session object
422-
# 3. Add the object to the table
423-
# 4. Build the session object with generated id
424-
# 5. Return the session
420+
# 1. Ensure app/user state rows exist (append_event requires them) and
421+
# insert an empty session row.
422+
# 2. Build the in-memory session reflecting any pre-existing app/user
423+
# state.
424+
# 3. Apply the caller-supplied initial state through the synthetic event
425+
# in `_record_initial_state_event` so all state writes share a single
426+
# code path.
425427
await self._prepare_tables()
426428
schema = self._get_schema_classes()
427429
async with self._rollback_on_exception_session() as sql_session:
@@ -432,6 +434,7 @@ async def create_session(
432434
f"Session with id {session_id} already exists."
433435
)
434436
# Get or create state rows, handling concurrent insert races.
437+
# `append_event` requires the app/user state rows to exist.
435438
storage_app_state = await _get_or_create_state(
436439
sql_session=sql_session,
437440
state_model=schema.StorageAppState,
@@ -445,19 +448,6 @@ async def create_session(
445448
defaults={"app_name": app_name, "user_id": user_id, "state": {}},
446449
)
447450

448-
# Extract state deltas
449-
state_deltas = _session_util.extract_state_delta(state)
450-
app_state_delta = state_deltas["app"]
451-
user_state_delta = state_deltas["user"]
452-
session_state = state_deltas["session"]
453-
454-
# Apply state delta
455-
if app_state_delta:
456-
storage_app_state.state = storage_app_state.state | app_state_delta
457-
if user_state_delta:
458-
storage_user_state.state = storage_user_state.state | user_state_delta
459-
460-
# Store the session
461451
now = datetime.fromtimestamp(platform_time.get_time(), tz=timezone.utc)
462452
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
463453
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
@@ -468,20 +458,21 @@ async def create_session(
468458
app_name=app_name,
469459
user_id=user_id,
470460
id=session_id,
471-
state=session_state,
461+
state={},
472462
create_time=now,
473463
update_time=now,
474464
)
475465
sql_session.add(storage_session)
476466
await sql_session.commit()
477467

478-
# Merge states for response
479468
merged_state = _merge_state(
480-
storage_app_state.state, storage_user_state.state, session_state
469+
storage_app_state.state, storage_user_state.state, {}
481470
)
482471
session = storage_session.to_session(
483472
state=merged_state, is_sqlite=is_sqlite
484473
)
474+
475+
await self._record_initial_state_event(session, state)
485476
return session
486477

487478
@override

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,18 @@ async def create_session(
8383
state: Optional[dict[str, Any]] = None,
8484
session_id: Optional[str] = None,
8585
) -> Session:
86-
return self._create_session_impl(
86+
# Initial state flows through `_record_initial_state_event` ->
87+
# `append_event` so the in-memory dicts and the event stream are written
88+
# exactly once. The deprecated `create_session_sync` keeps the legacy
89+
# direct-write path because it cannot await `append_event`.
90+
session = self._create_session_impl(
8791
app_name=app_name,
8892
user_id=user_id,
89-
state=state,
93+
state=None,
9094
session_id=session_id,
9195
)
96+
await self._record_initial_state_event(session, state)
97+
return session
9298

9399
def create_session_sync(
94100
self,

src/google/adk/sessions/sqlite_session_service.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -179,25 +179,9 @@ async def create_session(
179179
f"Session with id {session_id} already exists."
180180
)
181181

182-
# Extract state deltas
183-
state_deltas = _session_util.extract_state_delta(state)
184-
app_state_delta = state_deltas["app"]
185-
user_state_delta = state_deltas["user"]
186-
session_state = state_deltas["session"]
187-
188-
# Apply state delta and update/insert states atomically
189-
if app_state_delta:
190-
await self._upsert_app_state(db, app_name, app_state_delta, now)
191-
if user_state_delta:
192-
await self._upsert_user_state(
193-
db, app_name, user_id, user_state_delta, now
194-
)
195-
196-
# Fetch current state after upserts
197-
storage_app_state = await self._get_app_state(db, app_name)
198-
storage_user_state = await self._get_user_state(db, app_name, user_id)
199-
200-
# Store the session
182+
# Insert the session row with empty per-session state. Initial state
183+
# (including app:/user:-prefixed keys) is applied through the synthetic
184+
# event below so that all state writes go through `append_event`.
201185
await db.execute(
202186
"""
203187
INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time)
@@ -207,18 +191,19 @@ async def create_session(
207191
app_name,
208192
user_id,
209193
session_id,
210-
json.dumps(session_state),
194+
json.dumps({}),
211195
now,
212196
now,
213197
),
214198
)
215199
await db.commit()
216200

217-
# Merge states for response
218-
merged_state = _merge_state(
219-
storage_app_state, storage_user_state, session_state
220-
)
221-
return Session(
201+
# Reflect already-persisted app/user state so subsequent appends start
202+
# from the correct merged view.
203+
storage_app_state = await self._get_app_state(db, app_name)
204+
storage_user_state = await self._get_user_state(db, app_name, user_id)
205+
merged_state = _merge_state(storage_app_state, storage_user_state, {})
206+
session = Session(
222207
app_name=app_name,
223208
user_id=user_id,
224209
id=session_id,
@@ -227,6 +212,9 @@ async def create_session(
227212
last_update_time=now,
228213
)
229214

215+
await self._record_initial_state_event(session, state)
216+
return session
217+
230218
@override
231219
async def get_session(
232220
self,

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,13 @@ async def create_session(
125125
"""
126126
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
127127

128-
config = {'session_state': state} if state else {}
128+
# Initial state is persisted exclusively through the synthetic event
129+
# below (which is sent via `events.append`); avoid passing it as
130+
# `session_state` here so the same data is not written to the backend
131+
# twice.
132+
config = dict(kwargs)
129133
if session_id:
130134
config['session_id'] = session_id
131-
config.update(kwargs)
132135
async with self._get_api_client() as api_client:
133136
api_response = await api_client.agent_engines.sessions.create(
134137
name=f'reasoningEngines/{reasoning_engine_id}',
@@ -143,9 +146,11 @@ async def create_session(
143146
app_name=app_name,
144147
user_id=user_id,
145148
id=session_id,
146-
state=getattr(get_session_response, 'session_state', None) or {},
149+
state={},
147150
last_update_time=get_session_response.update_time.timestamp(),
148151
)
152+
153+
await self._record_initial_state_event(session, state)
149154
return session
150155

151156
@override
@@ -213,9 +218,21 @@ async def get_session(
213218
# to discard events written milliseconds after the session resource was
214219
# updated. Clock skew between those writes can otherwise drop tool_result
215220
# events and permanently break the replayed conversation.
221+
#
222+
# Apply each event's state_delta as we go so callers see the same state
223+
# whether or not the backend mirrors it onto the session_state field
224+
# (e.g. Vertex stores initial state via the synthetic create_session
225+
# event rather than the session_state field).
216226
if events_iterator is not None:
217227
async for event in events_iterator:
218-
session.events.append(_from_api_event(event))
228+
adk_event = _from_api_event(event)
229+
session.events.append(adk_event)
230+
if adk_event.actions and adk_event.actions.state_delta:
231+
for key, value in adk_event.actions.state_delta.items():
232+
if value is None:
233+
session.state.pop(key, None)
234+
else:
235+
session.state[key] = value
219236

220237
if config:
221238
# Filter events based on num_recent_events.

0 commit comments

Comments
 (0)