Skip to content

Commit f515a9f

Browse files
committed
feat(sessions): introduce pluggable SessionDataTransformer hooks for masking in DatabaseSessionService
1 parent abcf14c commit f515a9f

6 files changed

Lines changed: 169 additions & 6 deletions

File tree

contributing/samples/gepa/experiment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from tau_bench.types import EnvRunResult
4444
from tau_bench.types import RunConfig
4545
import tau_bench_agent as tau_bench_agent_lib
46-
4746
import utils
4847

4948

contributing/samples/gepa/run_experiment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from absl import flags
2626
import experiment
2727
from google.genai import types
28-
2928
import utils
3029

3130
_OUTPUT_DIR = flags.DEFINE_string(

src/google/adk/sessions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@
2222
'DatabaseSessionService',
2323
'InMemorySessionService',
2424
'Session',
25+
'SessionDataTransformer',
2526
'State',
2627
'VertexAiSessionService',
2728
]
2829

2930

3031
def __getattr__(name: str):
32+
if name == 'SessionDataTransformer':
33+
from .session_data_transformer import SessionDataTransformer
34+
return SessionDataTransformer
3135
if name == 'DatabaseSessionService':
3236
try:
3337
from .database_session_service import DatabaseSessionService

src/google/adk/sessions/database_session_service.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from .schemas.v1 import StorageSession as StorageSessionV1
6161
from .schemas.v1 import StorageUserState as StorageUserStateV1
6262
from .session import Session
63+
from .session_data_transformer import SessionDataTransformer
6364
from .state import State
6465

6566
logger = logging.getLogger("google_adk." + __name__)
@@ -188,7 +189,13 @@ def __init__(self, version: str):
188189
class DatabaseSessionService(BaseSessionService):
189190
"""A session service that uses a database for storage."""
190191

191-
def __init__(self, db_url: str, **kwargs: Any):
192+
def __init__(
193+
self,
194+
db_url: str,
195+
*,
196+
transformer: Optional[SessionDataTransformer] = None,
197+
**kwargs: Any,
198+
):
192199
"""Initializes the database session service with a database URL."""
193200
# 1. Create DB engine for db connection
194201
# 2. Create all tables based on schema
@@ -248,6 +255,7 @@ def __init__(self, db_url: str, **kwargs: Any):
248255
self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {}
249256
self._session_lock_ref_count: dict[_SessionLockKey, int] = {}
250257
self._session_locks_guard = asyncio.Lock()
258+
self.transformer = transformer
251259

252260
def _get_schema_classes(self) -> _SchemaClasses:
253261
return _SchemaClasses(self._db_schema_version)
@@ -446,7 +454,12 @@ async def create_session(
446454
)
447455

448456
# Extract state deltas
449-
state_deltas = _session_util.extract_state_delta(state)
457+
transformed_state = (
458+
self.transformer.before_persist_state(state)
459+
if self.transformer and state is not None
460+
else state
461+
)
462+
state_deltas = _session_util.extract_state_delta(transformed_state)
450463
app_state_delta = state_deltas["app"]
451464
user_state_delta = state_deltas["user"]
452465
session_state = state_deltas["session"]
@@ -479,6 +492,8 @@ async def create_session(
479492
merged_state = _merge_state(
480493
storage_app_state.state, storage_user_state.state, session_state
481494
)
495+
if self.transformer:
496+
merged_state = self.transformer.after_load_state(merged_state)
482497
session = storage_session.to_session(
483498
state=merged_state, is_sqlite=is_sqlite
484499
)
@@ -540,9 +555,16 @@ async def get_session(
540555

541556
# Merge states
542557
merged_state = _merge_state(app_state, user_state, session_state)
558+
if self.transformer:
559+
merged_state = self.transformer.after_load_state(merged_state)
543560

544561
# Convert storage session to session
545-
events = [e.to_event() for e in reversed(storage_events)]
562+
events = []
563+
for e in reversed(storage_events):
564+
evt = e.to_event()
565+
if self.transformer:
566+
evt = self.transformer.after_load_event(evt)
567+
events.append(evt)
546568
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
547569
session = storage_session.to_session(
548570
state=merged_state, events=events, is_sqlite=is_sqlite
@@ -596,6 +618,8 @@ async def list_sessions(
596618
session_state = storage_session.state
597619
user_state = user_states_map.get(storage_session.user_id, {})
598620
merged_state = _merge_state(app_state, user_state, session_state)
621+
if self.transformer:
622+
merged_state = self.transformer.after_load_state(merged_state)
599623
sessions.append(
600624
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite)
601625
)
@@ -640,6 +664,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
640664
if event.actions and event.actions.state_delta
641665
else {}
642666
)
667+
if self.transformer:
668+
state_delta = self.transformer.before_persist_state(state_delta)
643669
state_deltas = _session_util.extract_state_delta(state_delta)
644670
has_app_delta = bool(state_deltas["app"])
645671
has_user_delta = bool(state_deltas["user"])
@@ -735,7 +761,13 @@ async def append_event(self, session: Session, event: Event) -> Event:
735761
else:
736762
update_time = datetime.fromtimestamp(event.timestamp)
737763
storage_session.update_time = update_time
738-
sql_session.add(schema.StorageEvent.from_event(session, event))
764+
765+
transformed_event = (
766+
self.transformer.before_persist_event(event)
767+
if self.transformer
768+
else event
769+
)
770+
sql_session.add(schema.StorageEvent.from_event(session, transformed_event))
739771

740772
await sql_session.commit()
741773

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any
15+
from typing import Mapping
16+
from typing import Protocol
17+
18+
from google.adk.events.event import Event
19+
20+
21+
class SessionDataTransformer(Protocol):
22+
"""Hook protocol for selectively transforming DB session records before persist/load.
23+
24+
This is useful for implementing field-level encryption, PII masking, or secret
25+
scrubbing at the storage boundary without modifying the in-memory core structures,
26+
as long as the transformation yields valid storage dictionaries and Events.
27+
"""
28+
29+
def before_persist_event(self, event: Event) -> Event:
30+
"""Invoked just before serializing and persisting an Event to the database."""
31+
...
32+
33+
def after_load_event(self, event: Event) -> Event:
34+
"""Invoked immediately after loading and deserializing an Event from the database."""
35+
...
36+
37+
def before_persist_state(self, state: Mapping[str, Any]) -> dict[str, Any]:
38+
"""Invoked before persisting state changes (can be full state or partial deltas)."""
39+
...
40+
41+
def after_load_state(self, state: Mapping[str, Any]) -> dict[str, Any]:
42+
"""Invoked after loading a combined application/user/session state dict."""
43+
...

tests/unittests/sessions/test_session_service.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,3 +1626,89 @@ async def tracking_fn(**kwargs):
16261626
finally:
16271627
database_session_service._select_required_state = original_fn
16281628
await service.close()
1629+
1630+
import json
1631+
1632+
1633+
class MockPIIMaskerTransformer:
1634+
def before_persist_state(self, state):
1635+
return {k: f"{v}_masked" if isinstance(v, str) else v for k, v in state.items()}
1636+
1637+
def after_load_state(self, state):
1638+
return {k: v.replace("_masked", "") if isinstance(v, str) and v.endswith("_masked") else v for k, v in state.items()}
1639+
1640+
def before_persist_event(self, event: Event) -> Event:
1641+
new_event = event.model_copy() if hasattr(event, "model_copy") else event.copy()
1642+
if new_event.invocation_id:
1643+
new_event.invocation_id += "_masked"
1644+
return new_event
1645+
1646+
def after_load_event(self, event: Event) -> Event:
1647+
new_event = event.model_copy() if hasattr(event, "model_copy") else event.copy()
1648+
if new_event.invocation_id and new_event.invocation_id.endswith("_masked"):
1649+
new_event.invocation_id = new_event.invocation_id.replace("_masked", "")
1650+
return new_event
1651+
1652+
@pytest.mark.asyncio
1653+
async def test_session_data_transformer():
1654+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:', transformer=MockPIIMaskerTransformer())
1655+
try:
1656+
session = await service.create_session(
1657+
app_name='app', user_id='user', session_id='s1', state={'app:secret': 'foo', 'user:pii': 'bar'}
1658+
)
1659+
assert session.state == {'app:secret': 'foo', 'user:pii': 'bar'}
1660+
1661+
# Verify persistence has been masked
1662+
async with service.db_engine.connect() as conn:
1663+
from sqlalchemy import text
1664+
result = await conn.execute(text("SELECT state FROM app_states WHERE app_name = 'app'"))
1665+
app_state_json = result.scalar()
1666+
assert "foo_masked" in json.dumps(app_state_json)
1667+
1668+
event = Event(invocation_id='inv1', author='user', actions=EventActions(state_delta={'sk1': 'pass'}))
1669+
returned_event = await service.append_event(session, event)
1670+
1671+
assert returned_event.invocation_id == 'inv1'
1672+
assert session.state.get('sk1') == 'pass'
1673+
1674+
# Check event persistence
1675+
async with service.db_engine.connect() as conn:
1676+
result = await conn.execute(text("SELECT id, state FROM sessions WHERE id = 's1'"))
1677+
row = result.fetchone()
1678+
assert "pass_masked" in json.dumps(row[1])
1679+
1680+
result_evt = await conn.execute(text("SELECT event_data FROM events WHERE session_id = 's1' LIMIT 1"))
1681+
evt_payload = result_evt.scalar()
1682+
assert "inv1_masked" in json.dumps(evt_payload)
1683+
1684+
# Check retrieval unmasks
1685+
loaded_session = await service.get_session(app_name='app', user_id='user', session_id='s1')
1686+
assert loaded_session.state == {'app:secret': 'foo', 'user:pii': 'bar', 'sk1': 'pass'}
1687+
assert len(loaded_session.events) == 1
1688+
assert loaded_session.events[0].invocation_id == 'inv1'
1689+
finally:
1690+
await service.close()
1691+
1692+
class ErrorMaskerTransformer:
1693+
def before_persist_state(self, state):
1694+
raise ValueError("Transformer exception test")
1695+
1696+
def after_load_state(self, state):
1697+
return state
1698+
1699+
def before_persist_event(self, event: Event) -> Event:
1700+
return event
1701+
1702+
def after_load_event(self, event: Event) -> Event:
1703+
return event
1704+
1705+
@pytest.mark.asyncio
1706+
async def test_session_data_transformer_handles_exception():
1707+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:', transformer=ErrorMaskerTransformer())
1708+
try:
1709+
with pytest.raises(ValueError, match="Transformer exception test"):
1710+
await service.create_session(
1711+
app_name='app', user_id='user', session_id='s1', state={'app:secret': 'foo'}
1712+
)
1713+
finally:
1714+
await service.close()

0 commit comments

Comments
 (0)