|
33 | 33 | from sqlalchemy.engine import Connection |
34 | 34 | from sqlalchemy.engine import make_url |
35 | 35 | from sqlalchemy.exc import ArgumentError |
| 36 | +from sqlalchemy.exc import IntegrityError |
36 | 37 | from sqlalchemy.ext.asyncio import async_sessionmaker |
37 | 38 | from sqlalchemy.ext.asyncio import AsyncEngine |
38 | 39 | from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory |
@@ -103,6 +104,35 @@ async def _select_required_state( |
103 | 104 | return state_row |
104 | 105 |
|
105 | 106 |
|
| 107 | +async def _get_or_create_state( |
| 108 | + *, |
| 109 | + sql_session: DatabaseSessionFactory, |
| 110 | + state_model: type[_StorageStateT], |
| 111 | + primary_key: Any, |
| 112 | + defaults: dict[str, Any], |
| 113 | +) -> _StorageStateT: |
| 114 | + """Returns an existing state row or creates one, handling concurrent inserts. |
| 115 | +
|
| 116 | + Uses a SAVEPOINT so that an IntegrityError from a racing INSERT does not |
| 117 | + invalidate the outer transaction. |
| 118 | + """ |
| 119 | + row = await sql_session.get(state_model, primary_key) |
| 120 | + if row is not None: |
| 121 | + return row |
| 122 | + try: |
| 123 | + async with sql_session.begin_nested(): |
| 124 | + row = state_model(**defaults) |
| 125 | + sql_session.add(row) |
| 126 | + return row |
| 127 | + except IntegrityError: |
| 128 | + # Another concurrent caller inserted the row first. |
| 129 | + # The savepoint was rolled back, so re-fetch the winner's row. |
| 130 | + row = await sql_session.get(state_model, primary_key) |
| 131 | + if row is None: |
| 132 | + raise |
| 133 | + return row |
| 134 | + |
| 135 | + |
106 | 136 | def _set_sqlite_pragma(dbapi_connection, connection_record): |
107 | 137 | cursor = dbapi_connection.cursor() |
108 | 138 | cursor.execute("PRAGMA foreign_keys=ON") |
@@ -401,24 +431,20 @@ async def create_session( |
401 | 431 | raise AlreadyExistsError( |
402 | 432 | f"Session with id {session_id} already exists." |
403 | 433 | ) |
404 | | - # Fetch app and user states from storage |
405 | | - storage_app_state = await sql_session.get( |
406 | | - schema.StorageAppState, (app_name) |
| 434 | + # Get or create state rows, handling concurrent insert races. |
| 435 | + storage_app_state = await _get_or_create_state( |
| 436 | + sql_session=sql_session, |
| 437 | + state_model=schema.StorageAppState, |
| 438 | + primary_key=app_name, |
| 439 | + defaults={"app_name": app_name, "state": {}}, |
407 | 440 | ) |
408 | | - storage_user_state = await sql_session.get( |
409 | | - schema.StorageUserState, (app_name, user_id) |
| 441 | + storage_user_state = await _get_or_create_state( |
| 442 | + sql_session=sql_session, |
| 443 | + state_model=schema.StorageUserState, |
| 444 | + primary_key=(app_name, user_id), |
| 445 | + defaults={"app_name": app_name, "user_id": user_id, "state": {}}, |
410 | 446 | ) |
411 | 447 |
|
412 | | - # Create state tables if not exist |
413 | | - if not storage_app_state: |
414 | | - storage_app_state = schema.StorageAppState(app_name=app_name, state={}) |
415 | | - sql_session.add(storage_app_state) |
416 | | - if not storage_user_state: |
417 | | - storage_user_state = schema.StorageUserState( |
418 | | - app_name=app_name, user_id=user_id, state={} |
419 | | - ) |
420 | | - sql_session.add(storage_user_state) |
421 | | - |
422 | 448 | # Extract state deltas |
423 | 449 | state_deltas = _session_util.extract_state_delta(state) |
424 | 450 | app_state_delta = state_deltas["app"] |
|
0 commit comments