Skip to content

Commit fcfe55b

Browse files
Merge pull request #20 from Chinchill-AI/fix/acquire-lock-toctou-race
fix: replace two-step acquire_lock with atomic upsert to prevent TOCTOU race
2 parents 2d655bb + 2326089 commit fcfe55b

2 files changed

Lines changed: 78 additions & 26 deletions

File tree

src/chat_sdk/state/postgres.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -213,35 +213,27 @@ async def acquire_lock(self, thread_id: str, ttl_ms: int) -> Lock | None:
213213
self._ensure_connected()
214214

215215
token = _generate_token()
216-
expires_at = _pg_timestamp_from_ms(ttl_ms)
217216

218-
# Two-step approach to prevent race condition when lock just expired.
219-
# Step 1: Try INSERT for new locks (no existing row).
217+
# Atomic upsert: INSERT succeeds for new rows; ON CONFLICT DO UPDATE
218+
# fires only when the existing row is expired (WHERE expires_at <= now()).
219+
# Postgres acquires a row lock on the conflicting row, so only one
220+
# concurrent caller can win — eliminating the TOCTOU race that existed
221+
# in the previous two-step INSERT-then-UPDATE approach.
220222
row = await self._pool.fetchrow(
221223
"""INSERT INTO chat_state_locks (key_prefix, thread_id, token, expires_at)
222-
VALUES ($1, $2, $3, $4)
223-
ON CONFLICT (key_prefix, thread_id) DO NOTHING
224+
VALUES ($1, $2, $3, now() + make_interval(secs => $4::float / 1000))
225+
ON CONFLICT (key_prefix, thread_id) DO UPDATE
226+
SET token = EXCLUDED.token,
227+
expires_at = EXCLUDED.expires_at,
228+
updated_at = now()
229+
WHERE chat_state_locks.expires_at <= now()
224230
RETURNING thread_id, token, expires_at""",
225231
self._key_prefix,
226232
thread_id,
227233
token,
228-
expires_at,
234+
ttl_ms,
229235
)
230236

231-
if row is None:
232-
# Step 2: Row exists — try UPDATE only if expired.
233-
# UPDATE acquires a row lock, so only one concurrent caller wins.
234-
row = await self._pool.fetchrow(
235-
"""UPDATE chat_state_locks
236-
SET token = $3, expires_at = $4, updated_at = now()
237-
WHERE key_prefix = $1 AND thread_id = $2 AND expires_at <= now()
238-
RETURNING thread_id, token, expires_at""",
239-
self._key_prefix,
240-
thread_id,
241-
token,
242-
expires_at,
243-
)
244-
245237
if row is None:
246238
return None
247239

tests/test_state_postgres.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class simulates the asyncpg pool interface using in-memory dicts to
1515
from typing import Any
1616

1717
import pytest
18-
from chat_sdk.state.postgres import PostgresStateAdapter, _pg_timestamp_from_ms
18+
from chat_sdk.state.postgres import PostgresStateAdapter
1919
from chat_sdk.types import Lock, QueueEntry
2020

2121

@@ -276,14 +276,16 @@ async def fetchrow(self, query: str, *args: Any) -> _Record | None:
276276
return _Record({"_": 1})
277277
return None
278278

279-
# -- locks: acquire (INSERT ... ON CONFLICT ... RETURNING) --
279+
# -- locks: acquire (atomic upsert: INSERT ... ON CONFLICT DO UPDATE WHERE expired) --
280280
if "insert into chat_state_locks" in q:
281-
key_prefix, thread_id, token, expires_at = args[0], args[1], args[2], args[3]
281+
key_prefix, thread_id, token = args[0], args[1], args[2]
282+
ttl_ms = args[3]
282283
lock_key = (key_prefix, thread_id)
284+
expires_at = self._now() + _dt.timedelta(milliseconds=ttl_ms)
283285
existing = self.locks.get(lock_key)
284286

285287
if existing is None:
286-
# No existing lock -- acquire
288+
# No existing row -- INSERT succeeds
287289
self.locks[lock_key] = {
288290
"token": token,
289291
"expires_at": expires_at,
@@ -297,7 +299,7 @@ async def fetchrow(self, query: str, *args: Any) -> _Record | None:
297299
}
298300
)
299301

300-
# Existing lock present -- only overwrite if expired
302+
# Row exists -- DO UPDATE fires only when expired
301303
if existing["expires_at"] <= self._now():
302304
self.locks[lock_key] = {
303305
"token": token,
@@ -312,7 +314,7 @@ async def fetchrow(self, query: str, *args: Any) -> _Record | None:
312314
}
313315
)
314316

315-
# Lock is still held
317+
# Lock is still held -- DO UPDATE WHERE fails, RETURNING not fired
316318
return None
317319

318320
# -- cache: get (SELECT value FROM chat_state_cache) --
@@ -704,6 +706,64 @@ async def test_independent_locks_per_thread(self, pg_state: PostgresStateAdapter
704706
assert lock2 is not None
705707
assert lock1.token != lock2.token
706708

709+
@pytest.mark.asyncio
710+
async def test_acquire_lock_uses_single_atomic_upsert(
711+
self, pg_state: PostgresStateAdapter, mock_pool: MockAsyncpgPool
712+
):
713+
"""Verify acquire_lock issues exactly one SQL statement (atomic upsert).
714+
715+
The old two-step approach (INSERT ... DO NOTHING then UPDATE ... WHERE
716+
expired) had a TOCTOU race: two callers could both see the INSERT fail,
717+
then both attempt the UPDATE. The fix uses a single INSERT ... ON
718+
CONFLICT DO UPDATE WHERE expired, which is atomic because Postgres
719+
acquires a row lock on the conflicting row.
720+
"""
721+
# Clear any queries from fixture setup (connect / schema creation)
722+
mock_pool.executed_queries.clear()
723+
724+
# First acquire: new row inserted
725+
lock1 = await pg_state.acquire_lock("race-thread", 30_000)
726+
assert lock1 is not None
727+
728+
# Should have issued exactly one query for the lock acquisition
729+
lock_queries = [
730+
q for q in mock_pool.executed_queries if "chat_state_locks" in q.lower()
731+
]
732+
assert len(lock_queries) == 1, (
733+
f"Expected 1 atomic upsert query, got {len(lock_queries)}: {lock_queries}"
734+
)
735+
736+
# Second acquire while held: should fail in single query too
737+
mock_pool.executed_queries.clear()
738+
lock2 = await pg_state.acquire_lock("race-thread", 30_000)
739+
assert lock2 is None
740+
741+
lock_queries = [
742+
q for q in mock_pool.executed_queries if "chat_state_locks" in q.lower()
743+
]
744+
assert len(lock_queries) == 1, (
745+
f"Expected 1 atomic upsert query for contended lock, got {len(lock_queries)}"
746+
)
747+
748+
# Third acquire after expiry: should succeed in single query
749+
mock_pool.executed_queries.clear()
750+
time.sleep(0.005)
751+
# Force-expire the lock for testing
752+
lock_key = ("test", "race-thread")
753+
mock_pool.locks[lock_key]["expires_at"] = _dt.datetime.now(
754+
_dt.timezone.utc
755+
) - _dt.timedelta(seconds=1)
756+
757+
lock3 = await pg_state.acquire_lock("race-thread", 30_000)
758+
assert lock3 is not None
759+
760+
lock_queries = [
761+
q for q in mock_pool.executed_queries if "chat_state_locks" in q.lower()
762+
]
763+
assert len(lock_queries) == 1, (
764+
f"Expected 1 atomic upsert query for expired lock, got {len(lock_queries)}"
765+
)
766+
707767

708768
# ============================================================================
709769
# List operations

0 commit comments

Comments
 (0)