Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions reflex/istate/manager/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def _default_oplock_hold_time_ms() -> int:
)


# The lock waiter task should subscribe to lock channel updates within this period.
LOCK_SUBSCRIBE_TASK_TIMEOUT = 2 # seconds


SMR = f"[SMR:{os.getpid()}]"
start = time.monotonic()

Expand Down Expand Up @@ -153,6 +157,10 @@ class StateManagerRedis(StateManager):
default_factory=dict,
init=False,
)
_lock_updates_subscribed: asyncio.Event = dataclasses.field(
default_factory=asyncio.Event,
init=False,
)
_lock_task: asyncio.Task | None = dataclasses.field(default=None, init=False)

# Whether debug prints are enabled.
Expand Down Expand Up @@ -802,8 +810,12 @@ async def _subscribe_lock_updates(self):
}
async with self.redis.pubsub() as pubsub:
await pubsub.psubscribe(**handlers) # pyright: ignore[reportArgumentType]
async for _ in pubsub.listen():
pass
self._lock_updates_subscribed.set()
try:
async for _ in pubsub.listen():
pass
finally:
self._lock_updates_subscribed.clear()

def _ensure_lock_task(self) -> None:
"""Ensure the lock updates subscriber task is running."""
Expand All @@ -814,6 +826,30 @@ def _ensure_lock_task(self) -> None:
suppress_exceptions=[Exception],
)

async def _ensure_lock_task_subscribed(self, timeout: float | None = None) -> None:
"""Ensure the lock updates subscriber task is running and subscribed to avoid missing notifications.

Args:
timeout: How long to wait for the subscriber to be subscribed before
raising an error. If None, defaults to
min(LOCK_SUBSCRIBE_TASK_TIMEOUT, lock_expiration).

Raises:
TimeoutError: If the lock updates subscriber task fails to subscribe in time.
"""
if timeout is None:
timeout = min(
LOCK_SUBSCRIBE_TASK_TIMEOUT,
max(self.lock_expiration / 1000, 0),
)
# Make sure lock waiter task is running.
self._ensure_lock_task()
# Make sure the lock waiter is subscribed to avoid missing notifications.
await asyncio.wait_for(
self._lock_updates_subscribed.wait(),
timeout=timeout,
)

async def _enable_keyspace_notifications(self):
"""Enable keyspace notifications for the redis server.

Expand Down Expand Up @@ -970,7 +1006,8 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
)
return
# Make sure lock waiter task is running.
self._ensure_lock_task()
with contextlib.suppress(TimeoutError, asyncio.TimeoutError):
await self._ensure_lock_task_subscribed()
async with (
self._lock_waiter(lock_key) as lock_released_event,
self._request_lock_release(lock_key, lock_id),
Expand Down
44 changes: 37 additions & 7 deletions tests/units/istate/manager/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ class RedisTestState(BaseState):
count: int = 0


class SubState1(RedisTestState):
"""A test substate for redis state manager tests."""


class SubState2(RedisTestState):
"""A test substate for redis state manager tests."""


@pytest.fixture
def root_state() -> type[RedisTestState]:

Expand Down Expand Up @@ -65,6 +73,22 @@ def event_log(state_manager_redis: StateManagerRedis) -> list[dict[str, Any]]:
return state_manager_redis.redis._internals["event_log"] # pyright: ignore[reportAttributeAccessIssue]


@pytest.fixture
def event_log_on_update(state_manager_redis: StateManagerRedis) -> asyncio.Event:
"""Get the event for new event records being added to the redis event log.

Test is responsible for calling `.clear` before an operation when it needs
to detect a new event added afterward.

Args:
state_manager_redis: The StateManagerRedis.

Returns:
The event that is set when new events are added to the redis event log.
"""
return state_manager_redis.redis._internals["event_log_on_update"] # pyright: ignore[reportAttributeAccessIssue]


@pytest.mark.asyncio
async def test_basic_get_set(
state_manager_redis: StateManagerRedis,
Expand Down Expand Up @@ -123,13 +147,15 @@ async def test_modify_oplock(
state_manager_redis: StateManagerRedis,
root_state: type[RedisTestState],
event_log: list[dict[str, Any]],
event_log_on_update: asyncio.Event,
):
"""Test modifying state with StateManagerRedis with optimistic locking.

Args:
state_manager_redis: The StateManagerRedis to test.
root_state: The root state class.
event_log: The redis event log.
event_log_on_update: The event for new event records being added to the redis event log.
"""
token = str(uuid.uuid4())

Expand All @@ -143,6 +169,8 @@ async def test_modify_oplock(
state_manager_2._debug_enabled = True
state_manager_2._oplock_enabled = True

event_log_on_update.clear()

# Initial modify should set count to 1
async with state_manager_redis.modify_state(
_substate_key(token, root_state),
Expand All @@ -159,6 +187,7 @@ async def test_modify_oplock(
assert state_lock_1 is not None
assert not state_lock_1.locked()

await event_log_on_update.wait()
lock_events_before = len([
ev
for ev in event_log
Expand All @@ -182,6 +211,7 @@ async def test_modify_oplock(
assert lock_events_before == lock_events_after

# Contend the lock from another state manager
event_log_on_update.clear()
async with state_manager_2.modify_state(
_substate_key(token, root_state),
) as new_state:
Expand All @@ -203,6 +233,7 @@ async def test_modify_oplock(
assert token not in state_manager_redis._cached_states

# There should have been another redis lock taken.
await event_log_on_update.wait()
lock_events_after_2 = len([
ev
for ev in event_log
Expand All @@ -228,7 +259,9 @@ async def test_modify_oplock(
assert token_set_events == 1

# Now close the contender to release its lease.
event_log_on_update.clear()
await state_manager_2.close()
await event_log_on_update.wait()

# Both locks should have been released.
unlock_events = len([
Expand Down Expand Up @@ -562,13 +595,6 @@ async def test_oplock_fetch_substate(
root_state: The root state class.
event_log: The redis event log.
"""

class SubState1(root_state):
pass

class SubState2(root_state):
pass

token = str(uuid.uuid4())

state_manager_redis._debug_enabled = True
Expand Down Expand Up @@ -627,6 +653,7 @@ async def test_oplock_hold_oplock_after_cancel(
state_manager_redis: StateManagerRedis,
root_state: type[RedisTestState],
event_log: list[dict[str, Any]],
event_log_on_update: asyncio.Event,
short_lock_expiration: int,
):
"""Test that cancelling a modify does not release the oplock prematurely.
Expand All @@ -635,6 +662,7 @@ async def test_oplock_hold_oplock_after_cancel(
state_manager_redis: The StateManagerRedis to test.
root_state: The root state class.
event_log: The redis event log.
event_log_on_update: The event log update event.
short_lock_expiration: The lock expiration time in milliseconds.
"""
token = str(uuid.uuid4())
Expand Down Expand Up @@ -683,13 +711,15 @@ async def modify():
await lease_task

# Modify the state again, this should get a new lock and lease
event_log_on_update.clear()
async with state_manager_redis.modify_state(
_substate_key(token, root_state),
) as new_state:
assert isinstance(new_state, root_state)
new_state.count += 1

# There should have been two redis lock acquisitions.
await event_log_on_update.wait()
lock_events = len([
ev
for ev in event_log
Expand Down
28 changes: 22 additions & 6 deletions tests/units/mock_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def mock_redis() -> Redis:
expire_times: dict[bytes, float] = {}
event_log: list[dict[str, bytes]] = []
event_log_new_events = asyncio.Event()
event_log_on_update = asyncio.Event()

def _event_log_append_notify(event: dict[str, bytes]) -> None:
event_log.append(event)
event_log_new_events.set()
event_log_on_update.set()

def _key_bytes(key: KeyT) -> bytes:
if isinstance(key, str):
Expand All @@ -39,8 +45,7 @@ def _keyspace_event(key: KeyT, data: str | bytes):
key = key.encode()
if isinstance(data, str):
data = data.encode()
event_log.append({"channel": b"__keyspace@1__:" + key, "data": data})
event_log_new_events.set()
_event_log_append_notify({"channel": b"__keyspace@1__:" + key, "data": data})

def _expire_keys():
to_delete = []
Expand Down Expand Up @@ -192,12 +197,16 @@ async def psubscribe( # noqa: RUF029

for pattern in patterns:
watch_patterns[pattern] = None
event_log.append({"channel": b"psubscribe", "data": pattern.encode()})
event_log_new_events.set()
_event_log_append_notify({
"channel": b"psubscribe",
"data": pattern.encode(),
})
for pattern, handler in handlers.items():
watch_patterns[pattern] = handler
event_log.append({"channel": b"psubscribe", "data": pattern.encode()})
event_log_new_events.set()
_event_log_append_notify({
"channel": b"psubscribe",
"data": pattern.encode(),
})

async def listen() -> AsyncGenerator[dict[str, Any] | None, None]:
nonlocal event_log_pointer
Expand Down Expand Up @@ -243,6 +252,7 @@ async def listen() -> AsyncGenerator[dict[str, Any] | None, None]:
"keys": keys,
"expire_times": expire_times,
"event_log": event_log,
"event_log_on_update": event_log_on_update,
}
return redis_mock

Expand All @@ -261,24 +271,30 @@ async def real_redis() -> AsyncGenerator[Redis | None]:

# Create a pubsub to keep the internal event log for assertions.
event_log = []
event_log_on_update = asyncio.Event()
object.__setattr__(
redis,
"_internals",
{
"event_log": event_log,
"event_log_on_update": event_log_on_update,
},
)
redis_db = redis.get_connection_kwargs().get("db", 0)
pubsub_ready = asyncio.Event()

async def log_events():
async with redis.pubsub() as pubsub:
await pubsub.psubscribe(f"__keyspace@{redis_db}__:*")
pubsub_ready.set()
async for message in pubsub.listen():
if message is None:
continue
event_log.append(message)
event_log_on_update.set()

log_events_task = asyncio.create_task(log_events())
await pubsub_ready.wait()
try:
yield redis
finally:
Expand Down
Loading