@@ -15,7 +15,7 @@ class simulates the asyncpg pool interface using in-memory dicts to
1515from typing import Any
1616
1717import pytest
18- from chat_sdk .state .postgres import PostgresStateAdapter , _pg_timestamp_from_ms
18+ from chat_sdk .state .postgres import PostgresStateAdapter
1919from chat_sdk .types import Lock , QueueEntry
2020
2121
@@ -276,14 +276,16 @@ async def fetchrow(self, query: str, *args: Any) -> _Record | None:
276276 return _Record ({"_" : 1 })
277277 return None
278278
279- # -- locks: acquire (INSERT ... ON CONFLICT ... RETURNING ) --
279+ # -- locks: acquire (atomic upsert: INSERT ... ON CONFLICT DO UPDATE WHERE expired ) --
280280 if "insert into chat_state_locks" in q :
281- key_prefix , thread_id , token , expires_at = args [0 ], args [1 ], args [2 ], args [3 ]
281+ key_prefix , thread_id , token = args [0 ], args [1 ], args [2 ]
282+ ttl_ms = args [3 ]
282283 lock_key = (key_prefix , thread_id )
284+ expires_at = self ._now () + _dt .timedelta (milliseconds = ttl_ms )
283285 existing = self .locks .get (lock_key )
284286
285287 if existing is None :
286- # No existing lock -- acquire
288+ # No existing row -- INSERT succeeds
287289 self .locks [lock_key ] = {
288290 "token" : token ,
289291 "expires_at" : expires_at ,
@@ -297,7 +299,7 @@ async def fetchrow(self, query: str, *args: Any) -> _Record | None:
297299 }
298300 )
299301
300- # Existing lock present -- only overwrite if expired
302+ # Row exists -- DO UPDATE fires only when expired
301303 if existing ["expires_at" ] <= self ._now ():
302304 self .locks [lock_key ] = {
303305 "token" : token ,
@@ -312,7 +314,7 @@ async def fetchrow(self, query: str, *args: Any) -> _Record | None:
312314 }
313315 )
314316
315- # Lock is still held
317+ # Lock is still held -- DO UPDATE WHERE fails, RETURNING not fired
316318 return None
317319
318320 # -- cache: get (SELECT value FROM chat_state_cache) --
@@ -704,6 +706,64 @@ async def test_independent_locks_per_thread(self, pg_state: PostgresStateAdapter
704706 assert lock2 is not None
705707 assert lock1 .token != lock2 .token
706708
709+ @pytest .mark .asyncio
710+ async def test_acquire_lock_uses_single_atomic_upsert (
711+ self , pg_state : PostgresStateAdapter , mock_pool : MockAsyncpgPool
712+ ):
713+ """Verify acquire_lock issues exactly one SQL statement (atomic upsert).
714+
715+ The old two-step approach (INSERT ... DO NOTHING then UPDATE ... WHERE
716+ expired) had a TOCTOU race: two callers could both see the INSERT fail,
717+ then both attempt the UPDATE. The fix uses a single INSERT ... ON
718+ CONFLICT DO UPDATE WHERE expired, which is atomic because Postgres
719+ acquires a row lock on the conflicting row.
720+ """
721+ # Clear any queries from fixture setup (connect / schema creation)
722+ mock_pool .executed_queries .clear ()
723+
724+ # First acquire: new row inserted
725+ lock1 = await pg_state .acquire_lock ("race-thread" , 30_000 )
726+ assert lock1 is not None
727+
728+ # Should have issued exactly one query for the lock acquisition
729+ lock_queries = [
730+ q for q in mock_pool .executed_queries if "chat_state_locks" in q .lower ()
731+ ]
732+ assert len (lock_queries ) == 1 , (
733+ f"Expected 1 atomic upsert query, got { len (lock_queries )} : { lock_queries } "
734+ )
735+
736+ # Second acquire while held: should fail in single query too
737+ mock_pool .executed_queries .clear ()
738+ lock2 = await pg_state .acquire_lock ("race-thread" , 30_000 )
739+ assert lock2 is None
740+
741+ lock_queries = [
742+ q for q in mock_pool .executed_queries if "chat_state_locks" in q .lower ()
743+ ]
744+ assert len (lock_queries ) == 1 , (
745+ f"Expected 1 atomic upsert query for contended lock, got { len (lock_queries )} "
746+ )
747+
748+ # Third acquire after expiry: should succeed in single query
749+ mock_pool .executed_queries .clear ()
750+ time .sleep (0.005 )
751+ # Force-expire the lock for testing
752+ lock_key = ("test" , "race-thread" )
753+ mock_pool .locks [lock_key ]["expires_at" ] = _dt .datetime .now (
754+ _dt .timezone .utc
755+ ) - _dt .timedelta (seconds = 1 )
756+
757+ lock3 = await pg_state .acquire_lock ("race-thread" , 30_000 )
758+ assert lock3 is not None
759+
760+ lock_queries = [
761+ q for q in mock_pool .executed_queries if "chat_state_locks" in q .lower ()
762+ ]
763+ assert len (lock_queries ) == 1 , (
764+ f"Expected 1 atomic upsert query for expired lock, got { len (lock_queries )} "
765+ )
766+
707767
708768# ============================================================================
709769# List operations
0 commit comments