|
60 | 60 | from .schemas.v1 import StorageSession as StorageSessionV1 |
61 | 61 | from .schemas.v1 import StorageUserState as StorageUserStateV1 |
62 | 62 | from .session import Session |
| 63 | +from .session_data_transformer import SessionDataTransformer |
63 | 64 | from .state import State |
64 | 65 |
|
65 | 66 | logger = logging.getLogger("google_adk." + __name__) |
@@ -188,7 +189,13 @@ def __init__(self, version: str): |
188 | 189 | class DatabaseSessionService(BaseSessionService): |
189 | 190 | """A session service that uses a database for storage.""" |
190 | 191 |
|
191 | | - def __init__(self, db_url: str, **kwargs: Any): |
| 192 | + def __init__( |
| 193 | + self, |
| 194 | + db_url: str, |
| 195 | + *, |
| 196 | + transformer: Optional[SessionDataTransformer] = None, |
| 197 | + **kwargs: Any, |
| 198 | + ): |
192 | 199 | """Initializes the database session service with a database URL.""" |
193 | 200 | # 1. Create DB engine for db connection |
194 | 201 | # 2. Create all tables based on schema |
@@ -248,6 +255,7 @@ def __init__(self, db_url: str, **kwargs: Any): |
248 | 255 | self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {} |
249 | 256 | self._session_lock_ref_count: dict[_SessionLockKey, int] = {} |
250 | 257 | self._session_locks_guard = asyncio.Lock() |
| 258 | + self.transformer = transformer |
251 | 259 |
|
252 | 260 | def _get_schema_classes(self) -> _SchemaClasses: |
253 | 261 | return _SchemaClasses(self._db_schema_version) |
@@ -446,7 +454,12 @@ async def create_session( |
446 | 454 | ) |
447 | 455 |
|
448 | 456 | # Extract state deltas |
449 | | - state_deltas = _session_util.extract_state_delta(state) |
| 457 | + transformed_state = ( |
| 458 | + self.transformer.before_persist_state(state) |
| 459 | + if self.transformer and state is not None |
| 460 | + else state |
| 461 | + ) |
| 462 | + state_deltas = _session_util.extract_state_delta(transformed_state) |
450 | 463 | app_state_delta = state_deltas["app"] |
451 | 464 | user_state_delta = state_deltas["user"] |
452 | 465 | session_state = state_deltas["session"] |
@@ -479,6 +492,8 @@ async def create_session( |
479 | 492 | merged_state = _merge_state( |
480 | 493 | storage_app_state.state, storage_user_state.state, session_state |
481 | 494 | ) |
| 495 | + if self.transformer: |
| 496 | + merged_state = self.transformer.after_load_state(merged_state) |
482 | 497 | session = storage_session.to_session( |
483 | 498 | state=merged_state, is_sqlite=is_sqlite |
484 | 499 | ) |
@@ -540,9 +555,16 @@ async def get_session( |
540 | 555 |
|
541 | 556 | # Merge states |
542 | 557 | merged_state = _merge_state(app_state, user_state, session_state) |
| 558 | + if self.transformer: |
| 559 | + merged_state = self.transformer.after_load_state(merged_state) |
543 | 560 |
|
544 | 561 | # Convert storage session to session |
545 | | - events = [e.to_event() for e in reversed(storage_events)] |
| 562 | + events = [] |
| 563 | + for e in reversed(storage_events): |
| 564 | + evt = e.to_event() |
| 565 | + if self.transformer: |
| 566 | + evt = self.transformer.after_load_event(evt) |
| 567 | + events.append(evt) |
546 | 568 | is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT |
547 | 569 | session = storage_session.to_session( |
548 | 570 | state=merged_state, events=events, is_sqlite=is_sqlite |
@@ -596,6 +618,8 @@ async def list_sessions( |
596 | 618 | session_state = storage_session.state |
597 | 619 | user_state = user_states_map.get(storage_session.user_id, {}) |
598 | 620 | merged_state = _merge_state(app_state, user_state, session_state) |
| 621 | + if self.transformer: |
| 622 | + merged_state = self.transformer.after_load_state(merged_state) |
599 | 623 | sessions.append( |
600 | 624 | storage_session.to_session(state=merged_state, is_sqlite=is_sqlite) |
601 | 625 | ) |
@@ -640,6 +664,8 @@ async def append_event(self, session: Session, event: Event) -> Event: |
640 | 664 | if event.actions and event.actions.state_delta |
641 | 665 | else {} |
642 | 666 | ) |
| 667 | + if self.transformer: |
| 668 | + state_delta = self.transformer.before_persist_state(state_delta) |
643 | 669 | state_deltas = _session_util.extract_state_delta(state_delta) |
644 | 670 | has_app_delta = bool(state_deltas["app"]) |
645 | 671 | has_user_delta = bool(state_deltas["user"]) |
@@ -735,7 +761,13 @@ async def append_event(self, session: Session, event: Event) -> Event: |
735 | 761 | else: |
736 | 762 | update_time = datetime.fromtimestamp(event.timestamp) |
737 | 763 | storage_session.update_time = update_time |
738 | | - sql_session.add(schema.StorageEvent.from_event(session, event)) |
| 764 | + |
| 765 | + transformed_event = ( |
| 766 | + self.transformer.before_persist_event(event) |
| 767 | + if self.transformer |
| 768 | + else event |
| 769 | + ) |
| 770 | + sql_session.add(schema.StorageEvent.from_event(session, transformed_event)) |
739 | 771 |
|
740 | 772 | await sql_session.commit() |
741 | 773 |
|
|
0 commit comments