Skip to content

Commit c06cd45

Browse files
authored
fix: harden SQLAlchemySession against transient SQLite locks (#2854)
1 parent aeb653e commit c06cd45

File tree

2 files changed

+106
-28
lines changed

2 files changed

+106
-28
lines changed

src/agents/extensions/memory/sqlalchemy_session.py

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@
3939
Table,
4040
Text,
4141
delete,
42+
event,
4243
insert,
4344
select,
4445
text as sql_text,
4546
update,
4647
)
47-
from sqlalchemy.exc import IntegrityError
48+
from sqlalchemy.exc import IntegrityError, OperationalError
4849
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
4950

5051
from ...items import TResponseInputItem
@@ -57,6 +58,10 @@ class SQLAlchemySession(SessionABC):
5758

5859
_table_init_locks: ClassVar[dict[tuple[str, str, str], threading.Lock]] = {}
5960
_table_init_locks_guard: ClassVar[threading.Lock] = threading.Lock()
61+
_sqlite_configured_engines: ClassVar[set[int]] = set()
62+
_sqlite_configured_engines_guard: ClassVar[threading.Lock] = threading.Lock()
63+
_SQLITE_BUSY_TIMEOUT_MS: ClassVar[int] = 5000
64+
_SQLITE_LOCK_RETRY_DELAYS: ClassVar[tuple[float, ...]] = (0.05, 0.1, 0.2, 0.4, 0.8)
6065
_metadata: MetaData
6166
_sessions: Table
6267
_messages: Table
@@ -78,6 +83,50 @@ def _get_table_init_lock(
7883
cls._table_init_locks[lock_key] = lock
7984
return lock
8085

86+
@classmethod
87+
def _configure_sqlite_engine(cls, engine: AsyncEngine) -> None:
88+
"""Apply SQLite settings that reduce transient lock failures."""
89+
if engine.dialect.name != "sqlite":
90+
return
91+
92+
engine_key = id(engine.sync_engine)
93+
with cls._sqlite_configured_engines_guard:
94+
if engine_key in cls._sqlite_configured_engines:
95+
return
96+
97+
@event.listens_for(engine.sync_engine, "connect")
98+
def _configure_sqlite_connection(dbapi_connection: Any, _: Any) -> None:
99+
cursor = dbapi_connection.cursor()
100+
try:
101+
cursor.execute(f"PRAGMA busy_timeout = {cls._SQLITE_BUSY_TIMEOUT_MS}")
102+
cursor.execute("PRAGMA journal_mode = WAL")
103+
finally:
104+
cursor.close()
105+
106+
cls._sqlite_configured_engines.add(engine_key)
107+
108+
@staticmethod
109+
def _is_sqlite_lock_error(exc: OperationalError) -> bool:
110+
return "database is locked" in str(exc).lower()
111+
112+
async def _run_sqlite_write_with_retry(self, operation: Any) -> None:
113+
"""Retry transient SQLite write lock failures with bounded backoff."""
114+
if self._engine.dialect.name != "sqlite":
115+
await operation()
116+
return
117+
118+
for attempt, delay in enumerate((0.0, *self._SQLITE_LOCK_RETRY_DELAYS)):
119+
if delay:
120+
await asyncio.sleep(delay)
121+
try:
122+
await operation()
123+
return
124+
except OperationalError as exc:
125+
if not self._is_sqlite_lock_error(exc):
126+
raise
127+
if attempt == len(self._SQLITE_LOCK_RETRY_DELAYS):
128+
raise
129+
81130
def __init__(
82131
self,
83132
session_id: str,
@@ -105,6 +154,7 @@ def __init__(
105154
self.session_id = session_id
106155
self.session_settings = session_settings or SessionSettings()
107156
self._engine = engine
157+
self._configure_sqlite_engine(engine)
108158
self._init_lock = (
109159
self._get_table_init_lock(engine, sessions_table, messages_table)
110160
if create_tables
@@ -294,34 +344,37 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
294344
for item in items
295345
]
296346

297-
async with self._session_factory() as sess:
298-
async with sess.begin():
299-
# Avoid check-then-insert races on the first write while keeping
300-
# the common path free of avoidable integrity exceptions.
301-
existing = await sess.execute(
302-
select(self._sessions.c.session_id).where(
303-
self._sessions.c.session_id == self.session_id
347+
async def _write_items() -> None:
348+
async with self._session_factory() as sess:
349+
async with sess.begin():
350+
# Avoid check-then-insert races on the first write while keeping
351+
# the common path free of avoidable integrity exceptions.
352+
existing = await sess.execute(
353+
select(self._sessions.c.session_id).where(
354+
self._sessions.c.session_id == self.session_id
355+
)
304356
)
305-
)
306-
if not existing.scalar_one_or_none():
307-
try:
308-
async with sess.begin_nested():
309-
await sess.execute(
310-
insert(self._sessions).values({"session_id": self.session_id})
311-
)
312-
except IntegrityError:
313-
# Another concurrent writer created the parent row first.
314-
pass
315-
316-
# Insert messages in bulk
317-
await sess.execute(insert(self._messages), payload)
318-
319-
# Touch updated_at column
320-
await sess.execute(
321-
update(self._sessions)
322-
.where(self._sessions.c.session_id == self.session_id)
323-
.values(updated_at=sql_text("CURRENT_TIMESTAMP"))
324-
)
357+
if not existing.scalar_one_or_none():
358+
try:
359+
async with sess.begin_nested():
360+
await sess.execute(
361+
insert(self._sessions).values({"session_id": self.session_id})
362+
)
363+
except IntegrityError:
364+
# Another concurrent writer created the parent row first.
365+
pass
366+
367+
# Insert messages in bulk
368+
await sess.execute(insert(self._messages), payload)
369+
370+
# Touch updated_at column
371+
await sess.execute(
372+
update(self._sessions)
373+
.where(self._sessions.c.session_id == self.session_id)
374+
.values(updated_at=sql_text("CURRENT_TIMESTAMP"))
375+
)
376+
377+
await self._run_sqlite_write_with_retry(_write_items)
325378

326379
async def pop_item(self) -> TResponseInputItem | None:
327380
"""Remove and return the most recent item from the session.

tests/extensions/memory/test_sqlalchemy_session.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,31 @@ async def worker(content: str) -> None:
272272
assert sorted(stored_contents) == sorted(submitted)
273273

274274

275+
async def test_add_items_waits_for_transient_sqlite_write_lock(tmp_path):
276+
"""SQLite writes should wait briefly for a transient lock instead of failing."""
277+
db_url = f"sqlite+aiosqlite:///{tmp_path / 'sqlite_write_lock_retry.db'}"
278+
session = SQLAlchemySession.from_url(
279+
"sqlite_write_lock_retry",
280+
url=db_url,
281+
create_tables=True,
282+
)
283+
await session.get_items()
284+
285+
async with session.engine.connect() as conn:
286+
await conn.execute(text("BEGIN IMMEDIATE"))
287+
blocked_write = asyncio.create_task(
288+
session.add_items([{"role": "user", "content": "after-lock"}])
289+
)
290+
await asyncio.sleep(0.1)
291+
await conn.rollback()
292+
293+
await asyncio.wait_for(blocked_write, timeout=5)
294+
295+
stored = await session.get_items()
296+
assert len(stored) == 1
297+
assert stored[0].get("content") == "after-lock"
298+
299+
275300
async def test_add_items_concurrent_first_access_across_sessions_with_shared_engine(tmp_path):
276301
"""Concurrent first writes should not race table creation across session instances."""
277302
db_url = f"sqlite+aiosqlite:///{tmp_path / 'concurrent_shared_engine.db'}"

0 commit comments

Comments
 (0)