Skip to content

Commit aa483fe

Browse files
authored
fix(memory): make SQLAlchemySession first writes race-safe (#2725)
1 parent 54444ad commit aa483fe

2 files changed

Lines changed: 321 additions & 9 deletions

File tree

src/agents/extensions/memory/sqlalchemy_session.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525

2626
import asyncio
2727
import json
28-
from typing import Any
28+
import threading
29+
from typing import Any, ClassVar
2930

3031
from sqlalchemy import (
3132
TIMESTAMP,
@@ -43,6 +44,7 @@
4344
text as sql_text,
4445
update,
4546
)
47+
from sqlalchemy.exc import IntegrityError
4648
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
4749

4850
from ...items import TResponseInputItem
@@ -53,11 +55,29 @@
5355
class SQLAlchemySession(SessionABC):
5456
"""SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`."""
5557

58+
_table_init_locks: ClassVar[dict[tuple[str, str, str], threading.Lock]] = {}
59+
_table_init_locks_guard: ClassVar[threading.Lock] = threading.Lock()
5660
_metadata: MetaData
5761
_sessions: Table
5862
_messages: Table
5963
session_settings: SessionSettings | None = None
6064

65+
@classmethod
66+
def _get_table_init_lock(
67+
cls, engine: AsyncEngine, sessions_table: str, messages_table: str
68+
) -> threading.Lock:
69+
lock_key = (
70+
engine.url.render_as_string(hide_password=True),
71+
sessions_table,
72+
messages_table,
73+
)
74+
with cls._table_init_locks_guard:
75+
lock = cls._table_init_locks.get(lock_key)
76+
if lock is None:
77+
lock = threading.Lock()
78+
cls._table_init_locks[lock_key] = lock
79+
return lock
80+
6181
def __init__(
6282
self,
6383
session_id: str,
@@ -85,7 +105,11 @@ def __init__(
85105
self.session_id = session_id
86106
self.session_settings = session_settings or SessionSettings()
87107
self._engine = engine
88-
self._lock = asyncio.Lock()
108+
self._init_lock = (
109+
self._get_table_init_lock(engine, sessions_table, messages_table)
110+
if create_tables
111+
else None
112+
)
89113

90114
self._metadata = MetaData()
91115
self._sessions = Table(
@@ -182,10 +206,23 @@ async def _deserialize_item(self, item: str) -> TResponseInputItem:
182206
# ------------------------------------------------------------------
183207
async def _ensure_tables(self) -> None:
184208
"""Ensure tables are created before any database operations."""
185-
if self._create_tables:
209+
if not self._create_tables:
210+
return
211+
212+
assert self._init_lock is not None
213+
while not self._init_lock.acquire(blocking=False):
214+
# Poll without handing lock acquisition to a background thread so
215+
# cancellation cannot strand the shared init lock in the acquired state.
216+
await asyncio.sleep(0.01)
217+
try:
218+
if not self._create_tables:
219+
return
220+
186221
async with self._engine.begin() as conn:
187222
await conn.run_sync(self._metadata.create_all)
188223
self._create_tables = False # Only create once
224+
finally:
225+
self._init_lock.release()
189226

190227
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
191228
"""Retrieve the conversation history for this session.
@@ -259,18 +296,22 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
259296

260297
async with self._session_factory() as sess:
261298
async with sess.begin():
262-
# Ensure the parent session row exists - use merge for cross-DB compatibility
263-
# Check if session exists
299+
# Avoid check-then-insert races on the first write while keeping
300+
# the common path free of avoidable integrity exceptions.
264301
existing = await sess.execute(
265302
select(self._sessions.c.session_id).where(
266303
self._sessions.c.session_id == self.session_id
267304
)
268305
)
269306
if not existing.scalar_one_or_none():
270-
# Session doesn't exist, create it
271-
await sess.execute(
272-
insert(self._sessions).values({"session_id": self.session_id})
273-
)
307+
try:
308+
async with sess.begin_nested():
309+
await sess.execute(
310+
insert(self._sessions).values({"session_id": self.session_id})
311+
)
312+
except IntegrityError:
313+
# Another concurrent writer created the parent row first.
314+
pass
274315

275316
# Insert messages in bulk
276317
await sess.execute(insert(self._messages), payload)

tests/extensions/memory/test_sqlalchemy_session.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
5+
import threading
46
from collections.abc import Iterable, Sequence
57
from contextlib import asynccontextmanager
68
from datetime import datetime, timedelta
@@ -203,6 +205,275 @@ async def test_add_empty_items_list():
203205
assert len(items_after_add) == 0
204206

205207

208+
async def test_add_items_concurrent_first_access_with_create_tables(tmp_path):
209+
"""Concurrent first writes should not race table creation or drop items."""
210+
db_url = f"sqlite+aiosqlite:///{tmp_path / 'concurrent_first_access.db'}"
211+
session = SQLAlchemySession.from_url(
212+
"concurrent_first_access",
213+
url=db_url,
214+
create_tables=True,
215+
)
216+
submitted = [f"msg-{i}" for i in range(25)]
217+
218+
async def worker(content: str) -> None:
219+
await session.add_items([{"role": "user", "content": content}])
220+
221+
results = await asyncio.gather(
222+
*(worker(content) for content in submitted),
223+
return_exceptions=True,
224+
)
225+
226+
assert [result for result in results if isinstance(result, Exception)] == []
227+
228+
stored = await session.get_items()
229+
assert len(stored) == len(submitted)
230+
stored_contents: list[str] = []
231+
for item in stored:
232+
content = item.get("content")
233+
assert isinstance(content, str)
234+
stored_contents.append(content)
235+
assert sorted(stored_contents) == sorted(submitted)
236+
237+
238+
async def test_add_items_concurrent_first_write_after_tables_exist(tmp_path):
239+
"""Concurrent first writes should not race parent session creation."""
240+
db_url = f"sqlite+aiosqlite:///{tmp_path / 'concurrent_first_write.db'}"
241+
setup_session = SQLAlchemySession.from_url(
242+
"concurrent_first_write",
243+
url=db_url,
244+
create_tables=True,
245+
)
246+
await setup_session.get_items()
247+
248+
session = SQLAlchemySession.from_url(
249+
"concurrent_first_write",
250+
url=db_url,
251+
create_tables=False,
252+
)
253+
submitted = [f"msg-{i}" for i in range(25)]
254+
255+
async def worker(content: str) -> None:
256+
await session.add_items([{"role": "user", "content": content}])
257+
258+
results = await asyncio.gather(
259+
*(worker(content) for content in submitted),
260+
return_exceptions=True,
261+
)
262+
263+
assert [result for result in results if isinstance(result, Exception)] == []
264+
265+
stored = await session.get_items()
266+
assert len(stored) == len(submitted)
267+
stored_contents: list[str] = []
268+
for item in stored:
269+
content = item.get("content")
270+
assert isinstance(content, str)
271+
stored_contents.append(content)
272+
assert sorted(stored_contents) == sorted(submitted)
273+
274+
275+
async def test_add_items_concurrent_first_access_across_sessions_with_shared_engine(tmp_path):
276+
"""Concurrent first writes should not race table creation across session instances."""
277+
db_url = f"sqlite+aiosqlite:///{tmp_path / 'concurrent_shared_engine.db'}"
278+
engine = create_async_engine(db_url)
279+
try:
280+
session_a = SQLAlchemySession("shared_engine_a", engine=engine, create_tables=True)
281+
session_b = SQLAlchemySession("shared_engine_b", engine=engine, create_tables=True)
282+
283+
results = await asyncio.gather(
284+
session_a.add_items([{"role": "user", "content": "one"}]),
285+
session_b.add_items([{"role": "user", "content": "two"}]),
286+
return_exceptions=True,
287+
)
288+
289+
assert [result for result in results if isinstance(result, Exception)] == []
290+
291+
stored_a = await session_a.get_items()
292+
assert len(stored_a) == 1
293+
assert stored_a[0].get("content") == "one"
294+
295+
stored_b = await session_b.get_items()
296+
assert len(stored_b) == 1
297+
assert stored_b[0].get("content") == "two"
298+
finally:
299+
await engine.dispose()
300+
301+
302+
async def test_add_items_concurrent_first_access_across_from_url_sessions(tmp_path):
303+
"""Concurrent first writes should not race table creation across from_url sessions."""
304+
db_url = f"sqlite+aiosqlite:///{tmp_path / 'concurrent_from_url.db'}"
305+
session_a = SQLAlchemySession.from_url("from_url_a", url=db_url, create_tables=True)
306+
session_b = SQLAlchemySession.from_url("from_url_b", url=db_url, create_tables=True)
307+
try:
308+
results = await asyncio.gather(
309+
session_a.add_items([{"role": "user", "content": "one"}]),
310+
session_b.add_items([{"role": "user", "content": "two"}]),
311+
return_exceptions=True,
312+
)
313+
314+
assert [result for result in results if isinstance(result, Exception)] == []
315+
316+
stored_a = await session_a.get_items()
317+
assert len(stored_a) == 1
318+
assert stored_a[0].get("content") == "one"
319+
320+
stored_b = await session_b.get_items()
321+
assert len(stored_b) == 1
322+
assert stored_b[0].get("content") == "two"
323+
finally:
324+
await session_a.engine.dispose()
325+
await session_b.engine.dispose()
326+
327+
328+
async def test_add_items_concurrent_first_access_across_from_url_sessions_cross_loop(tmp_path):
329+
"""Concurrent first writes should not race or hang across event loops."""
330+
db_url = f"sqlite+aiosqlite:///{tmp_path / 'concurrent_from_url_cross_loop.db'}"
331+
barrier = threading.Barrier(2)
332+
results: list[tuple[str, str, Any]] = []
333+
results_lock = threading.Lock()
334+
335+
def worker(session_id: str, content: str) -> None:
336+
async def run() -> tuple[str, Any]:
337+
session = SQLAlchemySession.from_url(session_id, url=db_url, create_tables=True)
338+
barrier.wait()
339+
try:
340+
await asyncio.wait_for(
341+
session.add_items([{"role": "user", "content": content}]),
342+
timeout=5,
343+
)
344+
stored = await session.get_items()
345+
return ("ok", stored)
346+
finally:
347+
await session.engine.dispose()
348+
349+
try:
350+
status, payload = asyncio.run(run())
351+
except Exception as exc:
352+
status, payload = type(exc).__name__, str(exc)
353+
354+
with results_lock:
355+
results.append((session_id, status, payload))
356+
357+
threads = [
358+
threading.Thread(target=worker, args=("from_url_cross_loop_a", "one")),
359+
threading.Thread(target=worker, args=("from_url_cross_loop_b", "two")),
360+
]
361+
for thread in threads:
362+
thread.start()
363+
for thread in threads:
364+
await asyncio.to_thread(thread.join)
365+
366+
assert len(results) == 2
367+
assert [status for _, status, _ in results] == ["ok", "ok"]
368+
369+
stored_by_session = {
370+
session_id: cast(list[TResponseInputItem], payload) for session_id, _, payload in results
371+
}
372+
assert stored_by_session["from_url_cross_loop_a"][0].get("content") == "one"
373+
assert stored_by_session["from_url_cross_loop_b"][0].get("content") == "two"
374+
375+
376+
async def test_add_items_concurrent_first_access_with_shared_session_cross_loop(tmp_path):
377+
"""A shared session instance should not hang when used from two event loops."""
378+
db_url = f"sqlite+aiosqlite:///{tmp_path / 'shared_session_cross_loop.db'}"
379+
session = SQLAlchemySession.from_url(
380+
"shared_session_cross_loop",
381+
url=db_url,
382+
create_tables=True,
383+
)
384+
barrier = threading.Barrier(2)
385+
results: list[tuple[str, str]] = []
386+
results_lock = threading.Lock()
387+
388+
def worker(content: str) -> None:
389+
async def run() -> None:
390+
barrier.wait()
391+
await asyncio.wait_for(
392+
session.add_items([{"role": "user", "content": content}]),
393+
timeout=5,
394+
)
395+
396+
try:
397+
asyncio.run(run())
398+
status = "ok"
399+
except Exception as exc:
400+
status = type(exc).__name__
401+
402+
with results_lock:
403+
results.append((content, status))
404+
405+
threads = [
406+
threading.Thread(target=worker, args=("one",)),
407+
threading.Thread(target=worker, args=("two",)),
408+
]
409+
try:
410+
for thread in threads:
411+
thread.start()
412+
for thread in threads:
413+
await asyncio.to_thread(thread.join)
414+
415+
assert sorted(results) == [("one", "ok"), ("two", "ok")]
416+
417+
stored = await session.get_items()
418+
stored_contents: list[str] = []
419+
for item in stored:
420+
content = item.get("content")
421+
assert isinstance(content, str)
422+
stored_contents.append(content)
423+
assert sorted(stored_contents) == ["one", "two"]
424+
finally:
425+
await session.engine.dispose()
426+
427+
428+
async def test_add_items_cancelled_waiter_does_not_strand_table_init_lock(tmp_path):
429+
"""Cancelling a waiting initializer must not leave the shared init lock acquired."""
430+
db_url = f"sqlite+aiosqlite:///{tmp_path / 'cancelled_table_init_waiter.db'}"
431+
holder = SQLAlchemySession.from_url("holder", url=db_url, create_tables=True)
432+
waiter = SQLAlchemySession.from_url("waiter", url=db_url, create_tables=True)
433+
follower = SQLAlchemySession.from_url("follower", url=db_url, create_tables=True)
434+
435+
assert holder._init_lock is waiter._init_lock
436+
assert waiter._init_lock is follower._init_lock
437+
assert holder._init_lock is not None
438+
439+
acquired = holder._init_lock.acquire(blocking=False)
440+
assert acquired
441+
442+
try:
443+
blocked = asyncio.create_task(waiter.add_items([{"role": "user", "content": "waiter"}]))
444+
await asyncio.sleep(0.05)
445+
blocked.cancel()
446+
with pytest.raises(asyncio.CancelledError):
447+
await blocked
448+
finally:
449+
holder._init_lock.release()
450+
451+
try:
452+
await asyncio.wait_for(
453+
follower.add_items([{"role": "user", "content": "follower"}]),
454+
timeout=2,
455+
)
456+
stored = await follower.get_items()
457+
assert len(stored) == 1
458+
assert stored[0].get("content") == "follower"
459+
finally:
460+
await holder.engine.dispose()
461+
await waiter.engine.dispose()
462+
await follower.engine.dispose()
463+
464+
465+
async def test_create_tables_false_does_not_allocate_shared_init_lock(tmp_path):
466+
"""Sessions that skip auto-create should not populate the shared lock map."""
467+
db_url = f"sqlite+aiosqlite:///{tmp_path / 'no_create_tables_lock.db'}"
468+
before = len(SQLAlchemySession._table_init_locks)
469+
session = SQLAlchemySession.from_url("no_create_tables_lock", url=db_url, create_tables=False)
470+
try:
471+
assert session._init_lock is None
472+
assert len(SQLAlchemySession._table_init_locks) == before
473+
finally:
474+
await session.engine.dispose()
475+
476+
206477
async def test_get_items_same_timestamp_consistent_order():
207478
"""Test that items with identical timestamps keep insertion order."""
208479
session_id = "same_timestamp_test"

0 commit comments

Comments
 (0)