Skip to content

Commit fefa026

Browse files
masenfFarhanAliRaza
andcommitted
Fix race conditions in StateManagerRedis lock detection
Fix other race conditions also present in the tests. In short: both the framework and tests must deterministically wait for the Redis pubsub psubscribe call to complete to ensure that events, such as lock release, are properly recorded. Co-authored-by: Farhan Ali Raza <farhanalirazaazeemi@gmail.com>
1 parent 0a5e76e commit fefa026

File tree

3 files changed

+71
-15
lines changed

3 files changed

+71
-15
lines changed

reflex/istate/manager/redis.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ class StateManagerRedis(StateManager):
153153
default_factory=dict,
154154
init=False,
155155
)
156+
_lock_updates_subscribed: asyncio.Event = dataclasses.field(
157+
default_factory=asyncio.Event,
158+
init=False,
159+
)
156160
_lock_task: asyncio.Task | None = dataclasses.field(default=None, init=False)
157161

158162
# Whether debug prints are enabled.
@@ -802,8 +806,12 @@ async def _subscribe_lock_updates(self):
802806
}
803807
async with self.redis.pubsub() as pubsub:
804808
await pubsub.psubscribe(**handlers) # pyright: ignore[reportArgumentType]
805-
async for _ in pubsub.listen():
806-
pass
809+
self._lock_updates_subscribed.set()
810+
try:
811+
async for _ in pubsub.listen():
812+
pass
813+
finally:
814+
self._lock_updates_subscribed.clear()
807815

808816
def _ensure_lock_task(self) -> None:
809817
"""Ensure the lock updates subscriber task is running."""
@@ -971,6 +979,8 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
971979
return
972980
# Make sure lock waiter task is running.
973981
self._ensure_lock_task()
982+
# Make sure the lock waiter is subscribed to avoid missing notifications.
983+
await self._lock_updates_subscribed.wait()
974984
async with (
975985
self._lock_waiter(lock_key) as lock_released_event,
976986
self._request_lock_release(lock_key, lock_id),

tests/units/istate/manager/test_redis.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ class RedisTestState(BaseState):
2222
count: int = 0
2323

2424

25+
class SubState1(RedisTestState):
26+
"""A test substate for redis state manager tests."""
27+
28+
29+
class SubState2(RedisTestState):
30+
"""A test substate for redis state manager tests."""
31+
32+
2533
@pytest.fixture
2634
def root_state() -> type[RedisTestState]:
2735

@@ -65,6 +73,22 @@ def event_log(state_manager_redis: StateManagerRedis) -> list[dict[str, Any]]:
6573
return state_manager_redis.redis._internals["event_log"] # pyright: ignore[reportAttributeAccessIssue]
6674

6775

76+
@pytest.fixture
77+
def event_log_on_update(state_manager_redis: StateManagerRedis) -> asyncio.Event:
78+
"""Get the event for new event records being added to the redis event log.
79+
80+
Test is responsible for calling `.clear` before an operation when it needs
81+
to detect a new event added afterward.
82+
83+
Args:
84+
state_manager_redis: The StateManagerRedis.
85+
86+
Returns:
87+
The event that is set when new events are added to the redis event log.
88+
"""
89+
return state_manager_redis.redis._internals["event_log_on_update"] # pyright: ignore[reportAttributeAccessIssue]
90+
91+
6892
@pytest.mark.asyncio
6993
async def test_basic_get_set(
7094
state_manager_redis: StateManagerRedis,
@@ -123,13 +147,15 @@ async def test_modify_oplock(
123147
state_manager_redis: StateManagerRedis,
124148
root_state: type[RedisTestState],
125149
event_log: list[dict[str, Any]],
150+
event_log_on_update: asyncio.Event,
126151
):
127152
"""Test modifying state with StateManagerRedis with optimistic locking.
128153
129154
Args:
130155
state_manager_redis: The StateManagerRedis to test.
131156
root_state: The root state class.
132157
event_log: The redis event log.
158+
event_log_on_update: The event for new event records being added to the redis event log.
133159
"""
134160
token = str(uuid.uuid4())
135161

@@ -143,6 +169,8 @@ async def test_modify_oplock(
143169
state_manager_2._debug_enabled = True
144170
state_manager_2._oplock_enabled = True
145171

172+
event_log_on_update.clear()
173+
146174
# Initial modify should set count to 1
147175
async with state_manager_redis.modify_state(
148176
_substate_key(token, root_state),
@@ -159,6 +187,7 @@ async def test_modify_oplock(
159187
assert state_lock_1 is not None
160188
assert not state_lock_1.locked()
161189

190+
await event_log_on_update.wait()
162191
lock_events_before = len([
163192
ev
164193
for ev in event_log
@@ -182,6 +211,7 @@ async def test_modify_oplock(
182211
assert lock_events_before == lock_events_after
183212

184213
# Contend the lock from another state manager
214+
event_log_on_update.clear()
185215
async with state_manager_2.modify_state(
186216
_substate_key(token, root_state),
187217
) as new_state:
@@ -203,6 +233,7 @@ async def test_modify_oplock(
203233
assert token not in state_manager_redis._cached_states
204234

205235
# There should have been another redis lock taken.
236+
await event_log_on_update.wait()
206237
lock_events_after_2 = len([
207238
ev
208239
for ev in event_log
@@ -228,7 +259,9 @@ async def test_modify_oplock(
228259
assert token_set_events == 1
229260

230261
# Now close the contender to release its lease.
262+
event_log_on_update.clear()
231263
await state_manager_2.close()
264+
await event_log_on_update.wait()
232265

233266
# Both locks should have been released.
234267
unlock_events = len([
@@ -562,13 +595,6 @@ async def test_oplock_fetch_substate(
562595
root_state: The root state class.
563596
event_log: The redis event log.
564597
"""
565-
566-
class SubState1(root_state):
567-
pass
568-
569-
class SubState2(root_state):
570-
pass
571-
572598
token = str(uuid.uuid4())
573599

574600
state_manager_redis._debug_enabled = True
@@ -627,6 +653,7 @@ async def test_oplock_hold_oplock_after_cancel(
627653
state_manager_redis: StateManagerRedis,
628654
root_state: type[RedisTestState],
629655
event_log: list[dict[str, Any]],
656+
event_log_on_update: asyncio.Event,
630657
short_lock_expiration: int,
631658
):
632659
"""Test that cancelling a modify does not release the oplock prematurely.
@@ -635,6 +662,7 @@ async def test_oplock_hold_oplock_after_cancel(
635662
state_manager_redis: The StateManagerRedis to test.
636663
root_state: The root state class.
637664
event_log: The redis event log.
665+
event_log_on_update: The event log update event.
638666
short_lock_expiration: The lock expiration time in milliseconds.
639667
"""
640668
token = str(uuid.uuid4())
@@ -683,13 +711,15 @@ async def modify():
683711
await lease_task
684712

685713
# Modify the state again, this should get a new lock and lease
714+
event_log_on_update.clear()
686715
async with state_manager_redis.modify_state(
687716
_substate_key(token, root_state),
688717
) as new_state:
689718
assert isinstance(new_state, root_state)
690719
new_state.count += 1
691720

692721
# There should have been two redis lock acquisitions.
722+
await event_log_on_update.wait()
693723
lock_events = len([
694724
ev
695725
for ev in event_log

tests/units/mock_redis.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ def mock_redis() -> Redis:
2626
expire_times: dict[bytes, float] = {}
2727
event_log: list[dict[str, bytes]] = []
2828
event_log_new_events = asyncio.Event()
29+
event_log_on_update = asyncio.Event()
30+
31+
def _event_log_append_notify(event: dict[str, bytes]) -> None:
32+
event_log.append(event)
33+
event_log_new_events.set()
34+
event_log_on_update.set()
2935

3036
def _key_bytes(key: KeyT) -> bytes:
3137
if isinstance(key, str):
@@ -39,8 +45,7 @@ def _keyspace_event(key: KeyT, data: str | bytes):
3945
key = key.encode()
4046
if isinstance(data, str):
4147
data = data.encode()
42-
event_log.append({"channel": b"__keyspace@1__:" + key, "data": data})
43-
event_log_new_events.set()
48+
_event_log_append_notify({"channel": b"__keyspace@1__:" + key, "data": data})
4449

4550
def _expire_keys():
4651
to_delete = []
@@ -192,12 +197,16 @@ async def psubscribe( # noqa: RUF029
192197

193198
for pattern in patterns:
194199
watch_patterns[pattern] = None
195-
event_log.append({"channel": b"psubscribe", "data": pattern.encode()})
196-
event_log_new_events.set()
200+
_event_log_append_notify({
201+
"channel": b"psubscribe",
202+
"data": pattern.encode(),
203+
})
197204
for pattern, handler in handlers.items():
198205
watch_patterns[pattern] = handler
199-
event_log.append({"channel": b"psubscribe", "data": pattern.encode()})
200-
event_log_new_events.set()
206+
_event_log_append_notify({
207+
"channel": b"psubscribe",
208+
"data": pattern.encode(),
209+
})
201210

202211
async def listen() -> AsyncGenerator[dict[str, Any] | None, None]:
203212
nonlocal event_log_pointer
@@ -243,6 +252,7 @@ async def listen() -> AsyncGenerator[dict[str, Any] | None, None]:
243252
"keys": keys,
244253
"expire_times": expire_times,
245254
"event_log": event_log,
255+
"event_log_on_update": event_log_on_update,
246256
}
247257
return redis_mock
248258

@@ -261,24 +271,30 @@ async def real_redis() -> AsyncGenerator[Redis | None]:
261271

262272
# Create a pubsub to keep the internal event log for assertions.
263273
event_log = []
274+
event_log_on_update = asyncio.Event()
264275
object.__setattr__(
265276
redis,
266277
"_internals",
267278
{
268279
"event_log": event_log,
280+
"event_log_on_update": event_log_on_update,
269281
},
270282
)
271283
redis_db = redis.get_connection_kwargs().get("db", 0)
284+
pubsub_ready = asyncio.Event()
272285

273286
async def log_events():
274287
async with redis.pubsub() as pubsub:
275288
await pubsub.psubscribe(f"__keyspace@{redis_db}__:*")
289+
pubsub_ready.set()
276290
async for message in pubsub.listen():
277291
if message is None:
278292
continue
279293
event_log.append(message)
294+
event_log_on_update.set()
280295

281296
log_events_task = asyncio.create_task(log_events())
297+
await pubsub_ready.wait()
282298
try:
283299
yield redis
284300
finally:

0 commit comments

Comments
 (0)