Skip to content

Commit bd93af0

Browse files
Fix race conditions in StateManagerRedis lock detection (#6196)
* Fix race conditions in StateManagerRedis lock detection Fix other race conditions also present in the tests. In short: both the framework and tests must deterministically wait for the Redis pubsub psubscribe call to complete to ensure that events, such as lock release, are properly recorded. Co-authored-by: Farhan Ali Raza <farhanalirazaazeemi@gmail.com> * bounded wait for lock waiter's psubscribe to start * create _ensure_lock_task_subscribed helper to encapsulate logic ignore TimeoutError for this call, because we still want to enter the lock waiter loop, even if there was some redis problem, just the waiter might have to wait for the full lock expiration (fallback case to previous behavior). but we should _NEVER_ deadlock the process waiting for the psubscribe event and we should _NEVER_ exit _wait_for just because the timeout for the previous expired. --------- Co-authored-by: Farhan Ali Raza <farhanalirazaazeemi@gmail.com>
1 parent 0a5e76e commit bd93af0

File tree

3 files changed

+99
-16
lines changed

3 files changed

+99
-16
lines changed

reflex/istate/manager/redis.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def _default_oplock_hold_time_ms() -> int:
6161
)
6262

6363

64+
# The lock waiter task should subscribe to lock channel updates within this period.
65+
LOCK_SUBSCRIBE_TASK_TIMEOUT = 2 # seconds
66+
67+
6468
SMR = f"[SMR:{os.getpid()}]"
6569
start = time.monotonic()
6670

@@ -153,6 +157,10 @@ class StateManagerRedis(StateManager):
153157
default_factory=dict,
154158
init=False,
155159
)
160+
_lock_updates_subscribed: asyncio.Event = dataclasses.field(
161+
default_factory=asyncio.Event,
162+
init=False,
163+
)
156164
_lock_task: asyncio.Task | None = dataclasses.field(default=None, init=False)
157165

158166
# Whether debug prints are enabled.
@@ -802,8 +810,12 @@ async def _subscribe_lock_updates(self):
802810
}
803811
async with self.redis.pubsub() as pubsub:
804812
await pubsub.psubscribe(**handlers) # pyright: ignore[reportArgumentType]
805-
async for _ in pubsub.listen():
806-
pass
813+
self._lock_updates_subscribed.set()
814+
try:
815+
async for _ in pubsub.listen():
816+
pass
817+
finally:
818+
self._lock_updates_subscribed.clear()
807819

808820
def _ensure_lock_task(self) -> None:
809821
"""Ensure the lock updates subscriber task is running."""
@@ -814,6 +826,30 @@ def _ensure_lock_task(self) -> None:
814826
suppress_exceptions=[Exception],
815827
)
816828

829+
async def _ensure_lock_task_subscribed(self, timeout: float | None = None) -> None:
830+
"""Ensure the lock updates subscriber task is running and subscribed to avoid missing notifications.
831+
832+
Args:
833+
timeout: How long to wait for the subscriber to be subscribed before
834+
raising an error. If None, defaults to
835+
min(LOCK_SUBSCRIBE_TASK_TIMEOUT, lock_expiration).
836+
837+
Raises:
838+
TimeoutError: If the lock updates subscriber task fails to subscribe in time.
839+
"""
840+
if timeout is None:
841+
timeout = min(
842+
LOCK_SUBSCRIBE_TASK_TIMEOUT,
843+
max(self.lock_expiration / 1000, 0),
844+
)
845+
# Make sure lock waiter task is running.
846+
self._ensure_lock_task()
847+
# Make sure the lock waiter is subscribed to avoid missing notifications.
848+
await asyncio.wait_for(
849+
self._lock_updates_subscribed.wait(),
850+
timeout=timeout,
851+
)
852+
817853
async def _enable_keyspace_notifications(self):
818854
"""Enable keyspace notifications for the redis server.
819855
@@ -970,7 +1006,8 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
9701006
)
9711007
return
9721008
# Make sure lock waiter task is running.
973-
self._ensure_lock_task()
1009+
with contextlib.suppress(TimeoutError, asyncio.TimeoutError):
1010+
await self._ensure_lock_task_subscribed()
9741011
async with (
9751012
self._lock_waiter(lock_key) as lock_released_event,
9761013
self._request_lock_release(lock_key, lock_id),

tests/units/istate/manager/test_redis.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ class RedisTestState(BaseState):
2222
count: int = 0
2323

2424

25+
class SubState1(RedisTestState):
26+
"""A test substate for redis state manager tests."""
27+
28+
29+
class SubState2(RedisTestState):
30+
"""A test substate for redis state manager tests."""
31+
32+
2533
@pytest.fixture
2634
def root_state() -> type[RedisTestState]:
2735

@@ -65,6 +73,22 @@ def event_log(state_manager_redis: StateManagerRedis) -> list[dict[str, Any]]:
6573
return state_manager_redis.redis._internals["event_log"] # pyright: ignore[reportAttributeAccessIssue]
6674

6775

76+
@pytest.fixture
77+
def event_log_on_update(state_manager_redis: StateManagerRedis) -> asyncio.Event:
78+
"""Get the event for new event records being added to the redis event log.
79+
80+
Test is responsible for calling `.clear` before an operation when it needs
81+
to detect a new event added afterward.
82+
83+
Args:
84+
state_manager_redis: The StateManagerRedis.
85+
86+
Returns:
87+
The event that is set when new events are added to the redis event log.
88+
"""
89+
return state_manager_redis.redis._internals["event_log_on_update"] # pyright: ignore[reportAttributeAccessIssue]
90+
91+
6892
@pytest.mark.asyncio
6993
async def test_basic_get_set(
7094
state_manager_redis: StateManagerRedis,
@@ -123,13 +147,15 @@ async def test_modify_oplock(
123147
state_manager_redis: StateManagerRedis,
124148
root_state: type[RedisTestState],
125149
event_log: list[dict[str, Any]],
150+
event_log_on_update: asyncio.Event,
126151
):
127152
"""Test modifying state with StateManagerRedis with optimistic locking.
128153
129154
Args:
130155
state_manager_redis: The StateManagerRedis to test.
131156
root_state: The root state class.
132157
event_log: The redis event log.
158+
event_log_on_update: The event for new event records being added to the redis event log.
133159
"""
134160
token = str(uuid.uuid4())
135161

@@ -143,6 +169,8 @@ async def test_modify_oplock(
143169
state_manager_2._debug_enabled = True
144170
state_manager_2._oplock_enabled = True
145171

172+
event_log_on_update.clear()
173+
146174
# Initial modify should set count to 1
147175
async with state_manager_redis.modify_state(
148176
_substate_key(token, root_state),
@@ -159,6 +187,7 @@ async def test_modify_oplock(
159187
assert state_lock_1 is not None
160188
assert not state_lock_1.locked()
161189

190+
await event_log_on_update.wait()
162191
lock_events_before = len([
163192
ev
164193
for ev in event_log
@@ -182,6 +211,7 @@ async def test_modify_oplock(
182211
assert lock_events_before == lock_events_after
183212

184213
# Contend the lock from another state manager
214+
event_log_on_update.clear()
185215
async with state_manager_2.modify_state(
186216
_substate_key(token, root_state),
187217
) as new_state:
@@ -203,6 +233,7 @@ async def test_modify_oplock(
203233
assert token not in state_manager_redis._cached_states
204234

205235
# There should have been another redis lock taken.
236+
await event_log_on_update.wait()
206237
lock_events_after_2 = len([
207238
ev
208239
for ev in event_log
@@ -228,7 +259,9 @@ async def test_modify_oplock(
228259
assert token_set_events == 1
229260

230261
# Now close the contender to release its lease.
262+
event_log_on_update.clear()
231263
await state_manager_2.close()
264+
await event_log_on_update.wait()
232265

233266
# Both locks should have been released.
234267
unlock_events = len([
@@ -562,13 +595,6 @@ async def test_oplock_fetch_substate(
562595
root_state: The root state class.
563596
event_log: The redis event log.
564597
"""
565-
566-
class SubState1(root_state):
567-
pass
568-
569-
class SubState2(root_state):
570-
pass
571-
572598
token = str(uuid.uuid4())
573599

574600
state_manager_redis._debug_enabled = True
@@ -627,6 +653,7 @@ async def test_oplock_hold_oplock_after_cancel(
627653
state_manager_redis: StateManagerRedis,
628654
root_state: type[RedisTestState],
629655
event_log: list[dict[str, Any]],
656+
event_log_on_update: asyncio.Event,
630657
short_lock_expiration: int,
631658
):
632659
"""Test that cancelling a modify does not release the oplock prematurely.
@@ -635,6 +662,7 @@ async def test_oplock_hold_oplock_after_cancel(
635662
state_manager_redis: The StateManagerRedis to test.
636663
root_state: The root state class.
637664
event_log: The redis event log.
665+
event_log_on_update: The event log update event.
638666
short_lock_expiration: The lock expiration time in milliseconds.
639667
"""
640668
token = str(uuid.uuid4())
@@ -683,13 +711,15 @@ async def modify():
683711
await lease_task
684712

685713
# Modify the state again, this should get a new lock and lease
714+
event_log_on_update.clear()
686715
async with state_manager_redis.modify_state(
687716
_substate_key(token, root_state),
688717
) as new_state:
689718
assert isinstance(new_state, root_state)
690719
new_state.count += 1
691720

692721
# There should have been two redis lock acquisitions.
722+
await event_log_on_update.wait()
693723
lock_events = len([
694724
ev
695725
for ev in event_log

tests/units/mock_redis.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ def mock_redis() -> Redis:
2626
expire_times: dict[bytes, float] = {}
2727
event_log: list[dict[str, bytes]] = []
2828
event_log_new_events = asyncio.Event()
29+
event_log_on_update = asyncio.Event()
30+
31+
def _event_log_append_notify(event: dict[str, bytes]) -> None:
32+
event_log.append(event)
33+
event_log_new_events.set()
34+
event_log_on_update.set()
2935

3036
def _key_bytes(key: KeyT) -> bytes:
3137
if isinstance(key, str):
@@ -39,8 +45,7 @@ def _keyspace_event(key: KeyT, data: str | bytes):
3945
key = key.encode()
4046
if isinstance(data, str):
4147
data = data.encode()
42-
event_log.append({"channel": b"__keyspace@1__:" + key, "data": data})
43-
event_log_new_events.set()
48+
_event_log_append_notify({"channel": b"__keyspace@1__:" + key, "data": data})
4449

4550
def _expire_keys():
4651
to_delete = []
@@ -192,12 +197,16 @@ async def psubscribe( # noqa: RUF029
192197

193198
for pattern in patterns:
194199
watch_patterns[pattern] = None
195-
event_log.append({"channel": b"psubscribe", "data": pattern.encode()})
196-
event_log_new_events.set()
200+
_event_log_append_notify({
201+
"channel": b"psubscribe",
202+
"data": pattern.encode(),
203+
})
197204
for pattern, handler in handlers.items():
198205
watch_patterns[pattern] = handler
199-
event_log.append({"channel": b"psubscribe", "data": pattern.encode()})
200-
event_log_new_events.set()
206+
_event_log_append_notify({
207+
"channel": b"psubscribe",
208+
"data": pattern.encode(),
209+
})
201210

202211
async def listen() -> AsyncGenerator[dict[str, Any] | None, None]:
203212
nonlocal event_log_pointer
@@ -243,6 +252,7 @@ async def listen() -> AsyncGenerator[dict[str, Any] | None, None]:
243252
"keys": keys,
244253
"expire_times": expire_times,
245254
"event_log": event_log,
255+
"event_log_on_update": event_log_on_update,
246256
}
247257
return redis_mock
248258

@@ -261,24 +271,30 @@ async def real_redis() -> AsyncGenerator[Redis | None]:
261271

262272
# Create a pubsub to keep the internal event log for assertions.
263273
event_log = []
274+
event_log_on_update = asyncio.Event()
264275
object.__setattr__(
265276
redis,
266277
"_internals",
267278
{
268279
"event_log": event_log,
280+
"event_log_on_update": event_log_on_update,
269281
},
270282
)
271283
redis_db = redis.get_connection_kwargs().get("db", 0)
284+
pubsub_ready = asyncio.Event()
272285

273286
async def log_events():
274287
async with redis.pubsub() as pubsub:
275288
await pubsub.psubscribe(f"__keyspace@{redis_db}__:*")
289+
pubsub_ready.set()
276290
async for message in pubsub.listen():
277291
if message is None:
278292
continue
279293
event_log.append(message)
294+
event_log_on_update.set()
280295

281296
log_events_task = asyncio.create_task(log_events())
297+
await pubsub_ready.wait()
282298
try:
283299
yield redis
284300
finally:

0 commit comments

Comments
 (0)