Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 81 additions & 28 deletions src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@
Table,
Text,
delete,
event,
insert,
select,
text as sql_text,
update,
)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine

from ...items import TResponseInputItem
Expand All @@ -57,6 +58,10 @@ class SQLAlchemySession(SessionABC):

_table_init_locks: ClassVar[dict[tuple[str, str, str], threading.Lock]] = {}
_table_init_locks_guard: ClassVar[threading.Lock] = threading.Lock()
_sqlite_configured_engines: ClassVar[set[int]] = set()
_sqlite_configured_engines_guard: ClassVar[threading.Lock] = threading.Lock()
_SQLITE_BUSY_TIMEOUT_MS: ClassVar[int] = 5000
_SQLITE_LOCK_RETRY_DELAYS: ClassVar[tuple[float, ...]] = (0.05, 0.1, 0.2, 0.4, 0.8)
_metadata: MetaData
_sessions: Table
_messages: Table
Expand All @@ -78,6 +83,50 @@ def _get_table_init_lock(
cls._table_init_locks[lock_key] = lock
return lock

@classmethod
def _configure_sqlite_engine(cls, engine: AsyncEngine) -> None:
"""Apply SQLite settings that reduce transient lock failures."""
if engine.dialect.name != "sqlite":
return

engine_key = id(engine.sync_engine)
with cls._sqlite_configured_engines_guard:
if engine_key in cls._sqlite_configured_engines:
return

@event.listens_for(engine.sync_engine, "connect")
def _configure_sqlite_connection(dbapi_connection: Any, _: Any) -> None:
cursor = dbapi_connection.cursor()
try:
cursor.execute(f"PRAGMA busy_timeout = {cls._SQLITE_BUSY_TIMEOUT_MS}")
cursor.execute("PRAGMA journal_mode = WAL")
finally:
cursor.close()

cls._sqlite_configured_engines.add(engine_key)

@staticmethod
def _is_sqlite_lock_error(exc: OperationalError) -> bool:
return "database is locked" in str(exc).lower()

async def _run_sqlite_write_with_retry(self, operation: Any) -> None:
"""Retry transient SQLite write lock failures with bounded backoff."""
if self._engine.dialect.name != "sqlite":
await operation()
return

for attempt, delay in enumerate((0.0, *self._SQLITE_LOCK_RETRY_DELAYS)):
if delay:
await asyncio.sleep(delay)
try:
await operation()
return
except OperationalError as exc:
if not self._is_sqlite_lock_error(exc):
raise
if attempt == len(self._SQLITE_LOCK_RETRY_DELAYS):
raise

def __init__(
self,
session_id: str,
Expand Down Expand Up @@ -105,6 +154,7 @@ def __init__(
self.session_id = session_id
self.session_settings = session_settings or SessionSettings()
self._engine = engine
self._configure_sqlite_engine(engine)
self._init_lock = (
self._get_table_init_lock(engine, sessions_table, messages_table)
if create_tables
Expand Down Expand Up @@ -294,34 +344,37 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
for item in items
]

async with self._session_factory() as sess:
async with sess.begin():
# Avoid check-then-insert races on the first write while keeping
# the common path free of avoidable integrity exceptions.
existing = await sess.execute(
select(self._sessions.c.session_id).where(
self._sessions.c.session_id == self.session_id
async def _write_items() -> None:
async with self._session_factory() as sess:
async with sess.begin():
# Avoid check-then-insert races on the first write while keeping
# the common path free of avoidable integrity exceptions.
existing = await sess.execute(
select(self._sessions.c.session_id).where(
self._sessions.c.session_id == self.session_id
)
)
)
if not existing.scalar_one_or_none():
try:
async with sess.begin_nested():
await sess.execute(
insert(self._sessions).values({"session_id": self.session_id})
)
except IntegrityError:
# Another concurrent writer created the parent row first.
pass

# Insert messages in bulk
await sess.execute(insert(self._messages), payload)

# Touch updated_at column
await sess.execute(
update(self._sessions)
.where(self._sessions.c.session_id == self.session_id)
.values(updated_at=sql_text("CURRENT_TIMESTAMP"))
)
if not existing.scalar_one_or_none():
try:
async with sess.begin_nested():
await sess.execute(
insert(self._sessions).values({"session_id": self.session_id})
)
except IntegrityError:
# Another concurrent writer created the parent row first.
pass

# Insert messages in bulk
await sess.execute(insert(self._messages), payload)

# Touch updated_at column
await sess.execute(
update(self._sessions)
.where(self._sessions.c.session_id == self.session_id)
.values(updated_at=sql_text("CURRENT_TIMESTAMP"))
)

await self._run_sqlite_write_with_retry(_write_items)

async def pop_item(self) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.
Expand Down
25 changes: 25 additions & 0 deletions tests/extensions/memory/test_sqlalchemy_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,31 @@ async def worker(content: str) -> None:
assert sorted(stored_contents) == sorted(submitted)


async def test_add_items_waits_for_transient_sqlite_write_lock(tmp_path):
"""SQLite writes should wait briefly for a transient lock instead of failing."""
db_url = f"sqlite+aiosqlite:///{tmp_path / 'sqlite_write_lock_retry.db'}"
session = SQLAlchemySession.from_url(
"sqlite_write_lock_retry",
url=db_url,
create_tables=True,
)
await session.get_items()

async with session.engine.connect() as conn:
await conn.execute(text("BEGIN IMMEDIATE"))
blocked_write = asyncio.create_task(
session.add_items([{"role": "user", "content": "after-lock"}])
)
await asyncio.sleep(0.1)
await conn.rollback()

await asyncio.wait_for(blocked_write, timeout=5)

stored = await session.get_items()
assert len(stored) == 1
assert stored[0].get("content") == "after-lock"


async def test_add_items_concurrent_first_access_across_sessions_with_shared_engine(tmp_path):
"""Concurrent first writes should not race table creation across session instances."""
db_url = f"sqlite+aiosqlite:///{tmp_path / 'concurrent_shared_engine.db'}"
Expand Down
Loading