|
61 | 61 |
|
62 | 62 | logger = logging.getLogger("google_adk." + __name__) |
63 | 63 |
|
| 64 | +_STALE_SESSION_ERROR_MESSAGE = ( |
| 65 | + "The session has been modified in storage since it was loaded. " |
| 66 | + "Please reload the session before appending more events." |
| 67 | +) |
| 68 | + |
64 | 69 | _SQLITE_DIALECT = "sqlite" |
65 | 70 | _MARIADB_DIALECT = "mariadb" |
66 | 71 | _MYSQL_DIALECT = "mysql" |
@@ -309,6 +314,39 @@ async def _prepare_tables(self): |
309 | 314 |
|
310 | 315 | self._tables_created = True |
311 | 316 |
|
| 317 | + async def _session_matches_storage_revision( |
| 318 | + self, |
| 319 | + *, |
| 320 | + sql_session: DatabaseSessionFactory, |
| 321 | + schema: _SchemaClasses, |
| 322 | + session: Session, |
| 323 | + ) -> bool: |
| 324 | + """Returns whether a marker-less session still matches stored events.""" |
| 325 | + if not session.events: |
| 326 | + stmt = ( |
| 327 | + select(schema.StorageEvent.id) |
| 328 | + .filter(schema.StorageEvent.app_name == session.app_name) |
| 329 | + .filter(schema.StorageEvent.session_id == session.id) |
| 330 | + .filter(schema.StorageEvent.user_id == session.user_id) |
| 331 | + .limit(1) |
| 332 | + ) |
| 333 | + result = await sql_session.execute(stmt) |
| 334 | + return result.scalar_one_or_none() is None |
| 335 | + |
| 336 | + stmt = ( |
| 337 | + select(schema.StorageEvent.id) |
| 338 | + .filter(schema.StorageEvent.app_name == session.app_name) |
| 339 | + .filter(schema.StorageEvent.session_id == session.id) |
| 340 | + .filter(schema.StorageEvent.user_id == session.user_id) |
| 341 | + .order_by( |
| 342 | + schema.StorageEvent.timestamp.desc(), schema.StorageEvent.id.desc() |
| 343 | + ) |
| 344 | + .limit(1) |
| 345 | + ) |
| 346 | + result = await sql_session.execute(stmt) |
| 347 | + latest_storage_event_id = result.scalar_one_or_none() |
| 348 | + return latest_storage_event_id == session.events[-1].id |
| 349 | + |
312 | 350 | @override |
313 | 351 | async def create_session( |
314 | 352 | self, |
@@ -529,9 +567,9 @@ async def append_event(self, session: Session, event: Event) -> Event: |
529 | 567 | # Trim temp state before persisting |
530 | 568 | event = self._trim_temp_delta_state(event) |
531 | 569 |
|
532 | | - # 1. Check if timestamp is stale |
533 | | - # 2. Update session attributes based on event config |
534 | | - # 3. Store event to table |
| 570 | + # 1. Validate the session has not gone stale. |
| 571 | + # 2. Update session attributes based on event config. |
| 572 | + # 3. Store the new event. |
535 | 573 | schema = self._get_schema_classes() |
536 | 574 | is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT |
537 | 575 | use_row_level_locking = self._supports_row_level_locking() |
@@ -563,6 +601,8 @@ async def append_event(self, session: Session, event: Event) -> Event: |
563 | 601 | storage_session = storage_session_result.scalars().one_or_none() |
564 | 602 | if storage_session is None: |
565 | 603 | raise ValueError(f"Session {session.id} not found.") |
| 604 | + storage_update_time = storage_session.get_update_timestamp(is_sqlite) |
| 605 | + storage_update_marker = storage_session.get_update_marker() |
566 | 606 |
|
567 | 607 | storage_app_state = await _select_required_state( |
568 | 608 | sql_session=sql_session, |
@@ -591,27 +631,27 @@ async def append_event(self, session: Session, event: Event) -> Event: |
591 | 631 | ), |
592 | 632 | ) |
593 | 633 |
|
594 | | - if ( |
595 | | - storage_session.get_update_timestamp(is_sqlite) |
596 | | - > session.last_update_time |
597 | | - ): |
598 | | - # Reload the session from storage if it has been updated since it was |
599 | | - # loaded. |
600 | | - app_state = storage_app_state.state |
601 | | - user_state = storage_user_state.state |
602 | | - session_state = storage_session.state |
603 | | - session.state = _merge_state(app_state, user_state, session_state) |
604 | | - |
605 | | - stmt = ( |
606 | | - select(schema.StorageEvent) |
607 | | - .filter(schema.StorageEvent.app_name == session.app_name) |
608 | | - .filter(schema.StorageEvent.session_id == session.id) |
609 | | - .filter(schema.StorageEvent.user_id == session.user_id) |
610 | | - .order_by(schema.StorageEvent.timestamp.asc()) |
611 | | - ) |
612 | | - result = await sql_session.stream_scalars(stmt) |
613 | | - storage_events = [e async for e in result] |
614 | | - session.events = [e.to_event() for e in storage_events] |
| 634 | + if session._storage_update_marker is not None: |
| 635 | + # Sessions loaded by DatabaseSessionService carry an exact storage |
| 636 | + # revision marker, so stale-writer detection can use that marker |
| 637 | + # instead of relying on rounded timestamps. |
| 638 | + if session._storage_update_marker != storage_update_marker: |
| 639 | + raise ValueError(_STALE_SESSION_ERROR_MESSAGE) |
| 640 | + # Keep the float timestamp synchronized with the exact storage value |
| 641 | + # so tiny round-trip differences do not trigger false stale checks on |
| 642 | + # the next append. |
| 643 | + session.last_update_time = storage_update_time |
| 644 | + elif storage_update_time > session.last_update_time: |
| 645 | + # Backward-compatible fallback for marker-less session objects, such |
| 646 | + # as older in-memory sessions or manually constructed Session values. |
| 647 | + # Only reject when storage has actually advanced beyond the in-memory |
| 648 | + # revision represented by session.events. |
| 649 | + if not await self._session_matches_storage_revision( |
| 650 | + sql_session=sql_session, schema=schema, session=session |
| 651 | + ): |
| 652 | + raise ValueError(_STALE_SESSION_ERROR_MESSAGE) |
| 653 | + session.last_update_time = storage_update_time |
| 654 | + session._storage_update_marker = storage_update_marker |
615 | 655 |
|
616 | 656 | # Merge pre-extracted state deltas into storage. |
617 | 657 | if has_app_delta: |
@@ -642,6 +682,7 @@ async def append_event(self, session: Session, event: Event) -> Event: |
642 | 682 | session.last_update_time = storage_session.get_update_timestamp( |
643 | 683 | is_sqlite |
644 | 684 | ) |
| 685 | + session._storage_update_marker = storage_session.get_update_marker() |
645 | 686 |
|
646 | 687 | # Also update the in-memory session |
647 | 688 | await super().append_event(session=session, event=event) |
|
0 commit comments