diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index bbfd8e20ae2..35b87b92eaa 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -61,6 +61,10 @@ def _default_oplock_hold_time_ms() -> int: ) +# The lock waiter task should subscribe to lock channel updates within this period. +LOCK_SUBSCRIBE_TASK_TIMEOUT = 2 # seconds + + SMR = f"[SMR:{os.getpid()}]" start = time.monotonic() @@ -153,6 +157,10 @@ class StateManagerRedis(StateManager): default_factory=dict, init=False, ) + _lock_updates_subscribed: asyncio.Event = dataclasses.field( + default_factory=asyncio.Event, + init=False, + ) _lock_task: asyncio.Task | None = dataclasses.field(default=None, init=False) # Whether debug prints are enabled. @@ -802,8 +810,12 @@ async def _subscribe_lock_updates(self): } async with self.redis.pubsub() as pubsub: await pubsub.psubscribe(**handlers) # pyright: ignore[reportArgumentType] - async for _ in pubsub.listen(): - pass + self._lock_updates_subscribed.set() + try: + async for _ in pubsub.listen(): + pass + finally: + self._lock_updates_subscribed.clear() def _ensure_lock_task(self) -> None: """Ensure the lock updates subscriber task is running.""" @@ -814,6 +826,30 @@ def _ensure_lock_task(self) -> None: suppress_exceptions=[Exception], ) + async def _ensure_lock_task_subscribed(self, timeout: float | None = None) -> None: + """Ensure the lock updates subscriber task is running and subscribed to avoid missing notifications. + + Args: + timeout: How long to wait for the subscriber to be subscribed before + raising an error. If None, defaults to + min(LOCK_SUBSCRIBE_TASK_TIMEOUT, lock_expiration). + + Raises: + TimeoutError: If the lock updates subscriber task fails to subscribe in time. + """ + if timeout is None: + timeout = min( + LOCK_SUBSCRIBE_TASK_TIMEOUT, + max(self.lock_expiration / 1000, 0), + ) + # Make sure lock waiter task is running. + self._ensure_lock_task() + # Make sure the lock waiter is subscribed to avoid missing notifications. + await asyncio.wait_for( + self._lock_updates_subscribed.wait(), + timeout=timeout, + ) + async def _enable_keyspace_notifications(self): """Enable keyspace notifications for the redis server. @@ -970,7 +1006,8 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: ) return # Make sure lock waiter task is running. - self._ensure_lock_task() + with contextlib.suppress(TimeoutError, asyncio.TimeoutError): + await self._ensure_lock_task_subscribed() async with ( self._lock_waiter(lock_key) as lock_released_event, self._request_lock_release(lock_key, lock_id), diff --git a/tests/units/istate/manager/test_redis.py b/tests/units/istate/manager/test_redis.py index 076268ca9c6..d5fee452c5a 100644 --- a/tests/units/istate/manager/test_redis.py +++ b/tests/units/istate/manager/test_redis.py @@ -22,6 +22,14 @@ class RedisTestState(BaseState): count: int = 0 +class SubState1(RedisTestState): + """A test substate for redis state manager tests.""" + + +class SubState2(RedisTestState): + """A test substate for redis state manager tests.""" + + @pytest.fixture def root_state() -> type[RedisTestState]: @@ -65,6 +73,22 @@ def event_log(state_manager_redis: StateManagerRedis) -> list[dict[str, Any]]: return state_manager_redis.redis._internals["event_log"] # pyright: ignore[reportAttributeAccessIssue] +@pytest.fixture +def event_log_on_update(state_manager_redis: StateManagerRedis) -> asyncio.Event: + """Get the event for new event records being added to the redis event log. + + Test is responsible for calling `.clear` before an operation when it needs + to detect a new event added afterward. + + Args: + state_manager_redis: The StateManagerRedis. + + Returns: + The event that is set when new events are added to the redis event log. + """ + return state_manager_redis.redis._internals["event_log_on_update"] # pyright: ignore[reportAttributeAccessIssue] + + @pytest.mark.asyncio async def test_basic_get_set( state_manager_redis: StateManagerRedis, @@ -123,6 +147,7 @@ async def test_modify_oplock( state_manager_redis: StateManagerRedis, root_state: type[RedisTestState], event_log: list[dict[str, Any]], + event_log_on_update: asyncio.Event, ): """Test modifying state with StateManagerRedis with optimistic locking. @@ -130,6 +155,7 @@ async def test_modify_oplock( state_manager_redis: The StateManagerRedis to test. root_state: The root state class. event_log: The redis event log. + event_log_on_update: The event for new event records being added to the redis event log. """ token = str(uuid.uuid4()) @@ -143,6 +169,8 @@ async def test_modify_oplock( state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True + event_log_on_update.clear() + # Initial modify should set count to 1 async with state_manager_redis.modify_state( _substate_key(token, root_state), @@ -159,6 +187,7 @@ async def test_modify_oplock( assert state_lock_1 is not None assert not state_lock_1.locked() + await event_log_on_update.wait() lock_events_before = len([ ev for ev in event_log @@ -182,6 +211,7 @@ async def test_modify_oplock( assert lock_events_before == lock_events_after # Contend the lock from another state manager + event_log_on_update.clear() async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: @@ -203,6 +233,7 @@ async def test_modify_oplock( assert token not in state_manager_redis._cached_states # There should have been another redis lock taken. + await event_log_on_update.wait() lock_events_after_2 = len([ ev for ev in event_log @@ -228,7 +259,9 @@ async def test_modify_oplock( assert token_set_events == 1 # Now close the contender to release its lease. + event_log_on_update.clear() await state_manager_2.close() + await event_log_on_update.wait() # Both locks should have been released. unlock_events = len([ @@ -562,13 +595,6 @@ async def test_oplock_fetch_substate( root_state: The root state class. event_log: The redis event log. """ - - class SubState1(root_state): - pass - - class SubState2(root_state): - pass - token = str(uuid.uuid4()) state_manager_redis._debug_enabled = True @@ -627,6 +653,7 @@ async def test_oplock_hold_oplock_after_cancel( state_manager_redis: StateManagerRedis, root_state: type[RedisTestState], event_log: list[dict[str, Any]], + event_log_on_update: asyncio.Event, short_lock_expiration: int, ): """Test that cancelling a modify does not release the oplock prematurely. @@ -635,6 +662,7 @@ async def test_oplock_hold_oplock_after_cancel( state_manager_redis: The StateManagerRedis to test. root_state: The root state class. event_log: The redis event log. + event_log_on_update: The event log update event. short_lock_expiration: The lock expiration time in milliseconds. """ token = str(uuid.uuid4()) @@ -683,6 +711,7 @@ async def modify(): await lease_task # Modify the state again, this should get a new lock and lease + event_log_on_update.clear() async with state_manager_redis.modify_state( _substate_key(token, root_state), ) as new_state: @@ -690,6 +719,7 @@ async def modify(): new_state.count += 1 # There should have been two redis lock acquisitions. + await event_log_on_update.wait() lock_events = len([ ev for ev in event_log diff --git a/tests/units/mock_redis.py b/tests/units/mock_redis.py index 4329b2139f7..a9832bbcccc 100644 --- a/tests/units/mock_redis.py +++ b/tests/units/mock_redis.py @@ -26,6 +26,12 @@ def mock_redis() -> Redis: expire_times: dict[bytes, float] = {} event_log: list[dict[str, bytes]] = [] event_log_new_events = asyncio.Event() + event_log_on_update = asyncio.Event() + + def _event_log_append_notify(event: dict[str, bytes]) -> None: + event_log.append(event) + event_log_new_events.set() + event_log_on_update.set() def _key_bytes(key: KeyT) -> bytes: if isinstance(key, str): @@ -39,8 +45,7 @@ def _keyspace_event(key: KeyT, data: str | bytes): key = key.encode() if isinstance(data, str): data = data.encode() - event_log.append({"channel": b"__keyspace@1__:" + key, "data": data}) - event_log_new_events.set() + _event_log_append_notify({"channel": b"__keyspace@1__:" + key, "data": data}) def _expire_keys(): to_delete = [] @@ -192,12 +197,16 @@ async def psubscribe( # noqa: RUF029 for pattern in patterns: watch_patterns[pattern] = None - event_log.append({"channel": b"psubscribe", "data": pattern.encode()}) - event_log_new_events.set() + _event_log_append_notify({ + "channel": b"psubscribe", + "data": pattern.encode(), + }) for pattern, handler in handlers.items(): watch_patterns[pattern] = handler - event_log.append({"channel": b"psubscribe", "data": pattern.encode()}) - event_log_new_events.set() + _event_log_append_notify({ + "channel": b"psubscribe", + "data": pattern.encode(), + }) async def listen() -> AsyncGenerator[dict[str, Any] | None, None]: nonlocal event_log_pointer @@ -243,6 +252,7 @@ async def listen() -> AsyncGenerator[dict[str, Any] | None, None]: "keys": keys, "expire_times": expire_times, "event_log": event_log, + "event_log_on_update": event_log_on_update, } return redis_mock @@ -261,24 +271,30 @@ async def real_redis() -> AsyncGenerator[Redis | None]: # Create a pubsub to keep the internal event log for assertions. event_log = [] + event_log_on_update = asyncio.Event() object.__setattr__( redis, "_internals", { "event_log": event_log, + "event_log_on_update": event_log_on_update, }, ) redis_db = redis.get_connection_kwargs().get("db", 0) + pubsub_ready = asyncio.Event() async def log_events(): async with redis.pubsub() as pubsub: await pubsub.psubscribe(f"__keyspace@{redis_db}__:*") + pubsub_ready.set() async for message in pubsub.listen(): if message is None: continue event_log.append(message) + event_log_on_update.set() log_events_task = asyncio.create_task(log_events()) + await pubsub_ready.wait() try: yield redis finally: