diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index 759ddaf5d5..d84f2c78fb 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -39,12 +39,13 @@ Table, Text, delete, + event, insert, select, text as sql_text, update, ) -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from ...items import TResponseInputItem @@ -57,6 +58,10 @@ class SQLAlchemySession(SessionABC): _table_init_locks: ClassVar[dict[tuple[str, str, str], threading.Lock]] = {} _table_init_locks_guard: ClassVar[threading.Lock] = threading.Lock() + _sqlite_configured_engines: ClassVar[set[int]] = set() + _sqlite_configured_engines_guard: ClassVar[threading.Lock] = threading.Lock() + _SQLITE_BUSY_TIMEOUT_MS: ClassVar[int] = 5000 + _SQLITE_LOCK_RETRY_DELAYS: ClassVar[tuple[float, ...]] = (0.05, 0.1, 0.2, 0.4, 0.8) _metadata: MetaData _sessions: Table _messages: Table @@ -78,6 +83,50 @@ def _get_table_init_lock( cls._table_init_locks[lock_key] = lock return lock + @classmethod + def _configure_sqlite_engine(cls, engine: AsyncEngine) -> None: + """Apply SQLite settings that reduce transient lock failures.""" + if engine.dialect.name != "sqlite": + return + + engine_key = id(engine.sync_engine) + with cls._sqlite_configured_engines_guard: + if engine_key in cls._sqlite_configured_engines: + return + + @event.listens_for(engine.sync_engine, "connect") + def _configure_sqlite_connection(dbapi_connection: Any, _: Any) -> None: + cursor = dbapi_connection.cursor() + try: + cursor.execute(f"PRAGMA busy_timeout = {cls._SQLITE_BUSY_TIMEOUT_MS}") + cursor.execute("PRAGMA journal_mode = WAL") + finally: + cursor.close() + + cls._sqlite_configured_engines.add(engine_key) + + @staticmethod + def _is_sqlite_lock_error(exc: OperationalError) -> bool: + return "database is locked" in str(exc).lower() + + async def _run_sqlite_write_with_retry(self, operation: Any) -> None: + """Retry transient SQLite write lock failures with bounded backoff.""" + if self._engine.dialect.name != "sqlite": + await operation() + return + + for attempt, delay in enumerate((0.0, *self._SQLITE_LOCK_RETRY_DELAYS)): + if delay: + await asyncio.sleep(delay) + try: + await operation() + return + except OperationalError as exc: + if not self._is_sqlite_lock_error(exc): + raise + if attempt == len(self._SQLITE_LOCK_RETRY_DELAYS): + raise + def __init__( self, session_id: str, @@ -105,6 +154,7 @@ def __init__( self.session_id = session_id self.session_settings = session_settings or SessionSettings() self._engine = engine + self._configure_sqlite_engine(engine) self._init_lock = ( self._get_table_init_lock(engine, sessions_table, messages_table) if create_tables @@ -294,34 +344,37 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: for item in items ] - async with self._session_factory() as sess: - async with sess.begin(): - # Avoid check-then-insert races on the first write while keeping - # the common path free of avoidable integrity exceptions. - existing = await sess.execute( - select(self._sessions.c.session_id).where( - self._sessions.c.session_id == self.session_id + async def _write_items() -> None: + async with self._session_factory() as sess: + async with sess.begin(): + # Avoid check-then-insert races on the first write while keeping + # the common path free of avoidable integrity exceptions. + existing = await sess.execute( + select(self._sessions.c.session_id).where( + self._sessions.c.session_id == self.session_id + ) ) - ) - if not existing.scalar_one_or_none(): - try: - async with sess.begin_nested(): - await sess.execute( - insert(self._sessions).values({"session_id": self.session_id}) - ) - except IntegrityError: - # Another concurrent writer created the parent row first. - pass - - # Insert messages in bulk - await sess.execute(insert(self._messages), payload) - - # Touch updated_at column - await sess.execute( - update(self._sessions) - .where(self._sessions.c.session_id == self.session_id) - .values(updated_at=sql_text("CURRENT_TIMESTAMP")) - ) + if not existing.scalar_one_or_none(): + try: + async with sess.begin_nested(): + await sess.execute( + insert(self._sessions).values({"session_id": self.session_id}) + ) + except IntegrityError: + # Another concurrent writer created the parent row first. + pass + + # Insert messages in bulk + await sess.execute(insert(self._messages), payload) + + # Touch updated_at column + await sess.execute( + update(self._sessions) + .where(self._sessions.c.session_id == self.session_id) + .values(updated_at=sql_text("CURRENT_TIMESTAMP")) + ) + + await self._run_sqlite_write_with_retry(_write_items) async def pop_item(self) -> TResponseInputItem | None: """Remove and return the most recent item from the session. diff --git a/tests/extensions/memory/test_sqlalchemy_session.py b/tests/extensions/memory/test_sqlalchemy_session.py index 0365cc72cc..3919ada9b6 100644 --- a/tests/extensions/memory/test_sqlalchemy_session.py +++ b/tests/extensions/memory/test_sqlalchemy_session.py @@ -272,6 +272,31 @@ async def worker(content: str) -> None: assert sorted(stored_contents) == sorted(submitted) +async def test_add_items_waits_for_transient_sqlite_write_lock(tmp_path): + """SQLite writes should wait briefly for a transient lock instead of failing.""" + db_url = f"sqlite+aiosqlite:///{tmp_path / 'sqlite_write_lock_retry.db'}" + session = SQLAlchemySession.from_url( + "sqlite_write_lock_retry", + url=db_url, + create_tables=True, + ) + await session.get_items() + + async with session.engine.connect() as conn: + await conn.execute(text("BEGIN IMMEDIATE")) + blocked_write = asyncio.create_task( + session.add_items([{"role": "user", "content": "after-lock"}]) + ) + await asyncio.sleep(0.1) + await conn.rollback() + + await asyncio.wait_for(blocked_write, timeout=5) + + stored = await session.get_items() + assert len(stored) == 1 + assert stored[0].get("content") == "after-lock" + + async def test_add_items_concurrent_first_access_across_sessions_with_shared_engine(tmp_path): """Concurrent first writes should not race table creation across session instances.""" db_url = f"sqlite+aiosqlite:///{tmp_path / 'concurrent_shared_engine.db'}"