Skip to content

Commit b8e7647

Browse files
GWealecopybara-github
authored andcommitted
fix: Reject appends to stale sessions in DatabaseSessionService
This change introduces optimistic concurrency control for session updates. Instead of automatically reloading and merging when an append is attempted on a session that has been modified in storage, the service now raises a ValueError. Close #4751 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 885334444
1 parent 4c9c01f commit b8e7647

5 files changed

Lines changed: 184 additions & 38 deletions

File tree

src/google/adk/sessions/database_session_service.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@
6161

6262
logger = logging.getLogger("google_adk." + __name__)
6363

64+
_STALE_SESSION_ERROR_MESSAGE = (
65+
"The session has been modified in storage since it was loaded. "
66+
"Please reload the session before appending more events."
67+
)
68+
6469
_SQLITE_DIALECT = "sqlite"
6570
_MARIADB_DIALECT = "mariadb"
6671
_MYSQL_DIALECT = "mysql"
@@ -309,6 +314,39 @@ async def _prepare_tables(self):
309314

310315
self._tables_created = True
311316

317+
async def _session_matches_storage_revision(
318+
self,
319+
*,
320+
sql_session: DatabaseSessionFactory,
321+
schema: _SchemaClasses,
322+
session: Session,
323+
) -> bool:
324+
"""Returns whether a marker-less session still matches stored events."""
325+
if not session.events:
326+
stmt = (
327+
select(schema.StorageEvent.id)
328+
.filter(schema.StorageEvent.app_name == session.app_name)
329+
.filter(schema.StorageEvent.session_id == session.id)
330+
.filter(schema.StorageEvent.user_id == session.user_id)
331+
.limit(1)
332+
)
333+
result = await sql_session.execute(stmt)
334+
return result.scalar_one_or_none() is None
335+
336+
stmt = (
337+
select(schema.StorageEvent.id)
338+
.filter(schema.StorageEvent.app_name == session.app_name)
339+
.filter(schema.StorageEvent.session_id == session.id)
340+
.filter(schema.StorageEvent.user_id == session.user_id)
341+
.order_by(
342+
schema.StorageEvent.timestamp.desc(), schema.StorageEvent.id.desc()
343+
)
344+
.limit(1)
345+
)
346+
result = await sql_session.execute(stmt)
347+
latest_storage_event_id = result.scalar_one_or_none()
348+
return latest_storage_event_id == session.events[-1].id
349+
312350
@override
313351
async def create_session(
314352
self,
@@ -529,9 +567,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
529567
# Trim temp state before persisting
530568
event = self._trim_temp_delta_state(event)
531569

532-
# 1. Check if timestamp is stale
533-
# 2. Update session attributes based on event config
534-
# 3. Store event to table
570+
# 1. Validate the session has not gone stale.
571+
# 2. Update session attributes based on event config.
572+
# 3. Store the new event.
535573
schema = self._get_schema_classes()
536574
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
537575
use_row_level_locking = self._supports_row_level_locking()
@@ -563,6 +601,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
563601
storage_session = storage_session_result.scalars().one_or_none()
564602
if storage_session is None:
565603
raise ValueError(f"Session {session.id} not found.")
604+
storage_update_time = storage_session.get_update_timestamp(is_sqlite)
605+
storage_update_marker = storage_session.get_update_marker()
566606

567607
storage_app_state = await _select_required_state(
568608
sql_session=sql_session,
@@ -591,27 +631,27 @@ async def append_event(self, session: Session, event: Event) -> Event:
591631
),
592632
)
593633

594-
if (
595-
storage_session.get_update_timestamp(is_sqlite)
596-
> session.last_update_time
597-
):
598-
# Reload the session from storage if it has been updated since it was
599-
# loaded.
600-
app_state = storage_app_state.state
601-
user_state = storage_user_state.state
602-
session_state = storage_session.state
603-
session.state = _merge_state(app_state, user_state, session_state)
604-
605-
stmt = (
606-
select(schema.StorageEvent)
607-
.filter(schema.StorageEvent.app_name == session.app_name)
608-
.filter(schema.StorageEvent.session_id == session.id)
609-
.filter(schema.StorageEvent.user_id == session.user_id)
610-
.order_by(schema.StorageEvent.timestamp.asc())
611-
)
612-
result = await sql_session.stream_scalars(stmt)
613-
storage_events = [e async for e in result]
614-
session.events = [e.to_event() for e in storage_events]
634+
if session._storage_update_marker is not None:
635+
# Sessions loaded by DatabaseSessionService carry an exact storage
636+
# revision marker, so stale-writer detection can use that marker
637+
# instead of relying on rounded timestamps.
638+
if session._storage_update_marker != storage_update_marker:
639+
raise ValueError(_STALE_SESSION_ERROR_MESSAGE)
640+
# Keep the float timestamp synchronized with the exact storage value
641+
# so tiny round-trip differences do not trigger false stale checks on
642+
# the next append.
643+
session.last_update_time = storage_update_time
644+
elif storage_update_time > session.last_update_time:
645+
# Backward-compatible fallback for marker-less session objects, such
646+
# as older in-memory sessions or manually constructed Session values.
647+
# Only reject when storage has actually advanced beyond the in-memory
648+
# revision represented by session.events.
649+
if not await self._session_matches_storage_revision(
650+
sql_session=sql_session, schema=schema, session=session
651+
):
652+
raise ValueError(_STALE_SESSION_ERROR_MESSAGE)
653+
session.last_update_time = storage_update_time
654+
session._storage_update_marker = storage_update_marker
615655

616656
# Merge pre-extracted state deltas into storage.
617657
if has_app_delta:
@@ -642,6 +682,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
642682
session.last_update_time = storage_session.get_update_timestamp(
643683
is_sqlite
644684
)
685+
session._storage_update_marker = storage_session.get_update_marker()
645686

646687
# Also update the in-memory session
647688
await super().append_event(session=session, event=event)

src/google/adk/sessions/schemas/v0.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ def get_update_timestamp(self, is_sqlite: bool) -> float:
154154
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
155155
return self.update_time.timestamp()
156156

157+
def get_update_marker(self) -> str:
158+
"""Returns a stable revision marker for optimistic concurrency checks."""
159+
update_time = self.update_time
160+
if update_time.tzinfo is not None:
161+
update_time = update_time.astimezone(timezone.utc)
162+
return update_time.isoformat(timespec="microseconds")
163+
157164
def to_session(
158165
self,
159166
state: dict[str, Any] | None = None,
@@ -166,14 +173,16 @@ def to_session(
166173
if events is None:
167174
events = []
168175

169-
return Session(
176+
session = Session(
170177
app_name=self.app_name,
171178
user_id=self.user_id,
172179
id=self.id,
173180
state=state,
174181
events=events,
175182
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
176183
)
184+
session._storage_update_marker = self.get_update_marker()
185+
return session
177186

178187

179188
class StorageEvent(Base):

src/google/adk/sessions/schemas/v1.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ def get_update_timestamp(self, is_sqlite: bool) -> float:
128128
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
129129
return self.update_time.timestamp()
130130

131+
def get_update_marker(self) -> str:
132+
"""Returns a stable revision marker for optimistic concurrency checks."""
133+
update_time = self.update_time
134+
if update_time.tzinfo is not None:
135+
update_time = update_time.astimezone(timezone.utc)
136+
return update_time.isoformat(timespec="microseconds")
137+
131138
def to_session(
132139
self,
133140
state: dict[str, Any] | None = None,
@@ -140,14 +147,16 @@ def to_session(
140147
if events is None:
141148
events = []
142149

143-
return Session(
150+
session = Session(
144151
app_name=self.app_name,
145152
user_id=self.user_id,
146153
id=self.id,
147154
state=state,
148155
events=events,
149156
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
150157
)
158+
session._storage_update_marker = self.get_update_marker()
159+
return session
151160

152161

153162
class StorageEvent(Base):

src/google/adk/sessions/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pydantic import BaseModel
2121
from pydantic import ConfigDict
2222
from pydantic import Field
23+
from pydantic import PrivateAttr
2324

2425
from ..events.event import Event
2526

@@ -48,3 +49,6 @@ class Session(BaseModel):
4849
call/response, etc."""
4950
last_update_time: float = 0.0
5051
"""The last update time of the session."""
52+
53+
_storage_update_marker: str | None = PrivateAttr(default=None)
54+
"""Internal storage revision marker used for stale-session detection."""

tests/unittests/sessions/test_session_service.py

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -657,28 +657,27 @@ async def test_append_event_to_stale_session():
657657
assert len(original_session.events) == 1
658658
assert 'sk2' not in original_session.state
659659

660-
# Appending another event to stale original_session
660+
# Appending another event to stale original_session should be rejected.
661661
event3 = Event(
662662
invocation_id='inv3',
663663
author='user',
664664
timestamp=current_time + 3,
665665
actions=EventActions(state_delta={'sk3': 'v3'}),
666666
)
667-
await session_service.append_event(original_session, event3)
667+
with pytest.raises(ValueError, match='modified in storage'):
668+
await session_service.append_event(original_session, event3)
668669

669-
# If we fetch session from DB, it should contain all 3 events and all state
670-
# changes.
670+
# If we fetch session from DB, it should only contain the committed events.
671671
session_final = await session_service.get_session(
672672
app_name=app_name, user_id=user_id, session_id=original_session.id
673673
)
674-
assert len(session_final.events) == 3
674+
assert len(session_final.events) == 2
675675
assert session_final.state.get('sk1') == 'v1'
676676
assert session_final.state.get('sk2') == 'v2'
677-
assert session_final.state.get('sk3') == 'v3'
677+
assert session_final.state.get('sk3') is None
678678
assert [e.invocation_id for e in session_final.events] == [
679679
'inv1',
680680
'inv2',
681-
'inv3',
682681
]
683682

684683

@@ -738,7 +737,7 @@ async def test_append_event_raises_if_user_state_row_missing():
738737

739738

740739
@pytest.mark.asyncio
741-
async def test_append_event_concurrent_stale_sessions_preserve_all_state():
740+
async def test_append_event_concurrent_stale_sessions_reject_stale_writer():
742741
session_service = get_session_service(
743742
service_type=SessionServiceType.DATABASE
744743
)
@@ -771,19 +770,103 @@ async def test_append_event_concurrent_stale_sessions_preserve_all_state():
771770
actions=EventActions(state_delta={f'sk{i}-2': f'v{i}-2'}),
772771
)
773772

774-
await asyncio.gather(
773+
results = await asyncio.gather(
775774
session_service.append_event(stale_session_1, event_1),
776775
session_service.append_event(stale_session_2, event_2),
776+
return_exceptions=True,
777777
)
778+
errors = [result for result in results if isinstance(result, Exception)]
779+
successes = [
780+
result for result in results if not isinstance(result, Exception)
781+
]
782+
assert len(successes) == 1
783+
assert len(errors) == 1
784+
assert isinstance(errors[0], ValueError)
785+
assert 'modified in storage' in str(errors[0])
778786

779787
session_final = await session_service.get_session(
780788
app_name=app_name, user_id=user_id, session_id=session.id
781789
)
782790

783791
for i in range(iteration_count):
784-
assert session_final.state.get(f'sk{i}-1') == f'v{i}-1'
785-
assert session_final.state.get(f'sk{i}-2') == f'v{i}-2'
786-
assert len(session_final.events) == iteration_count * 2
792+
event_values = {
793+
session_final.state.get(f'sk{i}-1'),
794+
session_final.state.get(f'sk{i}-2'),
795+
}
796+
assert event_values & {f'v{i}-1', f'v{i}-2'}
797+
assert None in event_values
798+
assert len(session_final.events) == iteration_count
799+
800+
801+
@pytest.mark.asyncio
802+
async def test_append_event_allows_timestamp_drift_for_current_session():
803+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
804+
try:
805+
session = await service.create_session(
806+
app_name='my_app', user_id='user', session_id='s1'
807+
)
808+
event1 = Event(
809+
invocation_id='inv1',
810+
author='user',
811+
timestamp=session.last_update_time + 10,
812+
)
813+
await service.append_event(session, event1)
814+
815+
# Simulate a float round-trip mismatch without changing the persisted
816+
# session revision.
817+
session.last_update_time -= 0.0001
818+
819+
event2 = Event(
820+
invocation_id='inv2',
821+
author='user',
822+
timestamp=event1.timestamp + 10,
823+
)
824+
await service.append_event(session, event2)
825+
826+
refreshed_session = await service.get_session(
827+
app_name='my_app', user_id='user', session_id=session.id
828+
)
829+
assert [event.invocation_id for event in refreshed_session.events] == [
830+
'inv1',
831+
'inv2',
832+
]
833+
finally:
834+
await service.close()
835+
836+
837+
@pytest.mark.asyncio
838+
async def test_append_event_allows_markerless_current_session():
839+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
840+
try:
841+
session = await service.create_session(
842+
app_name='my_app', user_id='user', session_id='s1'
843+
)
844+
event1 = Event(
845+
invocation_id='inv1',
846+
author='user',
847+
timestamp=session.last_update_time + 10,
848+
)
849+
await service.append_event(session, event1)
850+
851+
session._storage_update_marker = None
852+
session.last_update_time -= 0.0001
853+
854+
event2 = Event(
855+
invocation_id='inv2',
856+
author='user',
857+
timestamp=event1.timestamp + 10,
858+
)
859+
await service.append_event(session, event2)
860+
861+
refreshed_session = await service.get_session(
862+
app_name='my_app', user_id='user', session_id=session.id
863+
)
864+
assert [event.invocation_id for event in refreshed_session.events] == [
865+
'inv1',
866+
'inv2',
867+
]
868+
finally:
869+
await service.close()
787870

788871

789872
@pytest.mark.asyncio

0 commit comments

Comments
 (0)