@@ -2830,6 +2830,14 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
28302830 """
28312831 yield self .state ()
28322832
2833+ async def disconnect (self , token : str ) -> None :
2834+ """Disconnect the client with the given token.
2835+
2836+ Args:
2837+ token: The token to disconnect.
2838+ """
2839+ pass
2840+
28332841
28342842class StateManagerMemory (StateManager ):
28352843 """A state manager that stores states in memory."""
@@ -2899,6 +2907,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
28992907 yield state
29002908 await self .set_state (token , state )
29012909
2910+ @override
2911+ async def disconnect (self , token : str ) -> None :
2912+ """Disconnect the client with the given token.
2913+
2914+ Args:
2915+ token: The token to disconnect.
2916+ """
2917+ if token in self .states :
2918+ del self .states [token ]
2919+ if lock := self ._states_locks .get (token ):
2920+ if lock .locked ():
2921+ lock .release ()
2922+ del self ._states_locks [token ]
2923+
29022924
29032925def _default_token_expiration () -> int :
29042926 """Get the default token expiration time.
@@ -3187,6 +3209,9 @@ class StateManagerRedis(StateManager):
31873209 b"evicted" ,
31883210 }
31893211
3212+ # This lock is used to ensure we only subscribe to keyspace events once per token and worker
3213+ _pubsub_locks : Dict [bytes , asyncio .Lock ] = pydantic .PrivateAttr ({})
3214+
31903215 async def _get_parent_state (
31913216 self , token : str , state : BaseState | None = None
31923217 ) -> BaseState | None :
@@ -3462,7 +3487,9 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
34623487 # Some redis servers only allow out-of-band configuration, so ignore errors here.
34633488 if not environment .REFLEX_IGNORE_REDIS_CONFIG_ERROR .get ():
34643489 raise
3465- async with self .redis .pubsub () as pubsub :
3490+ if lock_key not in self ._pubsub_locks :
3491+ self ._pubsub_locks [lock_key ] = asyncio .Lock ()
3492+ async with self ._pubsub_locks [lock_key ], self .redis .pubsub () as pubsub :
34663493 await pubsub .psubscribe (lock_key_channel )
34673494 while not state_is_locked :
34683495 # wait for the lock to be released
@@ -3479,6 +3506,19 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
34793506 break
34803507 state_is_locked = await self ._try_get_lock (lock_key , lock_id )
34813508
3509+ @override
3510+ async def disconnect (self , token : str ):
3511+ """Disconnect the token from the redis client.
3512+
3513+ Args:
3514+ token: The token to disconnect.
3515+ """
3516+ lock_key = self ._lock_key (token )
3517+ if lock := self ._pubsub_locks .get (lock_key ):
3518+ if lock .locked ():
3519+ lock .release ()
3520+ del self ._pubsub_locks [lock_key ]
3521+
34823522 @contextlib .asynccontextmanager
34833523 async def _lock (self , token : str ):
34843524 """Obtain a redis lock for a token.
0 commit comments