Skip to content

Commit 1fabe33

Browse files
fix: only open one connection/sub for each token per worker
bonus: properly cleanup StateManager connections on disconnect
1 parent 1ee325f commit 1fabe33

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

reflex/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,7 +1477,7 @@ def __init__(self, namespace: str, app: App):
14771477
super().__init__(namespace)
14781478
self.app = app
14791479

1480-
def on_connect(self, sid, environ):
1480+
async def on_connect(self, sid, environ):
14811481
"""Event for when the websocket is connected.
14821482
14831483
Args:
@@ -1486,7 +1486,7 @@ def on_connect(self, sid, environ):
14861486
"""
14871487
pass
14881488

1489-
def on_disconnect(self, sid):
1489+
async def on_disconnect(self, sid):
14901490
"""Event for when the websocket disconnects.
14911491
14921492
Args:
@@ -1495,6 +1495,7 @@ def on_disconnect(self, sid):
14951495
disconnect_token = self.sid_to_token.pop(sid, None)
14961496
if disconnect_token:
14971497
self.token_to_sid.pop(disconnect_token, None)
1498+
await self.app.state_manager.disconnect(sid)
14981499

14991500
async def emit_update(self, update: StateUpdate, sid: str) -> None:
15001501
"""Emit an update to the client.

reflex/state.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2826,6 +2826,14 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
28262826
"""
28272827
yield self.state()
28282828

2829+
async def disconnect(self, token: str) -> None:
2830+
"""Disconnect the client with the given token.
2831+
2832+
Args:
2833+
token: The token to disconnect.
2834+
"""
2835+
pass
2836+
28292837

28302838
class StateManagerMemory(StateManager):
28312839
"""A state manager that stores states in memory."""
@@ -2895,6 +2903,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
28952903
yield state
28962904
await self.set_state(token, state)
28972905

2906+
@override
2907+
async def disconnect(self, token: str) -> None:
2908+
"""Disconnect the client with the given token.
2909+
2910+
Args:
2911+
token: The token to disconnect.
2912+
"""
2913+
if token in self.states:
2914+
del self.states[token]
2915+
if lock := self._states_locks.get(token):
2916+
if lock.locked():
2917+
lock.release()
2918+
del self._states_locks[token]
2919+
28982920

28992921
def _default_token_expiration() -> int:
29002922
"""Get the default token expiration time.
@@ -3183,6 +3205,9 @@ class StateManagerRedis(StateManager):
31833205
b"evicted",
31843206
}
31853207

3208+
# This lock is used to ensure we only subscribe to keyspace events once per token and worker
3209+
_pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})
3210+
31863211
async def _get_parent_state(
31873212
self, token: str, state: BaseState | None = None
31883213
) -> BaseState | None:
@@ -3458,7 +3483,9 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
34583483
# Some redis servers only allow out-of-band configuration, so ignore errors here.
34593484
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
34603485
raise
3461-
async with self.redis.pubsub() as pubsub:
3486+
if lock_key not in self._pubsub_locks:
3487+
self._pubsub_locks[lock_key] = asyncio.Lock()
3488+
async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub:
34623489
await pubsub.psubscribe(lock_key_channel)
34633490
while not state_is_locked:
34643491
# wait for the lock to be released
@@ -3475,6 +3502,19 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
34753502
break
34763503
state_is_locked = await self._try_get_lock(lock_key, lock_id)
34773504

3505+
@override
3506+
async def disconnect(self, token: str):
3507+
"""Disconnect the token from the redis client.
3508+
3509+
Args:
3510+
token: The token to disconnect.
3511+
"""
3512+
lock_key = self._lock_key(token)
3513+
if lock := self._pubsub_locks.get(lock_key):
3514+
if lock.locked():
3515+
lock.release()
3516+
del self._pubsub_locks[lock_key]
3517+
34783518
@contextlib.asynccontextmanager
34793519
async def _lock(self, token: str):
34803520
"""Obtain a redis lock for a token.

0 commit comments

Comments
 (0)