@@ -2826,6 +2826,14 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
28262826 """
28272827 yield self .state ()
28282828
2829+ async def disconnect (self , token : str ) -> None :
2830+ """Disconnect the client with the given token.
2831+
2832+ Args:
2833+ token: The token to disconnect.
2834+ """
2835+ pass
2836+
28292837
28302838class StateManagerMemory (StateManager ):
28312839 """A state manager that stores states in memory."""
@@ -2895,6 +2903,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
28952903 yield state
28962904 await self .set_state (token , state )
28972905
2906+ @override
2907+ async def disconnect (self , token : str ) -> None :
2908+ """Disconnect the client with the given token.
2909+
2910+ Args:
2911+ token: The token to disconnect.
2912+ """
2913+ if token in self .states :
2914+ del self .states [token ]
2915+ if lock := self ._states_locks .get (token ):
2916+ if lock .locked ():
2917+ lock .release ()
2918+ del self ._states_locks [token ]
2919+
28982920
28992921def _default_token_expiration () -> int :
29002922 """Get the default token expiration time.
@@ -3183,6 +3205,9 @@ class StateManagerRedis(StateManager):
31833205 b"evicted" ,
31843206 }
31853207
3208+ # This lock is used to ensure we only subscribe to keyspace events once per token and worker
3209+ _pubsub_locks : Dict [bytes , asyncio .Lock ] = pydantic .PrivateAttr ({})
3210+
31863211 async def _get_parent_state (
31873212 self , token : str , state : BaseState | None = None
31883213 ) -> BaseState | None :
@@ -3458,7 +3483,9 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
34583483 # Some redis servers only allow out-of-band configuration, so ignore errors here.
34593484 if not environment .REFLEX_IGNORE_REDIS_CONFIG_ERROR .get ():
34603485 raise
3461- async with self .redis .pubsub () as pubsub :
3486+ if lock_key not in self ._pubsub_locks :
3487+ self ._pubsub_locks [lock_key ] = asyncio .Lock ()
3488+ async with self ._pubsub_locks [lock_key ], self .redis .pubsub () as pubsub :
34623489 await pubsub .psubscribe (lock_key_channel )
34633490 while not state_is_locked :
34643491 # wait for the lock to be released
@@ -3475,6 +3502,19 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
34753502 break
34763503 state_is_locked = await self ._try_get_lock (lock_key , lock_id )
34773504
3505+ @override
3506+ async def disconnect (self , token : str ):
3507+ """Disconnect the token from the redis client.
3508+
3509+ Args:
3510+ token: The token to disconnect.
3511+ """
3512+ lock_key = self ._lock_key (token )
3513+ if lock := self ._pubsub_locks .get (lock_key ):
3514+ if lock .locked ():
3515+ lock .release ()
3516+ del self ._pubsub_locks [lock_key ]
3517+
34783518 @contextlib .asynccontextmanager
34793519 async def _lock (self , token : str ):
34803520 """Obtain a redis lock for a token.
0 commit comments