4747from collections import deque
4848from collections .abc import Awaitable
4949from collections .abc import Callable
50+ import copy
5051from dataclasses import dataclass
5152from dataclasses import field
5253from datetime import datetime
7677DEFAULT_APP_STATE_COLLECTION = "app_states"
7778DEFAULT_USER_STATE_COLLECTION = "user_states"
7879
79- _SessionLockKey = tuple [str , str , str ]
80-
8180# Transient Firestore / gRPC failures worth retrying. Matched by class name to
8281# avoid a hard dependency on google.api_core being importable everywhere.
8382_RETRYABLE_ERROR_NAMES = frozenset ({
@@ -141,6 +140,10 @@ def __init__(
141140 client : Any = None ,
142141 root_collection : Optional [str ] = None ,
143142 * ,
143+ sessions_collection : str = DEFAULT_SESSIONS_COLLECTION ,
144+ events_collection : str = DEFAULT_EVENTS_COLLECTION ,
145+ app_state_collection : str = DEFAULT_APP_STATE_COLLECTION ,
146+ user_state_collection : str = DEFAULT_USER_STATE_COLLECTION ,
144147 durable_mode : bool = False ,
145148 buffer_max_events : int = 10 ,
146149 flush_interval_seconds : float = 120.0 ,
@@ -155,6 +158,14 @@ def __init__(
155158 client: An optional Firestore ``AsyncClient``. If not provided, a new one
156159 is created (requires ``google-cloud-firestore``).
157160 root_collection: Root collection name. Defaults to ``'adk-session'``.
161+ sessions_collection: Subcollection name for sessions. Defaults to
162+ ``'sessions'``.
163+ events_collection: Subcollection name for events. Defaults to
164+ ``'events'``.
165+ app_state_collection: Root collection for app-scoped state. Defaults to
166+ ``'app_states'``.
167+ user_state_collection: Root collection for user-scoped state. Defaults
168+ to ``'user_states'``.
158169 durable_mode: When True, every event is persisted immediately and no
159170 buffering happens.
160171 buffer_max_events: Flush a session once this many events are buffered.
@@ -172,12 +183,13 @@ def __init__(
172183 " Install it with: pip install google-adk-community[firestore]"
173184 ) from e
174185
186+ self ._firestore = firestore
175187 self .client = client if client is not None else firestore .AsyncClient ()
176188 self .root_collection = root_collection or DEFAULT_ROOT_COLLECTION
177- self .sessions_collection = DEFAULT_SESSIONS_COLLECTION
178- self .events_collection = DEFAULT_EVENTS_COLLECTION
179- self .app_state_collection = DEFAULT_APP_STATE_COLLECTION
180- self .user_state_collection = DEFAULT_USER_STATE_COLLECTION
189+ self .sessions_collection = sessions_collection
190+ self .events_collection = events_collection
191+ self .app_state_collection = app_state_collection
192+ self .user_state_collection = user_state_collection
181193
182194 self ._durable_mode = durable_mode
183195 self ._buffer_max_events = buffer_max_events
@@ -224,8 +236,6 @@ def _merge_state(
224236 user_state : dict [str , Any ],
225237 session_state : dict [str , Any ],
226238 ) -> dict [str , Any ]:
227- import copy
228-
229239 merged = copy .deepcopy (session_state )
230240 for key , value in app_state .items ():
231241 merged [State .APP_PREFIX + key ] = value
@@ -258,8 +268,6 @@ async def create_session(
258268 session_id : Optional [str ] = None ,
259269 ) -> Session :
260270 """Creates a new session (raises AlreadyExistsError on a duplicate id)."""
261- from google .cloud import firestore
262-
263271 session_id = session_id or str (uuid .uuid4 ())
264272 deltas = _session_util .extract_state_delta (state or {})
265273 session_ref = self ._get_sessions_ref (app_name , user_id ).document (session_id )
@@ -270,8 +278,8 @@ async def create_session(
270278 "appName" : app_name ,
271279 "userId" : user_id ,
272280 "state" : deltas ["session" ],
273- "createTime" : firestore .SERVER_TIMESTAMP ,
274- "updateTime" : firestore .SERVER_TIMESTAMP ,
281+ "createTime" : self . _firestore .SERVER_TIMESTAMP ,
282+ "updateTime" : self . _firestore .SERVER_TIMESTAMP ,
275283 "revision" : 1 ,
276284 }
277285
@@ -546,8 +554,6 @@ async def _persist_with_retry(
546554
547555 async def _persist_batch (self , session : Session , events : list [Event ]) -> None :
548556 """Persists a batch of events for one session in a single transaction."""
549- from google .cloud import firestore
550-
551557 session_ref = self ._get_sessions_ref (
552558 session .app_name , session .user_id
553559 ).document (session .id )
@@ -628,7 +634,7 @@ async def _append_txn(transaction: Any) -> int:
628634 session_ref ,
629635 {
630636 "state" : session_only_state ,
631- "updateTime" : firestore .SERVER_TIMESTAMP ,
637+ "updateTime" : self . _firestore .SERVER_TIMESTAMP ,
632638 "revision" : new_revision ,
633639 },
634640 )
0 commit comments