Skip to content

Commit 761842e

Browse files
committed
refactor(sessions): clean up and add configurable collection names
- Remove unused _SessionLockKey type alias - Move import copy to module level (stdlib, no reason to lazy-import) - Store self._firestore in __init__ to avoid repeated guarded imports inside create_session / _persist_batch - Add sessions_collection, events_collection, app_state_collection, and user_state_collection constructor params (keyword-only, with defaults) so developers can customise the Firestore collection layout without subclassing
1 parent fa936aa commit 761842e

1 file changed

Lines changed: 21 additions & 15 deletions

File tree

src/google/adk_community/sessions/firestore_session_service.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from collections import deque
4848
from collections.abc import Awaitable
4949
from collections.abc import Callable
50+
import copy
5051
from dataclasses import dataclass
5152
from dataclasses import field
5253
from datetime import datetime
@@ -76,8 +77,6 @@
7677
DEFAULT_APP_STATE_COLLECTION = "app_states"
7778
DEFAULT_USER_STATE_COLLECTION = "user_states"
7879

79-
_SessionLockKey = tuple[str, str, str]
80-
8180
# Transient Firestore / gRPC failures worth retrying. Matched by class name to
8281
# avoid a hard dependency on google.api_core being importable everywhere.
8382
_RETRYABLE_ERROR_NAMES = frozenset({
@@ -141,6 +140,10 @@ def __init__(
141140
client: Any = None,
142141
root_collection: Optional[str] = None,
143142
*,
143+
sessions_collection: str = DEFAULT_SESSIONS_COLLECTION,
144+
events_collection: str = DEFAULT_EVENTS_COLLECTION,
145+
app_state_collection: str = DEFAULT_APP_STATE_COLLECTION,
146+
user_state_collection: str = DEFAULT_USER_STATE_COLLECTION,
144147
durable_mode: bool = False,
145148
buffer_max_events: int = 10,
146149
flush_interval_seconds: float = 120.0,
@@ -155,6 +158,14 @@ def __init__(
155158
client: An optional Firestore ``AsyncClient``. If not provided, a new one
156159
is created (requires ``google-cloud-firestore``).
157160
root_collection: Root collection name. Defaults to ``'adk-session'``.
161+
sessions_collection: Subcollection name for sessions. Defaults to
162+
``'sessions'``.
163+
events_collection: Subcollection name for events. Defaults to
164+
``'events'``.
165+
app_state_collection: Root collection for app-scoped state. Defaults to
166+
``'app_states'``.
167+
user_state_collection: Root collection for user-scoped state. Defaults
168+
to ``'user_states'``.
158169
durable_mode: When True, every event is persisted immediately and no
159170
buffering happens.
160171
buffer_max_events: Flush a session once this many events are buffered.
@@ -172,12 +183,13 @@ def __init__(
172183
" Install it with: pip install google-adk-community[firestore]"
173184
) from e
174185

186+
self._firestore = firestore
175187
self.client = client if client is not None else firestore.AsyncClient()
176188
self.root_collection = root_collection or DEFAULT_ROOT_COLLECTION
177-
self.sessions_collection = DEFAULT_SESSIONS_COLLECTION
178-
self.events_collection = DEFAULT_EVENTS_COLLECTION
179-
self.app_state_collection = DEFAULT_APP_STATE_COLLECTION
180-
self.user_state_collection = DEFAULT_USER_STATE_COLLECTION
189+
self.sessions_collection = sessions_collection
190+
self.events_collection = events_collection
191+
self.app_state_collection = app_state_collection
192+
self.user_state_collection = user_state_collection
181193

182194
self._durable_mode = durable_mode
183195
self._buffer_max_events = buffer_max_events
@@ -224,8 +236,6 @@ def _merge_state(
224236
user_state: dict[str, Any],
225237
session_state: dict[str, Any],
226238
) -> dict[str, Any]:
227-
import copy
228-
229239
merged = copy.deepcopy(session_state)
230240
for key, value in app_state.items():
231241
merged[State.APP_PREFIX + key] = value
@@ -258,8 +268,6 @@ async def create_session(
258268
session_id: Optional[str] = None,
259269
) -> Session:
260270
"""Creates a new session (raises AlreadyExistsError on a duplicate id)."""
261-
from google.cloud import firestore
262-
263271
session_id = session_id or str(uuid.uuid4())
264272
deltas = _session_util.extract_state_delta(state or {})
265273
session_ref = self._get_sessions_ref(app_name, user_id).document(session_id)
@@ -270,8 +278,8 @@ async def create_session(
270278
"appName": app_name,
271279
"userId": user_id,
272280
"state": deltas["session"],
273-
"createTime": firestore.SERVER_TIMESTAMP,
274-
"updateTime": firestore.SERVER_TIMESTAMP,
281+
"createTime": self._firestore.SERVER_TIMESTAMP,
282+
"updateTime": self._firestore.SERVER_TIMESTAMP,
275283
"revision": 1,
276284
}
277285

@@ -546,8 +554,6 @@ async def _persist_with_retry(
546554

547555
async def _persist_batch(self, session: Session, events: list[Event]) -> None:
548556
"""Persists a batch of events for one session in a single transaction."""
549-
from google.cloud import firestore
550-
551557
session_ref = self._get_sessions_ref(
552558
session.app_name, session.user_id
553559
).document(session.id)
@@ -628,7 +634,7 @@ async def _append_txn(transaction: Any) -> int:
628634
session_ref,
629635
{
630636
"state": session_only_state,
631-
"updateTime": firestore.SERVER_TIMESTAMP,
637+
"updateTime": self._firestore.SERVER_TIMESTAMP,
632638
"revision": new_revision,
633639
},
634640
)

0 commit comments

Comments
 (0)