From fefa0264ea70dd74d8001c47b1d18c6d3ce1e433 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 19 Mar 2026 14:02:46 -0700 Subject: [PATCH 1/3] Fix race conditions in StateManagerRedis lock detection Fix other race conditions also present in the tests. In short: both the framework and tests must deterministically wait for the Redis pubsub psubscribe call to complete to ensure that events, such as lock release, are properly recorded. Co-authored-by: Farhan Ali Raza --- reflex/istate/manager/redis.py | 14 ++++++-- tests/units/istate/manager/test_redis.py | 44 ++++++++++++++++++++---- tests/units/mock_redis.py | 28 +++++++++++---- 3 files changed, 71 insertions(+), 15 deletions(-) diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index bbfd8e20ae2..abc6fbaf935 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -153,6 +153,10 @@ class StateManagerRedis(StateManager): default_factory=dict, init=False, ) + _lock_updates_subscribed: asyncio.Event = dataclasses.field( + default_factory=asyncio.Event, + init=False, + ) _lock_task: asyncio.Task | None = dataclasses.field(default=None, init=False) # Whether debug prints are enabled. @@ -802,8 +806,12 @@ async def _subscribe_lock_updates(self): } async with self.redis.pubsub() as pubsub: await pubsub.psubscribe(**handlers) # pyright: ignore[reportArgumentType] - async for _ in pubsub.listen(): - pass + self._lock_updates_subscribed.set() + try: + async for _ in pubsub.listen(): + pass + finally: + self._lock_updates_subscribed.clear() def _ensure_lock_task(self) -> None: """Ensure the lock updates subscriber task is running.""" @@ -971,6 +979,8 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: return # Make sure lock waiter task is running. self._ensure_lock_task() + # Make sure the lock waiter is subscribed to avoid missing notifications. + await self._lock_updates_subscribed.wait() async with ( self._lock_waiter(lock_key) as lock_released_event, self._request_lock_release(lock_key, lock_id), diff --git a/tests/units/istate/manager/test_redis.py b/tests/units/istate/manager/test_redis.py index 076268ca9c6..d5fee452c5a 100644 --- a/tests/units/istate/manager/test_redis.py +++ b/tests/units/istate/manager/test_redis.py @@ -22,6 +22,14 @@ class RedisTestState(BaseState): count: int = 0 +class SubState1(RedisTestState): + """A test substate for redis state manager tests.""" + + +class SubState2(RedisTestState): + """A test substate for redis state manager tests.""" + + @pytest.fixture def root_state() -> type[RedisTestState]: @@ -65,6 +73,22 @@ def event_log(state_manager_redis: StateManagerRedis) -> list[dict[str, Any]]: return state_manager_redis.redis._internals["event_log"] # pyright: ignore[reportAttributeAccessIssue] +@pytest.fixture +def event_log_on_update(state_manager_redis: StateManagerRedis) -> asyncio.Event: + """Get the event for new event records being added to the redis event log. + + Test is responsible for calling `.clear` before an operation when it needs + to detect a new event added afterward. + + Args: + state_manager_redis: The StateManagerRedis. + + Returns: + The event that is set when new events are added to the redis event log. + """ + return state_manager_redis.redis._internals["event_log_on_update"] # pyright: ignore[reportAttributeAccessIssue] + + @pytest.mark.asyncio async def test_basic_get_set( state_manager_redis: StateManagerRedis, @@ -123,6 +147,7 @@ async def test_modify_oplock( state_manager_redis: StateManagerRedis, root_state: type[RedisTestState], event_log: list[dict[str, Any]], + event_log_on_update: asyncio.Event, ): """Test modifying state with StateManagerRedis with optimistic locking. @@ -130,6 +155,7 @@ async def test_modify_oplock( state_manager_redis: The StateManagerRedis to test. root_state: The root state class. event_log: The redis event log. + event_log_on_update: The event for new event records being added to the redis event log. """ token = str(uuid.uuid4()) @@ -143,6 +169,8 @@ async def test_modify_oplock( state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True + event_log_on_update.clear() + # Initial modify should set count to 1 async with state_manager_redis.modify_state( _substate_key(token, root_state), @@ -159,6 +187,7 @@ async def test_modify_oplock( assert state_lock_1 is not None assert not state_lock_1.locked() + await event_log_on_update.wait() lock_events_before = len([ ev for ev in event_log @@ -182,6 +211,7 @@ async def test_modify_oplock( assert lock_events_before == lock_events_after # Contend the lock from another state manager + event_log_on_update.clear() async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: @@ -203,6 +233,7 @@ async def test_modify_oplock( assert token not in state_manager_redis._cached_states # There should have been another redis lock taken. + await event_log_on_update.wait() lock_events_after_2 = len([ ev for ev in event_log @@ -228,7 +259,9 @@ async def test_modify_oplock( assert token_set_events == 1 # Now close the contender to release its lease. + event_log_on_update.clear() await state_manager_2.close() + await event_log_on_update.wait() # Both locks should have been released. unlock_events = len([ @@ -562,13 +595,6 @@ async def test_oplock_fetch_substate( root_state: The root state class. event_log: The redis event log. """ - - class SubState1(root_state): - pass - - class SubState2(root_state): - pass - token = str(uuid.uuid4()) state_manager_redis._debug_enabled = True @@ -627,6 +653,7 @@ async def test_oplock_hold_oplock_after_cancel( state_manager_redis: StateManagerRedis, root_state: type[RedisTestState], event_log: list[dict[str, Any]], + event_log_on_update: asyncio.Event, short_lock_expiration: int, ): """Test that cancelling a modify does not release the oplock prematurely. @@ -635,6 +662,7 @@ async def test_oplock_hold_oplock_after_cancel( state_manager_redis: The StateManagerRedis to test. root_state: The root state class. event_log: The redis event log. + event_log_on_update: The event log update event. short_lock_expiration: The lock expiration time in milliseconds. """ token = str(uuid.uuid4()) @@ -683,6 +711,7 @@ async def modify(): await lease_task # Modify the state again, this should get a new lock and lease + event_log_on_update.clear() async with state_manager_redis.modify_state( _substate_key(token, root_state), ) as new_state: @@ -690,6 +719,7 @@ async def modify(): new_state.count += 1 # There should have been two redis lock acquisitions. + await event_log_on_update.wait() lock_events = len([ ev for ev in event_log diff --git a/tests/units/mock_redis.py b/tests/units/mock_redis.py index 4329b2139f7..a9832bbcccc 100644 --- a/tests/units/mock_redis.py +++ b/tests/units/mock_redis.py @@ -26,6 +26,12 @@ def mock_redis() -> Redis: expire_times: dict[bytes, float] = {} event_log: list[dict[str, bytes]] = [] event_log_new_events = asyncio.Event() + event_log_on_update = asyncio.Event() + + def _event_log_append_notify(event: dict[str, bytes]) -> None: + event_log.append(event) + event_log_new_events.set() + event_log_on_update.set() def _key_bytes(key: KeyT) -> bytes: if isinstance(key, str): @@ -39,8 +45,7 @@ def _keyspace_event(key: KeyT, data: str | bytes): key = key.encode() if isinstance(data, str): data = data.encode() - event_log.append({"channel": b"__keyspace@1__:" + key, "data": data}) - event_log_new_events.set() + _event_log_append_notify({"channel": b"__keyspace@1__:" + key, "data": data}) def _expire_keys(): to_delete = [] @@ -192,12 +197,16 @@ async def psubscribe( # noqa: RUF029 for pattern in patterns: watch_patterns[pattern] = None - event_log.append({"channel": b"psubscribe", "data": pattern.encode()}) - event_log_new_events.set() + _event_log_append_notify({ + "channel": b"psubscribe", + "data": pattern.encode(), + }) for pattern, handler in handlers.items(): watch_patterns[pattern] = handler - event_log.append({"channel": b"psubscribe", "data": pattern.encode()}) - event_log_new_events.set() + _event_log_append_notify({ + "channel": b"psubscribe", + "data": pattern.encode(), + }) async def listen() -> AsyncGenerator[dict[str, Any] | None, None]: nonlocal event_log_pointer @@ -243,6 +252,7 @@ async def listen() -> AsyncGenerator[dict[str, Any] | None, None]: "keys": keys, "expire_times": expire_times, "event_log": event_log, + "event_log_on_update": event_log_on_update, } return redis_mock @@ -261,24 +271,30 @@ async def real_redis() -> AsyncGenerator[Redis | None]: # Create a pubsub to keep the internal event log for assertions. event_log = [] + event_log_on_update = asyncio.Event() object.__setattr__( redis, "_internals", { "event_log": event_log, + "event_log_on_update": event_log_on_update, }, ) redis_db = redis.get_connection_kwargs().get("db", 0) + pubsub_ready = asyncio.Event() async def log_events(): async with redis.pubsub() as pubsub: await pubsub.psubscribe(f"__keyspace@{redis_db}__:*") + pubsub_ready.set() async for message in pubsub.listen(): if message is None: continue event_log.append(message) + event_log_on_update.set() log_events_task = asyncio.create_task(log_events()) + await pubsub_ready.wait() try: yield redis finally: From 4eaf1c3e1fe2dd849d605b6d90db74fd237500c0 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 19 Mar 2026 14:25:03 -0700 Subject: [PATCH 2/3] bounded wait for lock waiter's psubscribe to start --- reflex/istate/manager/redis.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index abc6fbaf935..35570334860 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -61,6 +61,10 @@ def _default_oplock_hold_time_ms() -> int: ) +# The lock waiter task should subscribe to lock channel updates within this period. +LOCK_SUBSCRIBE_TASK_TIMEOUT = 2 # seconds + + SMR = f"[SMR:{os.getpid()}]" start = time.monotonic() @@ -980,7 +984,12 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: # Make sure lock waiter task is running. self._ensure_lock_task() # Make sure the lock waiter is subscribed to avoid missing notifications. - await self._lock_updates_subscribed.wait() + await asyncio.wait_for( + self._lock_updates_subscribed.wait(), + timeout=min( + LOCK_SUBSCRIBE_TASK_TIMEOUT, max(self.lock_expiration / 1000, 0) + ), + ) async with ( self._lock_waiter(lock_key) as lock_released_event, self._request_lock_release(lock_key, lock_id), From 797c8a104511598af72b0656b733ed5a9496b8d1 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 19 Mar 2026 14:31:52 -0700 Subject: [PATCH 3/3] create _ensure_lock_task_subscribed helper to encapsulate logic ignore TimeoutError for this call, because we still want to enter the lock waiter loop, even if there was some redis problem, just the waiter might have to wait for the full lock expiration (fallback case to previous behavior). but we should _NEVER_ deadlock the process waiting for the psubscribe event and we should _NEVER_ exit _wait_for just because the timeout for the previous expired. --- reflex/istate/manager/redis.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index 35570334860..35b87b92eaa 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -826,6 +826,30 @@ def _ensure_lock_task(self) -> None: suppress_exceptions=[Exception], ) + async def _ensure_lock_task_subscribed(self, timeout: float | None = None) -> None: + """Ensure the lock updates subscriber task is running and subscribed to avoid missing notifications. + + Args: + timeout: How long to wait for the subscriber to be subscribed before + raising an error. If None, defaults to + min(LOCK_SUBSCRIBE_TASK_TIMEOUT, lock_expiration). + + Raises: + TimeoutError: If the lock updates subscriber task fails to subscribe in time. + """ + if timeout is None: + timeout = min( + LOCK_SUBSCRIBE_TASK_TIMEOUT, + max(self.lock_expiration / 1000, 0), + ) + # Make sure lock waiter task is running. + self._ensure_lock_task() + # Make sure the lock waiter is subscribed to avoid missing notifications. + await asyncio.wait_for( + self._lock_updates_subscribed.wait(), + timeout=timeout, + ) + async def _enable_keyspace_notifications(self): """Enable keyspace notifications for the redis server. @@ -982,14 +1006,8 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: ) return # Make sure lock waiter task is running. - self._ensure_lock_task() - # Make sure the lock waiter is subscribed to avoid missing notifications. - await asyncio.wait_for( - self._lock_updates_subscribed.wait(), - timeout=min( - LOCK_SUBSCRIBE_TASK_TIMEOUT, max(self.lock_expiration / 1000, 0) - ), - ) + with contextlib.suppress(TimeoutError, asyncio.TimeoutError): + await self._ensure_lock_task_subscribed() async with ( self._lock_waiter(lock_key) as lock_released_event, self._request_lock_release(lock_key, lock_id),