Skip to content

Commit d78422a

Browse files
GWealecopybara-github
authored andcommitted
fix: Handle concurrent creation of app/user state rows in DatabaseSessionService
Introduces a new helper function, `_get_or_create_state`, which uses a nested transaction (SAVEPOINT) to safely attempt creating a state row. If a concurrent transaction has already inserted the row, the inner insert will fail with an IntegrityError, which is caught, and the already-existing row is then fetched Close #4954 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 889441684
1 parent f434d25 commit d78422a

File tree

2 files changed

+172
-15
lines changed

2 files changed

+172
-15
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from sqlalchemy.engine import Connection
3434
from sqlalchemy.engine import make_url
3535
from sqlalchemy.exc import ArgumentError
36+
from sqlalchemy.exc import IntegrityError
3637
from sqlalchemy.ext.asyncio import async_sessionmaker
3738
from sqlalchemy.ext.asyncio import AsyncEngine
3839
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
@@ -103,6 +104,35 @@ async def _select_required_state(
103104
return state_row
104105

105106

107+
async def _get_or_create_state(
108+
*,
109+
sql_session: DatabaseSessionFactory,
110+
state_model: type[_StorageStateT],
111+
primary_key: Any,
112+
defaults: dict[str, Any],
113+
) -> _StorageStateT:
114+
"""Returns an existing state row or creates one, handling concurrent inserts.
115+
116+
Uses a SAVEPOINT so that an IntegrityError from a racing INSERT does not
117+
invalidate the outer transaction.
118+
"""
119+
row = await sql_session.get(state_model, primary_key)
120+
if row is not None:
121+
return row
122+
try:
123+
async with sql_session.begin_nested():
124+
row = state_model(**defaults)
125+
sql_session.add(row)
126+
return row
127+
except IntegrityError:
128+
# Another concurrent caller inserted the row first.
129+
# The savepoint was rolled back, so re-fetch the winner's row.
130+
row = await sql_session.get(state_model, primary_key)
131+
if row is None:
132+
raise
133+
return row
134+
135+
106136
def _set_sqlite_pragma(dbapi_connection, connection_record):
107137
cursor = dbapi_connection.cursor()
108138
cursor.execute("PRAGMA foreign_keys=ON")
@@ -401,24 +431,20 @@ async def create_session(
401431
raise AlreadyExistsError(
402432
f"Session with id {session_id} already exists."
403433
)
404-
# Fetch app and user states from storage
405-
storage_app_state = await sql_session.get(
406-
schema.StorageAppState, (app_name)
434+
# Get or create state rows, handling concurrent insert races.
435+
storage_app_state = await _get_or_create_state(
436+
sql_session=sql_session,
437+
state_model=schema.StorageAppState,
438+
primary_key=app_name,
439+
defaults={"app_name": app_name, "state": {}},
407440
)
408-
storage_user_state = await sql_session.get(
409-
schema.StorageUserState, (app_name, user_id)
441+
storage_user_state = await _get_or_create_state(
442+
sql_session=sql_session,
443+
state_model=schema.StorageUserState,
444+
primary_key=(app_name, user_id),
445+
defaults={"app_name": app_name, "user_id": user_id, "state": {}},
410446
)
411447

412-
# Create state tables if not exist
413-
if not storage_app_state:
414-
storage_app_state = schema.StorageAppState(app_name=app_name, state={})
415-
sql_session.add(storage_app_state)
416-
if not storage_user_state:
417-
storage_user_state = schema.StorageUserState(
418-
app_name=app_name, user_id=user_id, state={}
419-
)
420-
sql_session.add(storage_user_state)
421-
422448
# Extract state deltas
423449
state_deltas = _session_util.extract_state_delta(state)
424450
app_state_delta = state_deltas["app"]

tests/unittests/sessions/test_session_service.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,6 +1370,137 @@ async def test_prepare_tables_serializes_schema_detection_and_creation():
13701370
await service.close()
13711371

13721372

1373+
@pytest.mark.asyncio
1374+
async def test_get_or_create_state_returns_existing_row():
1375+
"""_get_or_create_state returns an existing row without inserting."""
1376+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
1377+
try:
1378+
await service._prepare_tables()
1379+
schema = service._get_schema_classes()
1380+
1381+
# Pre-create the app_state row.
1382+
async with service.database_session_factory() as sql_session:
1383+
sql_session.add(schema.StorageAppState(app_name='app1', state={'k': 'v'}))
1384+
await sql_session.commit()
1385+
1386+
# _get_or_create_state should find and return it.
1387+
async with service.database_session_factory() as sql_session:
1388+
row = await database_session_service._get_or_create_state(
1389+
sql_session=sql_session,
1390+
state_model=schema.StorageAppState,
1391+
primary_key='app1',
1392+
defaults={'app_name': 'app1', 'state': {}},
1393+
)
1394+
assert row.app_name == 'app1'
1395+
assert row.state == {'k': 'v'}
1396+
finally:
1397+
await service.close()
1398+
1399+
1400+
@pytest.mark.asyncio
1401+
async def test_get_or_create_state_creates_new_row():
1402+
"""_get_or_create_state creates a row when none exists."""
1403+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
1404+
try:
1405+
await service._prepare_tables()
1406+
schema = service._get_schema_classes()
1407+
1408+
async with service.database_session_factory() as sql_session:
1409+
row = await database_session_service._get_or_create_state(
1410+
sql_session=sql_session,
1411+
state_model=schema.StorageAppState,
1412+
primary_key='new_app',
1413+
defaults={'app_name': 'new_app', 'state': {}},
1414+
)
1415+
await sql_session.commit()
1416+
assert row.app_name == 'new_app'
1417+
assert row.state == {}
1418+
1419+
# Verify the row was actually persisted.
1420+
async with service.database_session_factory() as sql_session:
1421+
persisted = await sql_session.get(schema.StorageAppState, 'new_app')
1422+
assert persisted is not None
1423+
finally:
1424+
await service.close()
1425+
1426+
1427+
@pytest.mark.asyncio
1428+
async def test_get_or_create_state_handles_race_condition():
1429+
"""_get_or_create_state recovers when a concurrent INSERT wins the race.
1430+
1431+
Simulates the race from https://github.com/google/adk-python/issues/4954:
1432+
the initial SELECT returns None (another caller hasn't committed yet), but
1433+
by the time we INSERT, the other caller has committed — so the INSERT fails
1434+
with IntegrityError and we fall back to re-fetching.
1435+
"""
1436+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
1437+
try:
1438+
await service._prepare_tables()
1439+
schema = service._get_schema_classes()
1440+
1441+
# Pre-create the row to guarantee the INSERT will fail.
1442+
async with service.database_session_factory() as sql_session:
1443+
sql_session.add(schema.StorageAppState(app_name='race_app', state={}))
1444+
await sql_session.commit()
1445+
1446+
# Patch session.get to return None on the first call (simulating the
1447+
# race window), then fall through to the real implementation.
1448+
async with service.database_session_factory() as sql_session:
1449+
original_get = sql_session.get
1450+
call_count = 0
1451+
1452+
async def patched_get(*args, **kwargs):
1453+
nonlocal call_count
1454+
call_count += 1
1455+
if call_count == 1:
1456+
return None # Simulate: row not yet visible
1457+
return await original_get(*args, **kwargs)
1458+
1459+
sql_session.get = patched_get
1460+
1461+
row = await database_session_service._get_or_create_state(
1462+
sql_session=sql_session,
1463+
state_model=schema.StorageAppState,
1464+
primary_key='race_app',
1465+
defaults={'app_name': 'race_app', 'state': {}},
1466+
)
1467+
assert row.app_name == 'race_app'
1468+
# The function should have called get twice: once before the INSERT
1469+
# (patched to return None) and once after the IntegrityError.
1470+
assert call_count == 2
1471+
finally:
1472+
await service.close()
1473+
1474+
1475+
@pytest.mark.asyncio
1476+
async def test_create_session_sequential_same_app_name():
1477+
"""Sequential create_session calls for the same app_name work correctly.
1478+
1479+
The second call reuses the existing app_states row.
1480+
"""
1481+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
1482+
try:
1483+
s1 = await service.create_session(
1484+
app_name='shared', user_id='u1', session_id='s1'
1485+
)
1486+
s2 = await service.create_session(
1487+
app_name='shared', user_id='u2', session_id='s2'
1488+
)
1489+
assert s1.app_name == 'shared'
1490+
assert s2.app_name == 'shared'
1491+
1492+
got1 = await service.get_session(
1493+
app_name='shared', user_id='u1', session_id='s1'
1494+
)
1495+
got2 = await service.get_session(
1496+
app_name='shared', user_id='u2', session_id='s2'
1497+
)
1498+
assert got1 is not None
1499+
assert got2 is not None
1500+
finally:
1501+
await service.close()
1502+
1503+
13731504
@pytest.mark.asyncio
13741505
async def test_prepare_tables_idempotent_after_creation():
13751506
"""Calling _prepare_tables multiple times is safe and idempotent.

0 commit comments

Comments
 (0)