Skip to content

Commit 46232ac

Browse files
committed
Use cache_key and lock_key in StateManagerDisk and StateManagerRedis
1 parent 3a0f532 commit 46232ac

2 files changed

Lines changed: 50 additions & 57 deletions

File tree

reflex/istate/manager/disk.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,8 @@ async def get_state(
168168
Returns:
169169
The state for the token.
170170
"""
171-
if isinstance(token, BaseStateToken):
172-
root_state = self.states.get(token.ident)
173-
self._token_last_touched[token.ident] = time.time()
174-
else:
175-
root_state = self.states.get(str(token))
176-
self._token_last_touched[str(token)] = time.time()
171+
root_state = self.states.get(token.cache_key)
172+
self._token_last_touched[token.cache_key] = time.time()
177173
if root_state is not None:
178174
# Retrieved state from memory.
179175
return root_state
@@ -194,13 +190,13 @@ async def get_state(
194190
# Ensure all substates exist, even if they were not serialized previously.
195191
root_state.substates = fresh_root_state.substates
196192
await self.populate_substates(token, root_state, root_state)
197-
self.states[token.ident] = root_state
193+
self.states[token.cache_key] = root_state
198194
return cast(TOKEN_TYPE, root_state)
199195
# For non-BaseState tokens, if the deserialized state is None, we create a new instance using the token's cls.
200196
state = await self.load_state(token)
201197
if state is None:
202198
state = token.cls()
203-
self.states[str(token)] = state
199+
self.states[token.cache_key] = state
204200
return cast(TOKEN_TYPE, state)
205201

206202
async def set_state_for_substate(
@@ -275,10 +271,10 @@ async def _process_write_queue(self):
275271
token, self._write_queue.pop(token).state
276272
)
277273
# Check for expired states to purge.
278-
for token_ident, last_touched in list(self._token_last_touched.items()):
274+
for cache_key, last_touched in list(self._token_last_touched.items()):
279275
if now - last_touched > self.token_expiration:
280-
self._token_last_touched.pop(token_ident)
281-
self.states.pop(token_ident, None)
276+
self._token_last_touched.pop(cache_key)
277+
self.states.pop(cache_key, None)
282278
await run_in_thread(self._purge_expired_states)
283279
await self._process_write_queue_delay()
284280
except asyncio.CancelledError: # noqa: PERF203
@@ -363,12 +359,13 @@ async def modify_state(
363359
The state for the token.
364360
"""
365361
# Disk state manager ignores the substate suffix and always returns the top-level state.
366-
if token.ident not in self._states_locks:
362+
lock_key = token.lock_key
363+
if lock_key not in self._states_locks:
367364
async with self._state_manager_lock:
368-
if token.ident not in self._states_locks:
369-
self._states_locks[token.ident] = asyncio.Lock()
365+
if lock_key not in self._states_locks:
366+
self._states_locks[lock_key] = asyncio.Lock()
370367

371-
async with self._states_locks[token.ident]:
368+
async with self._states_locks[lock_key]:
372369
state = await self.get_state(token)
373370
yield state
374371
await self.set_state(token, state, **context)

reflex/istate/manager/redis.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ async def set_state(
362362
Args:
363363
token: The token to set the state for.
364364
state: The state to set.
365-
lock_id: If provided, the lock_key must be set to this value to set the state.
365+
lock_id: If provided, the lock must be held with this value to set the state.
366366
context: The event context.
367367
368368
Raises:
@@ -397,9 +397,9 @@ async def set_state(
397397

398398
base_state = cast(BaseState, state)
399399

400-
client_token = token.ident
400+
lock_key = token.lock_key
401401

402-
if lock_id is not None and client_token not in self._local_leases:
402+
if lock_id is not None and lock_key not in self._local_leases:
403403
time_taken = (
404404
self.lock_expiration - (await self.redis.pttl(self._lock_key(token)))
405405
) / 1000
@@ -424,7 +424,7 @@ async def set_state(
424424
lock_id=lock_id,
425425
**context,
426426
),
427-
name=f"reflex_set_state|{client_token}|{substate.get_full_name()}",
427+
name=f"reflex_set_state|{lock_key}|{substate.get_full_name()}",
428428
)
429429
for substate in base_state.substates.values()
430430
]
@@ -472,7 +472,7 @@ async def _try_modify_state(
472472
return
473473

474474
# Opportunistic locking is enabled, so try to hold the lock across multiple calls.
475-
client_token = token.ident
475+
lock_key = token.lock_key
476476
lock_held_ctx = contextlib.AsyncExitStack()
477477
try:
478478
lock_id = await lock_held_ctx.enter_async_context(
@@ -484,12 +484,12 @@ async def _try_modify_state(
484484
else:
485485
# Do not create a lease break task when multiple instances are waiting.
486486
if (
487-
not await self._get_local_lease(client_token)
487+
not await self._get_local_lease(lock_key)
488488
and await self._n_lock_contenders(self._lock_key(token)) > 0
489489
):
490490
if self._debug_enabled:
491491
console.debug(
492-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} has contention, not leasing"
492+
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} has contention, not leasing"
493493
)
494494
async with lock_held_ctx:
495495
state = await self.get_state(token)
@@ -503,19 +503,19 @@ async def _try_modify_state(
503503
token, lock_id, cleanup_ctx=lock_held_ctx, **context
504504
)
505505
) is (
506-
current_lease_task := await self._get_local_lease(client_token)
506+
current_lease_task := await self._get_local_lease(lock_key)
507507
) and new_lease_task is not None:
508508
if self._debug_enabled:
509509
console.debug(
510-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} obtained lock {lock_id.decode()}."
510+
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} obtained lock {lock_id.decode()}."
511511
)
512512
elif current_lease_task is None:
513513
# Check if we still have the redis lock, then just try to send this one update and release it.
514514
await self._try_extend_lock(self._lock_key(token))
515515
if await self.redis.get(self._lock_key(token)) == lock_id:
516516
if self._debug_enabled:
517517
console.debug(
518-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} holding lock {lock_id.decode()}, {new_lease_task=} already exited, doing single update..."
518+
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} holding lock {lock_id.decode()}, {new_lease_task=} already exited, doing single update..."
519519
)
520520
async with lock_held_ctx:
521521
state = await self.get_state(token)
@@ -524,7 +524,7 @@ async def _try_modify_state(
524524
return
525525
elif self._debug_enabled:
526526
console.debug(
527-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lock {lock_id.decode()} expired while waiting for lease task to exit..."
527+
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lock {lock_id.decode()} expired while waiting for lease task to exit..."
528528
)
529529
# Have to retry getting the state, but now it's probably cached.
530530
yield None
@@ -561,17 +561,15 @@ async def _get_state_cached(
561561
Yields:
562562
The cached state for the token, or None if not cached/uncachable.
563563
"""
564-
client_token = token.ident
564+
lock_key = token.lock_key
565565
# Opportunistically reuse existing lock.
566566
if (
567-
client_token in self._local_leases
568-
and (state_lock := self._cached_states_locks.get(client_token)) is not None
567+
lock_key in self._local_leases
568+
and (state_lock := self._cached_states_locks.get(lock_key)) is not None
569569
):
570570
async with state_lock:
571-
if await self._get_local_lease(client_token) is not None:
572-
if (
573-
cached_state := self._cached_states.get(client_token)
574-
) is not None:
571+
if await self._get_local_lease(lock_key) is not None:
572+
if (cached_state := self._cached_states.get(lock_key)) is not None:
575573
if isinstance(token, BaseStateToken):
576574
# Make sure we have the substate cached (or fetch it from redis).
577575
state_path = token.cls.get_full_name()
@@ -592,11 +590,11 @@ async def _get_state_cached(
592590
return
593591
elif self._debug_enabled:
594592
console.debug(
595-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease task found, lock held, but no cached state"
593+
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease task found, lock held, but no cached state"
596594
)
597595
elif self._debug_enabled:
598596
console.debug(
599-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} no active lease task found"
597+
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} no active lease task found"
600598
)
601599
yield None
602600

@@ -636,32 +634,32 @@ async def _create_lease_break_task(
636634
"""
637635
self._ensure_lock_task()
638636

639-
client_token = token.ident
637+
lock_key = token.lock_key
640638

641639
async def do_flush() -> None:
642-
if (state_lock := self._cached_states_locks.get(client_token)) is None:
640+
if (state_lock := self._cached_states_locks.get(lock_key)) is None:
643641
# If we lost the lock, we can't write the state, something went wrong.
644642
console.warn(
645-
f"State lock for {client_token} missing while finalizing lease."
643+
f"State lock for {lock_key} missing while finalizing lease."
646644
)
647645
return
648646
async with state_lock:
649647
# Write the state to redis while no one else can modify the cached copy.
650-
state = self._cached_states.pop(client_token, None)
648+
state = self._cached_states.pop(lock_key, None)
651649
try:
652650
if state:
653651
if self._debug_enabled:
654652
console.debug(
655-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} flushing state"
653+
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} flushing state"
656654
)
657655
await self.set_state(token, state, lock_id=lock_id, **context)
658656
finally:
659-
if (current_lease := self._local_leases.get(client_token)) is task:
660-
self._local_leases.pop(client_token, None)
657+
if (current_lease := self._local_leases.get(lock_key)) is task:
658+
self._local_leases.pop(lock_key, None)
661659
# TODO: clean up the cached states locks periodically
662660
elif self._debug_enabled:
663661
console.debug(
664-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} cleanup of {task=} found different task in _local_leases {current_lease=}."
662+
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} cleanup of {task=} found different task in _local_leases {current_lease=}."
665663
)
666664

667665
async def lease_breaker():
@@ -670,7 +668,7 @@ async def lease_breaker():
670668
lease_break_time = self.oplock_hold_time_ms / 1000
671669
if self._debug_enabled:
672670
console.debug(
673-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} started, sleeping for {lease_break_time}s"
671+
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} started, sleeping for {lease_break_time}s"
674672
)
675673
try:
676674
await asyncio.sleep(lease_break_time)
@@ -679,7 +677,7 @@ async def lease_breaker():
679677
# We got cancelled so if someone is holding the lock,
680678
# extend the timeout so they get the full time to complete.
681679
if (
682-
state_lock := self._cached_states_locks[client_token]
680+
state_lock := self._cached_states_locks[lock_key]
683681
) is not None and state_lock.locked():
684682
await self._try_extend_lock(self._lock_key(token))
685683
try:
@@ -698,36 +696,34 @@ async def lease_breaker():
698696
if cancelled_error is not None:
699697
raise cancelled_error
700698

701-
if (state_lock := self._cached_states_locks.get(client_token)) is not None:
699+
if (state_lock := self._cached_states_locks.get(lock_key)) is not None:
702700
# We have an existing lock, so lets see if we have an existing lease to cancel.
703701
async with state_lock:
704-
if (existing_task := self._local_leases.get(client_token)) is not None:
702+
if (existing_task := self._local_leases.get(lock_key)) is not None:
705703
# There's already a lease break task, so cancel it to clear it out.
706704
existing_task.cancel()
707705
if existing_task is not None:
708706
with contextlib.suppress(asyncio.CancelledError):
709707
await existing_task
710708

711709
# Now we might need to create a new lock.
712-
if (state_lock := self._cached_states_locks.get(client_token)) is None:
710+
if (state_lock := self._cached_states_locks.get(lock_key)) is None:
713711
async with self._state_manager_lock:
714-
if (state_lock := self._cached_states_locks.get(client_token)) is None:
715-
state_lock = self._cached_states_locks[client_token] = (
716-
asyncio.Lock()
717-
)
712+
if (state_lock := self._cached_states_locks.get(lock_key)) is None:
713+
state_lock = self._cached_states_locks[lock_key] = asyncio.Lock()
718714

719715
async with state_lock:
720716
# Create the task now if one didn't sneak past us.
721717
if (
722-
client_token not in self._local_leases
718+
lock_key not in self._local_leases
723719
and await self._n_lock_contenders(self._lock_key(token)) == 0
724720
):
725-
self._local_leases[client_token] = task = asyncio.create_task(
721+
self._local_leases[lock_key] = task = asyncio.create_task(
726722
lease_breaker(),
727-
name=f"reflex_lease_breaker|{client_token}|{lock_id.decode()}",
723+
name=f"reflex_lease_breaker|{lock_key}|{lock_id.decode()}",
728724
)
729725
# Fetch the requested state into the cache.
730-
self._cached_states[client_token] = await self.get_state(token)
726+
self._cached_states[lock_key] = await self.get_state(token)
731727
return task
732728
return None
733729

@@ -741,7 +737,7 @@ def _lock_key(token: StateToken[Any]) -> bytes:
741737
Returns:
742738
The redis lock key for the token.
743739
"""
744-
return f"{token.ident}_lock".encode()
740+
return f"{token.lock_key}_lock".encode()
745741

746742
async def _try_extend_lock(self, lock_key: bytes) -> bool | None:
747743
"""Extends the current lock for another lock_expiration period.

0 commit comments

Comments
 (0)