Skip to content

Commit c411073

Browse files
committed
style: remove __future__ annotations, clean up comments, align with pyink
- Drop `from __future__ import annotations`; use X | None syntax directly (requires Python >=3.10, already in project metadata) - Remove Optional import; all annotations now use built-in union syntax - Remove vague section-header comments - Simplify is_retryable_error return (single return name in _RETRYABLE_ERROR_NAMES) - Update firestore extra version range to match ADK's own constraint (>=2.11,<3)
1 parent dff032f commit c411073

2 files changed

Lines changed: 75 additions & 97 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ documentation = "https://google.github.io/adk-docs/"
4040

4141
[project.optional-dependencies]
4242
firestore = [
43-
"google-cloud-firestore>=2.11.0, <3.0.0", # For BufferedFirestoreSessionService
43+
"google-cloud-firestore>=2.11,<3",
4444
]
4545
s3 = [
4646
"aioboto3>=13.0.0", # For S3ArtifactService

src/google/adk_community/sessions/firestore_session_service.py

Lines changed: 74 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,28 @@
1818
``google.adk.integrations.firestore.FirestoreSessionService`` (same collection
1919
hierarchy, app/user/session state scoping, optimistic concurrency via a
2020
``revision`` field, and idempotent event documents keyed by ``event.id``) but
21-
**owns** the Firestore I/O so it can persist a whole batch of buffered events in
22-
a **single transaction**.
21+
**owns** the Firestore I/O so it can persist a whole batch of buffered events
22+
in a **single transaction**.
2323
24-
Collection hierarchy::
24+
Collection hierarchy (matches the ADK builtin)::
2525
2626
adk-session/{app}/users/{user}/sessions/{session}/events/{event}
2727
app_states/{app}
2828
user_states/{app}/users/{user}
2929
3030
Events accumulate in a per-session in-memory buffer and flush when the buffer
31-
reaches ``buffer_max_events``, when ``flush_interval_seconds`` elapses (the
32-
background task started by :meth:`start`), when ``flush_session`` / ``flush_all``
33-
/ ``flush`` is called, or when :meth:`stop` runs. Set ``durable_mode=True`` to
34-
persist every event immediately (no buffering).
35-
36-
Batching does not change the event-document count, but it collapses the repeated
37-
session-doc + state-doc updates and per-event transactions from N to 1 (fewer
38-
round-trips and less optimistic-lock contention). On an abrupt process death
39-
before a flush, up to ``flush_interval_seconds`` of events (or
40-
``buffer_max_events - 1`` per session) may be lost; ``stop()`` flushes on
41-
graceful shutdown but cannot protect against crashes.
31+
reaches ``buffer_max_events``, when ``flush_interval_seconds`` elapses (via
32+
the background task started by :meth:`start`), when ``flush_session`` /
33+
``flush_all`` / ``flush`` is called, or when :meth:`stop` runs. Set
34+
``durable_mode=True`` to persist every event immediately (no buffering).
35+
36+
Batching collapses the repeated session-doc + state-doc updates and per-event
37+
transactions from N to 1 (fewer round-trips, less optimistic-lock contention).
38+
On an abrupt process death before a flush, up to ``flush_interval_seconds`` of
39+
events (or ``buffer_max_events - 1`` per session) may be lost; :meth:`stop`
40+
flushes on graceful shutdown.
4241
"""
4342

44-
from __future__ import annotations
45-
4643
import asyncio
4744
from collections import deque
4845
from collections.abc import Awaitable
@@ -56,7 +53,6 @@
5653
import random
5754
import time
5855
from typing import Any
59-
from typing import Optional
6056
import uuid
6157

6258
from google.adk.errors.already_exists_error import AlreadyExistsError
@@ -77,8 +73,6 @@
7773
DEFAULT_APP_STATE_COLLECTION = "app_states"
7874
DEFAULT_USER_STATE_COLLECTION = "user_states"
7975

80-
# Transient Firestore / gRPC failures worth retrying. Matched by class name to
81-
# avoid a hard dependency on google.api_core being importable everywhere.
8276
_RETRYABLE_ERROR_NAMES = frozenset({
8377
"DeadlineExceeded",
8478
"ServiceUnavailable",
@@ -117,9 +111,7 @@ def is_retryable_error(exc: BaseException) -> bool:
117111
name = type(exc).__name__
118112
if name in _NON_RETRYABLE_ERROR_NAMES:
119113
return False
120-
if name in _RETRYABLE_ERROR_NAMES:
121-
return True
122-
return False
114+
return name in _RETRYABLE_ERROR_NAMES
123115

124116

125117
@dataclass
@@ -138,7 +130,7 @@ class BufferedFirestoreSessionService(BaseSessionService): # type: ignore[misc]
138130
def __init__(
139131
self,
140132
client: Any = None,
141-
root_collection: Optional[str] = None,
133+
root_collection: str | None = None,
142134
*,
143135
sessions_collection: str = DEFAULT_SESSIONS_COLLECTION,
144136
events_collection: str = DEFAULT_EVENTS_COLLECTION,
@@ -156,26 +148,25 @@ def __init__(
156148
"""Initializes the buffered Firestore session service.
157149
158150
Args:
159-
client: An optional Firestore ``AsyncClient``. If not provided, a new one
160-
is created (requires ``google-cloud-firestore``).
151+
client: An optional Firestore ``AsyncClient``. If not provided, a new
152+
one is created (requires ``google-cloud-firestore``).
161153
root_collection: Root collection name. Defaults to ``'adk-session'``.
162-
sessions_collection: Subcollection name for sessions. Defaults to
154+
sessions_collection: Sessions subcollection name. Defaults to
163155
``'sessions'``.
164-
events_collection: Subcollection name for events. Defaults to
165-
``'events'``.
166-
app_state_collection: Root collection for app-scoped state. Defaults to
156+
events_collection: Events subcollection name. Defaults to ``'events'``.
157+
app_state_collection: Collection for app-scoped state. Defaults to
167158
``'app_states'``.
168-
user_state_collection: Root collection for user-scoped state. Defaults
169-
to ``'user_states'``.
170-
flat_layout: When True, session documents live directly in
171-
``root_collection/{session_id}`` (no ``{app}/users/{user}/sessions/``
172-
nesting). Useful when the session id already encodes the user (e.g.
173-
``{phone}-{date}``). Defaults to False.
174-
durable_mode: When True, every event is persisted immediately and no
175-
buffering happens.
176-
buffer_max_events: Flush a session once this many events are buffered.
159+
user_state_collection: Collection for user-scoped state. Defaults to
160+
``'user_states'``.
161+
flat_layout: When ``True``, session documents are stored directly at
162+
``root_collection/{session_id}`` instead of the default nested ADK
163+
path. Useful when the session id already encodes the user (e.g.
164+
``{phone}-{date}``) or to match an existing flat collection.
165+
durable_mode: When ``True``, every event is persisted immediately (no
166+
buffering). Equivalent to the builtin service behaviour.
167+
buffer_max_events: Flush when this many events are buffered per session.
177168
flush_interval_seconds: Background flush cadence (see :meth:`start`).
178-
max_retry_attempts: Max attempts when a flush hits a retryable error.
169+
max_retry_attempts: Max attempts on a retryable Firestore error.
179170
retry_base_delay_seconds: Base delay for exponential backoff with jitter.
180171
clock: Monotonic clock, injectable for tests.
181172
sleeper: Async sleep function, injectable for tests.
@@ -195,29 +186,22 @@ def __init__(
195186
self.events_collection = events_collection
196187
self.app_state_collection = app_state_collection
197188
self.user_state_collection = user_state_collection
198-
# flat_layout=True: sessions/{session_id} (no {app}/users/{user} nesting)
199-
# flat_layout=False (default): {root}/{app}/users/{user}/{sessions}/{session_id}
200189
self._flat_layout = flat_layout
201-
202190
self._durable_mode = durable_mode
203191
self._buffer_max_events = buffer_max_events
204192
self._flush_interval_seconds = flush_interval_seconds
205193
self._max_retry_attempts = max_retry_attempts
206194
self._retry_base_delay_seconds = retry_base_delay_seconds
207195
self._clock = clock
208196
self._sleeper = sleeper
209-
# Injectable so tests can drive a fake client without the real transactional
210-
# retry wrapper.
211197
self._transactional = firestore.async_transactional
212198

213199
self._buffers: dict[str, _SessionBuffer] = {}
214200
self._session_refs: dict[str, Session] = {}
215201
self._buffers_guard = asyncio.Lock()
216-
self._task: Optional[asyncio.Task[None]] = None
202+
self._task: asyncio.Task[None] | None = None
217203
self._check_interval = max(1.0, min(flush_interval_seconds, 5.0))
218204

219-
# -- Firestore refs / helpers ---------------------------------------------
220-
221205
def _get_sessions_ref(self, app_name: str, user_id: str) -> Any:
222206
if self._flat_layout:
223207
return self.client.collection(self.root_collection)
@@ -266,16 +250,14 @@ def _coerce_timestamp(value: Any) -> float:
266250
except (ValueError, TypeError):
267251
return 0.0
268252

269-
# -- CRUD ------------------------------------------------------------------
270-
271253
@override
272254
async def create_session(
273255
self,
274256
*,
275257
app_name: str,
276258
user_id: str,
277-
state: Optional[dict[str, Any]] = None,
278-
session_id: Optional[str] = None,
259+
state: dict[str, Any] | None = None,
260+
session_id: str | None = None,
279261
) -> Session:
280262
"""Creates a new session (raises AlreadyExistsError on a duplicate id)."""
281263
session_id = session_id or str(uuid.uuid4())
@@ -334,8 +316,8 @@ async def get_session(
334316
app_name: str,
335317
user_id: str,
336318
session_id: str,
337-
config: Optional[GetSessionConfig] = None,
338-
) -> Optional[Session]:
319+
config: GetSessionConfig | None = None,
320+
) -> Session | None:
339321
"""Gets a session, merging persisted and not-yet-flushed buffered events."""
340322
session_ref = self._get_sessions_ref(app_name, user_id).document(session_id)
341323
doc = await session_ref.get()
@@ -378,7 +360,7 @@ async def get_session(
378360

379361
@override
380362
async def list_sessions(
381-
self, *, app_name: str, user_id: Optional[str] = None
363+
self, *, app_name: str, user_id: str | None = None
382364
) -> ListSessionsResponse:
383365
"""Lists sessions for an app (optionally a single user)."""
384366
if self._flat_layout:
@@ -468,8 +450,6 @@ async def get_user_state(
468450
"""Returns the raw (un-prefixed) user-scoped state for an app/user."""
469451
return dict(await self._read_state(self._user_state_ref(app_name, user_id)))
470452

471-
# -- buffered append -------------------------------------------------------
472-
473453
@override
474454
async def append_event(self, session: Session, event: Event) -> Event:
475455
"""Appends an event in memory and buffers (or immediately persists) it."""
@@ -500,21 +480,47 @@ async def flush_all(self) -> None:
500480
for session_id in list(self._buffers.keys()):
501481
try:
502482
await self._flush(session_id, explicit=False)
503-
except Exception: # noqa: BLE001 - never abort shutdown; already logged
483+
except Exception: # noqa: BLE001
504484
logger.exception("flush_all_session_failed session_id=%s", session_id)
505485

506486
async def flush(self) -> None:
507487
"""ADK lifecycle hook (Runner.close()): flushes all buffered sessions."""
508488
await self.flush_all()
509489

490+
async def start(self) -> None:
491+
"""Starts the background periodic-flush task (idempotent)."""
492+
if self._task is not None and not self._task.done():
493+
return
494+
self._task = asyncio.create_task(self._periodic_flush_loop())
495+
496+
async def stop(self) -> None:
497+
"""Stops the background task and performs a final flush (idempotent)."""
498+
task = self._task
499+
self._task = None
500+
if task is not None:
501+
task.cancel()
502+
try:
503+
await task
504+
except asyncio.CancelledError:
505+
pass
506+
await self.flush_all()
507+
508+
async def close(self) -> None:
509+
"""Closes the underlying Firestore AsyncClient."""
510+
closer = getattr(self.client, "close", None)
511+
if closer is not None:
512+
result = closer()
513+
if asyncio.iscoroutine(result):
514+
await result
515+
510516
async def _flush(self, session_id: str, *, explicit: bool) -> None:
511517
buffer = self._buffers.get(session_id)
512518
if buffer is None:
513519
return
514520

515521
async with buffer.lock:
516522
if buffer.flush_in_progress:
517-
return # only one flush per session at a time
523+
return
518524
if not buffer.pending_events:
519525
buffer.last_flush_monotonic = self._clock()
520526
return
@@ -524,15 +530,15 @@ async def _flush(self, session_id: str, *, explicit: bool) -> None:
524530
buffer.last_flush_monotonic = self._clock()
525531
session = self._session_refs.get(session_id)
526532

527-
if session is None: # pragma: no cover - defensive
533+
if session is None: # pragma: no cover
528534
async with buffer.lock:
529535
buffer.pending_events.extendleft(reversed(batch))
530536
buffer.flush_in_progress = False
531537
return
532538

533539
try:
534540
await self._persist_with_retry(session, batch, session_id)
535-
except Exception as exc: # noqa: BLE001 - reclassified; never silently dropped
541+
except Exception as exc: # noqa: BLE001
536542
async with buffer.lock:
537543
buffer.pending_events.extendleft(reversed(batch))
538544
buffer.flush_in_progress = False
@@ -554,7 +560,7 @@ async def _persist_with_retry(
554560
try:
555561
await self._persist_batch(session, batch)
556562
return
557-
except Exception as exc: # noqa: BLE001 - retryable vs permanent
563+
except Exception as exc: # noqa: BLE001
558564
if not is_retryable_error(exc) or attempt >= self._max_retry_attempts:
559565
logger.error(
560566
"session_flush_failed session_id=%s events=%s attempt=%s"
@@ -625,12 +631,11 @@ async def _append_txn(transaction: Any) -> int:
625631
event_ref = session_ref.collection(self.events_collection).document(
626632
event.id
627633
)
634+
# Use event's own timestamp so intra-batch order survives a shared commit time.
628635
transaction.set(
629636
event_ref,
630637
{
631638
"event_data": event.model_dump(exclude_none=True, mode="json"),
632-
# The event's own timestamp (not SERVER_TIMESTAMP) so order is
633-
# preserved within a batch that shares a commit time.
634639
"timestamp": datetime.fromtimestamp(
635640
event.timestamp, tz=timezone.utc
636641
),
@@ -664,34 +669,6 @@ async def _append_txn(transaction: Any) -> int:
664669
if events:
665670
session.last_update_time = events[-1].timestamp
666671

667-
# -- periodic flushing -----------------------------------------------------
668-
669-
async def start(self) -> None:
670-
"""Starts the background periodic-flush task (idempotent)."""
671-
if self._task is not None and not self._task.done():
672-
return
673-
self._task = asyncio.create_task(self._periodic_flush_loop())
674-
675-
async def stop(self) -> None:
676-
"""Stops the background task and performs a final flush (idempotent)."""
677-
task = self._task
678-
self._task = None
679-
if task is not None:
680-
task.cancel()
681-
try:
682-
await task
683-
except asyncio.CancelledError:
684-
pass
685-
await self.flush_all()
686-
687-
async def close(self) -> None:
688-
"""Closes the underlying Firestore AsyncClient."""
689-
closer = getattr(self.client, "close", None)
690-
if closer is not None:
691-
result = closer()
692-
if asyncio.iscoroutine(result):
693-
await result
694-
695672
async def _periodic_flush_loop(self) -> None:
696673
try:
697674
while True:
@@ -704,8 +681,11 @@ async def _flush_due(self) -> list[asyncio.Task[None]]:
704681
now = self._clock()
705682
tasks: list[asyncio.Task[None]] = []
706683
for session_id, buffer in list(self._buffers.items()):
707-
due = (now - buffer.last_flush_monotonic) >= self._flush_interval_seconds
708-
if buffer.pending_events and due:
684+
if (
685+
buffer.pending_events
686+
and (now - buffer.last_flush_monotonic)
687+
>= self._flush_interval_seconds
688+
):
709689
tasks.append(
710690
asyncio.create_task(self._safe_background_flush(session_id))
711691
)
@@ -714,11 +694,9 @@ async def _flush_due(self) -> list[asyncio.Task[None]]:
714694
async def _safe_background_flush(self, session_id: str) -> None:
715695
try:
716696
await self._flush(session_id, explicit=False)
717-
except Exception: # noqa: BLE001 - background task must not raise unhandled
697+
except Exception: # noqa: BLE001
718698
logger.exception("background_flush_failed session_id=%s", session_id)
719699

720-
# -- internal helpers ------------------------------------------------------
721-
722700
async def _get_or_create_buffer(self, session: Session) -> _SessionBuffer:
723701
async with self._buffers_guard:
724702
buffer = self._buffers.get(session.id)

0 commit comments

Comments
 (0)