Skip to content

Commit 16a1a18

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Option to use shallow-copy for session in InMemorySessionService
PiperOrigin-RevId: 893606046
1 parent 4d7b951 commit 16a1a18

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

src/google/adk/features/_feature_registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class FeatureName(str, Enum):
5353
TOOL_CONFIRMATION = "TOOL_CONFIRMATION"
5454
PLUGGABLE_AUTH = "PLUGGABLE_AUTH"
5555
SNAKE_CASE_SKILL_NAME = "SNAKE_CASE_SKILL_NAME"
56+
IN_MEMORY_SESSION_SERVICE_LIGHT_COPY = "IN_MEMORY_SESSION_SERVICE_LIGHT_COPY"
5657

5758

5859
class FeatureStage(Enum):
@@ -166,6 +167,9 @@ class FeatureConfig:
166167
FeatureName.SNAKE_CASE_SKILL_NAME: FeatureConfig(
167168
FeatureStage.EXPERIMENTAL, default_on=False
168169
),
170+
FeatureName.IN_MEMORY_SESSION_SERVICE_LIGHT_COPY: FeatureConfig(
171+
FeatureStage.WIP, default_on=False
172+
),
169173
}
170174

171175
# Track which experimental features have already warned (warn only once)

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from . import _session_util
2626
from ..errors.already_exists_error import AlreadyExistsError
2727
from ..events.event import Event
28+
from ..features import FeatureName
29+
from ..features import is_feature_enabled
2830
from .base_session_service import BaseSessionService
2931
from .base_session_service import GetSessionConfig
3032
from .base_session_service import ListSessionsResponse
@@ -34,6 +36,28 @@
3436
logger = logging.getLogger('google_adk.' + __name__)
3537

3638

39+
def _light_copy(session: Session) -> Session:
40+
"""Returns a light copy of the session.
41+
42+
Main difference between this and true shallow-copy is that container fields
43+
(e.g., events and state) are also shallow-copied. What this means is appending
44+
to events/state of the copied session won't affect the original while avoiding
45+
the potentially expensive cost of a full/recursive deep-copy of all events and
46+
state.
47+
"""
48+
copied_session = session.model_copy(deep=False)
49+
copied_session.events = copy.copy(session.events)
50+
copied_session.state = copy.copy(session.state)
51+
return copied_session
52+
53+
54+
def _copy_session(session: Session) -> Session:
55+
if is_feature_enabled(FeatureName.IN_MEMORY_SESSION_SERVICE_LIGHT_COPY):
56+
return _light_copy(session)
57+
else:
58+
return copy.deepcopy(session)
59+
60+
3761
class InMemorySessionService(BaseSessionService):
3862
"""An in-memory implementation of the session service.
3963
@@ -124,7 +148,7 @@ def _create_session_impl(
124148
self.sessions[app_name][user_id] = {}
125149
self.sessions[app_name][user_id][session_id] = session
126150

127-
copied_session = copy.deepcopy(session)
151+
copied_session = _copy_session(session)
128152
return self._merge_state(app_name, user_id, copied_session)
129153

130154
@override
@@ -175,7 +199,7 @@ def _get_session_impl(
175199
return None
176200

177201
session = self.sessions[app_name][user_id].get(session_id)
178-
copied_session = copy.deepcopy(session)
202+
copied_session = _copy_session(session)
179203

180204
if config:
181205
if config.num_recent_events is not None:
@@ -247,13 +271,13 @@ def _list_sessions_impl(
247271
if user_id is None:
248272
for uid in list(self.sessions[app_name].keys()):
249273
for session in list(self.sessions[app_name][uid].values()):
250-
copied_session = copy.deepcopy(session)
274+
copied_session = _copy_session(session)
251275
copied_session.events = []
252276
copied_session = self._merge_state(app_name, uid, copied_session)
253277
sessions_without_events.append(copied_session)
254278
else:
255279
for session in list(self.sessions[app_name][user_id].values()):
256-
copied_session = copy.deepcopy(session)
280+
copied_session = _copy_session(session)
257281
copied_session.events = []
258282
copied_session = self._merge_state(app_name, user_id, copied_session)
259283
sessions_without_events.append(copied_session)

tests/unittests/sessions/test_session_service.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from google.adk.errors.already_exists_error import AlreadyExistsError
2424
from google.adk.events.event import Event
2525
from google.adk.events.event_actions import EventActions
26+
from google.adk.features import FeatureName
27+
from google.adk.features import override_feature_enabled
2628
from google.adk.sessions import database_session_service
2729
from google.adk.sessions.base_session_service import GetSessionConfig
2830
from google.adk.sessions.database_session_service import DatabaseSessionService
@@ -35,6 +37,7 @@
3537

3638
class SessionServiceType(enum.Enum):
3739
IN_MEMORY = 'IN_MEMORY'
40+
IN_MEMORY_WITH_LIGHT_COPY_ENABLED = 'IN_MEMORY_WITH_LIGHT_COPY_ENABLED'
3841
DATABASE = 'DATABASE'
3942
SQLITE = 'SQLITE'
4043

@@ -48,22 +51,33 @@ def get_session_service(
4851
return DatabaseSessionService('sqlite+aiosqlite:///:memory:')
4952
if service_type == SessionServiceType.SQLITE:
5053
return SqliteSessionService(str(tmp_path / 'sqlite.db'))
54+
if service_type == SessionServiceType.IN_MEMORY_WITH_LIGHT_COPY_ENABLED:
55+
return InMemorySessionService()
5156
return InMemorySessionService()
5257

5358

5459
@pytest.fixture(
5560
params=[
5661
SessionServiceType.IN_MEMORY,
62+
SessionServiceType.IN_MEMORY_WITH_LIGHT_COPY_ENABLED,
5763
SessionServiceType.DATABASE,
5864
SessionServiceType.SQLITE,
5965
]
6066
)
6167
async def session_service(request, tmp_path):
6268
"""Provides a session service and closes database backends on teardown."""
69+
if request.param == SessionServiceType.IN_MEMORY_WITH_LIGHT_COPY_ENABLED:
70+
override_feature_enabled(
71+
FeatureName.IN_MEMORY_SESSION_SERVICE_LIGHT_COPY, True
72+
)
6373
service = get_session_service(request.param, tmp_path)
6474
yield service
6575
if isinstance(service, DatabaseSessionService):
6676
await service.close()
77+
if request.param == SessionServiceType.IN_MEMORY_WITH_LIGHT_COPY_ENABLED:
78+
override_feature_enabled(
79+
FeatureName.IN_MEMORY_SESSION_SERVICE_LIGHT_COPY, False
80+
)
6781

6882

6983
def test_database_session_service_enables_pool_pre_ping_by_default():

0 commit comments

Comments
 (0)