Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions src/chat_sdk/state/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,35 +213,27 @@ async def acquire_lock(self, thread_id: str, ttl_ms: int) -> Lock | None:
self._ensure_connected()

token = _generate_token()
expires_at = _pg_timestamp_from_ms(ttl_ms)

# Two-step approach to prevent race condition when lock just expired.
# Step 1: Try INSERT for new locks (no existing row).
# Atomic upsert: INSERT succeeds for new rows; ON CONFLICT DO UPDATE
# fires only when the existing row is expired (WHERE expires_at <= now()).
# Postgres acquires a row lock on the conflicting row, so only one
# concurrent caller can win — eliminating the TOCTOU race that existed
# in the previous two-step INSERT-then-UPDATE approach.
row = await self._pool.fetchrow(
"""INSERT INTO chat_state_locks (key_prefix, thread_id, token, expires_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (key_prefix, thread_id) DO NOTHING
VALUES ($1, $2, $3, now() + make_interval(secs => $4::float / 1000))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the extend_lock implementation (line 269) and to simplify the SQL, you can use the interval multiplication syntax instead of make_interval. This also avoids the explicit cast to float and division.

Suggested change
VALUES ($1, $2, $3, now() + make_interval(secs => $4::float / 1000))
VALUES ($1, $2, $3, now() + $4 * interval '1 millisecond')

ON CONFLICT (key_prefix, thread_id) DO UPDATE
SET token = EXCLUDED.token,
expires_at = EXCLUDED.expires_at,
updated_at = now()
WHERE chat_state_locks.expires_at <= now()
RETURNING thread_id, token, expires_at""",
self._key_prefix,
thread_id,
token,
expires_at,
ttl_ms,
)

if row is None:
# Step 2: Row exists — try UPDATE only if expired.
# UPDATE acquires a row lock, so only one concurrent caller wins.
row = await self._pool.fetchrow(
"""UPDATE chat_state_locks
SET token = $3, expires_at = $4, updated_at = now()
WHERE key_prefix = $1 AND thread_id = $2 AND expires_at <= now()
RETURNING thread_id, token, expires_at""",
self._key_prefix,
thread_id,
token,
expires_at,
)

if row is None:
return None

Expand Down
72 changes: 66 additions & 6 deletions tests/test_state_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class simulates the asyncpg pool interface using in-memory dicts to
from typing import Any

import pytest
from chat_sdk.state.postgres import PostgresStateAdapter, _pg_timestamp_from_ms
from chat_sdk.state.postgres import PostgresStateAdapter
from chat_sdk.types import Lock, QueueEntry


Expand Down Expand Up @@ -276,14 +276,16 @@ async def fetchrow(self, query: str, *args: Any) -> _Record | None:
return _Record({"_": 1})
return None

# -- locks: acquire (INSERT ... ON CONFLICT ... RETURNING) --
# -- locks: acquire (atomic upsert: INSERT ... ON CONFLICT DO UPDATE WHERE expired) --
if "insert into chat_state_locks" in q:
key_prefix, thread_id, token, expires_at = args[0], args[1], args[2], args[3]
key_prefix, thread_id, token = args[0], args[1], args[2]
ttl_ms = args[3]
lock_key = (key_prefix, thread_id)
expires_at = self._now() + _dt.timedelta(milliseconds=ttl_ms)
existing = self.locks.get(lock_key)

if existing is None:
# No existing lock -- acquire
# No existing row -- INSERT succeeds
self.locks[lock_key] = {
"token": token,
"expires_at": expires_at,
Expand All @@ -297,7 +299,7 @@ async def fetchrow(self, query: str, *args: Any) -> _Record | None:
}
)

# Existing lock present -- only overwrite if expired
# Row exists -- DO UPDATE fires only when expired
if existing["expires_at"] <= self._now():
self.locks[lock_key] = {
"token": token,
Expand All @@ -312,7 +314,7 @@ async def fetchrow(self, query: str, *args: Any) -> _Record | None:
}
)

# Lock is still held
# Lock is still held -- DO UPDATE WHERE fails, RETURNING not fired
return None

# -- cache: get (SELECT value FROM chat_state_cache) --
Expand Down Expand Up @@ -704,6 +706,64 @@ async def test_independent_locks_per_thread(self, pg_state: PostgresStateAdapter
assert lock2 is not None
assert lock1.token != lock2.token

@pytest.mark.asyncio
async def test_acquire_lock_uses_single_atomic_upsert(
self, pg_state: PostgresStateAdapter, mock_pool: MockAsyncpgPool
):
"""Verify acquire_lock issues exactly one SQL statement (atomic upsert).

The old two-step approach (INSERT ... DO NOTHING then UPDATE ... WHERE
expired) had a TOCTOU race: two callers could both see the INSERT fail,
then both attempt the UPDATE. The fix uses a single INSERT ... ON
CONFLICT DO UPDATE WHERE expired, which is atomic because Postgres
acquires a row lock on the conflicting row.
"""
# Clear any queries from fixture setup (connect / schema creation)
mock_pool.executed_queries.clear()

# First acquire: new row inserted
lock1 = await pg_state.acquire_lock("race-thread", 30_000)
assert lock1 is not None

# Should have issued exactly one query for the lock acquisition
lock_queries = [
q for q in mock_pool.executed_queries if "chat_state_locks" in q.lower()
]
assert len(lock_queries) == 1, (
f"Expected 1 atomic upsert query, got {len(lock_queries)}: {lock_queries}"
)

# Second acquire while held: should fail in single query too
mock_pool.executed_queries.clear()
lock2 = await pg_state.acquire_lock("race-thread", 30_000)
assert lock2 is None

lock_queries = [
q for q in mock_pool.executed_queries if "chat_state_locks" in q.lower()
]
assert len(lock_queries) == 1, (
f"Expected 1 atomic upsert query for contended lock, got {len(lock_queries)}"
)

# Third acquire after expiry: should succeed in single query
mock_pool.executed_queries.clear()
time.sleep(0.005)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using time.sleep() in an asynchronous test blocks the event loop. While it might not cause issues in this specific isolated test, it is better practice to use await asyncio.sleep() to allow the event loop to continue running other tasks if necessary.

Suggested change
time.sleep(0.005)
await asyncio.sleep(0.005)

# Force-expire the lock for testing
lock_key = ("test", "race-thread")
mock_pool.locks[lock_key]["expires_at"] = _dt.datetime.now(
_dt.timezone.utc
) - _dt.timedelta(seconds=1)

lock3 = await pg_state.acquire_lock("race-thread", 30_000)
assert lock3 is not None

lock_queries = [
q for q in mock_pool.executed_queries if "chat_state_locks" in q.lower()
]
assert len(lock_queries) == 1, (
f"Expected 1 atomic upsert query for expired lock, got {len(lock_queries)}"
)


# ============================================================================
# List operations
Expand Down
Loading