Skip to content

Commit 71a54bf

Browse files
authored
fix duplicate tab issue (#4607)
* fix duplicate tab issue * review changes * fix issue * handle token_to_sid mapping via redis when enabled * fix for precommit * fix closed loop * use once on get_redis * revert once * fix * handle indexerror * fix * fix race condition
1 parent 0809c11 commit 71a54bf

5 files changed

Lines changed: 700 additions & 13 deletions

File tree

reflex/.templates/web/utils/state.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ export const connect = async (
530530
transports: transports,
531531
protocols: [reflexEnvironment.version],
532532
autoUnref: false,
533+
query: { token: getToken() },
533534
});
534535
// Ensure undefined fields in events are sent as null instead of removed
535536
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v);
@@ -601,6 +602,10 @@ export const connect = async (
601602
event_processing = false;
602603
queueEvents([...initialEvents(), event], socket, true, navigate, params);
603604
});
605+
socket.current.on("new_token", async (new_token) => {
606+
token = new_token;
607+
window.sessionStorage.setItem(TOKEN_KEY, new_token);
608+
});
604609

605610
document.addEventListener("visibilitychange", checkVisibility);
606611
};

reflex/app.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import json
1414
import sys
1515
import traceback
16+
import urllib.parse
1617
from collections.abc import (
1718
AsyncGenerator,
1819
AsyncIterator,
@@ -114,6 +115,7 @@
114115
)
115116
from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env
116117
from reflex.utils.imports import ImportVar
118+
from reflex.utils.token_manager import TokenManager
117119
from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send
118120

119121
if TYPE_CHECKING:
@@ -1958,12 +1960,6 @@ class EventNamespace(AsyncNamespace):
19581960
# The application object.
19591961
app: App
19601962

1961-
# Keep a mapping between socket ID and client token.
1962-
token_to_sid: dict[str, str]
1963-
1964-
# Keep a mapping between client token and socket ID.
1965-
sid_to_token: dict[str, str]
1966-
19671963
def __init__(self, namespace: str, app: App):
19681964
"""Initialize the event namespace.
19691965
@@ -1972,17 +1968,45 @@ def __init__(self, namespace: str, app: App):
19721968
app: The application object.
19731969
"""
19741970
super().__init__(namespace)
1975-
self.token_to_sid = {}
1976-
self.sid_to_token = {}
19771971
self.app = app
19781972

1979-
def on_connect(self, sid: str, environ: dict):
1973+
# Use TokenManager for distributed duplicate tab prevention
1974+
self._token_manager = TokenManager.create()
1975+
1976+
@property
1977+
def token_to_sid(self) -> dict[str, str]:
1978+
"""Get token to SID mapping for backward compatibility.
1979+
1980+
Returns:
1981+
The token to SID mapping dict.
1982+
"""
1983+
# For backward compatibility, expose the underlying dict
1984+
return self._token_manager.token_to_sid
1985+
1986+
@property
1987+
def sid_to_token(self) -> dict[str, str]:
1988+
"""Get SID to token mapping for backward compatibility.
1989+
1990+
Returns:
1991+
The SID to token mapping dict.
1992+
"""
1993+
# For backward compatibility, expose the underlying dict
1994+
return self._token_manager.sid_to_token
1995+
1996+
async def on_connect(self, sid: str, environ: dict):
19801997
"""Event for when the websocket is connected.
19811998
19821999
Args:
19832000
sid: The Socket.IO session id.
19842001
environ: The request information, including HTTP headers.
19852002
"""
2003+
query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING", ""))
2004+
token_list = query_params.get("token", [])
2005+
if token_list:
2006+
await self.link_token_to_sid(sid, token_list[0])
2007+
else:
2008+
console.warn(f"No token provided in connection for session {sid}")
2009+
19862010
subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL")
19872011
if subprotocol and subprotocol != constants.Reflex.VERSION:
19882012
console.warn(
@@ -1995,9 +2019,18 @@ def on_disconnect(self, sid: str):
19952019
Args:
19962020
sid: The Socket.IO session id.
19972021
"""
1998-
disconnect_token = self.sid_to_token.pop(sid, None)
2022+
# Get token before cleaning up
2023+
disconnect_token = self.sid_to_token.get(sid)
19992024
if disconnect_token:
2000-
self.token_to_sid.pop(disconnect_token, None)
2025+
# Use async cleanup through token manager
2026+
task = asyncio.create_task(
2027+
self._token_manager.disconnect_token(disconnect_token, sid)
2028+
)
2029+
# Don't await to avoid blocking disconnect, but handle potential errors
2030+
task.add_done_callback(
2031+
lambda t: t.exception()
2032+
and console.error(f"Token cleanup error: {t.exception()}")
2033+
)
20012034

20022035
async def emit_update(self, update: StateUpdate, sid: str) -> None:
20032036
"""Emit an update to the client.
@@ -2055,8 +2088,13 @@ async def on_event(self, sid: str, data: Any):
20552088
msg = f"Failed to deserialize event data: {fields}."
20562089
raise exceptions.EventDeserializationError(msg) from ex
20572090

2058-
self.token_to_sid[event.token] = sid
2059-
self.sid_to_token[sid] = event.token
2091+
# Correct the token if it doesn't match what we expect for this SID
2092+
expected_token = self.sid_to_token.get(sid)
2093+
if expected_token and event.token != expected_token:
2094+
# Create new event with corrected token since Event is frozen
2095+
from dataclasses import replace
2096+
2097+
event = replace(event, token=expected_token)
20602098

20612099
# Get the event environment.
20622100
if self.app.sio is None:
@@ -2106,3 +2144,17 @@ async def on_ping(self, sid: str):
21062144
"""
21072145
# Emit the test event.
21082146
await self.emit(str(constants.SocketEvent.PING), "pong", to=sid)
2147+
2148+
async def link_token_to_sid(self, sid: str, token: str):
2149+
"""Link a token to a session id.
2150+
2151+
Args:
2152+
sid: The Socket.IO session id.
2153+
token: The client token.
2154+
"""
2155+
# Use TokenManager for duplicate detection and Redis support
2156+
new_token = await self._token_manager.link_token_to_sid(token, sid)
2157+
2158+
if new_token:
2159+
# Duplicate detected, emit new token to client
2160+
await self.emit("new_token", new_token, to=sid)

reflex/testing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,17 @@ async def _reset_backend_state_manager(self):
376376
msg = "Failed to reset state manager."
377377
raise RuntimeError(msg)
378378

379+
# Also reset the TokenManager to avoid loop affinity issues
380+
if (
381+
hasattr(self.app_instance, "event_namespace")
382+
and self.app_instance.event_namespace is not None
383+
and hasattr(self.app_instance.event_namespace, "_token_manager")
384+
):
385+
# Import here to avoid circular imports
386+
from reflex.utils.token_manager import TokenManager
387+
388+
self.app_instance.event_namespace._token_manager = TokenManager.create()
389+
379390
def _start_frontend(self):
380391
# Set up the frontend.
381392
with chdir(self.app_path):

reflex/utils/token_manager.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
"""Token manager for handling client token to session ID mappings."""
2+
3+
from __future__ import annotations
4+
5+
import uuid
6+
from abc import ABC, abstractmethod
7+
from typing import TYPE_CHECKING
8+
9+
from reflex.utils import console, prerequisites
10+
11+
if TYPE_CHECKING:
12+
from redis.asyncio import Redis
13+
14+
15+
def _get_new_token() -> str:
16+
"""Generate a new unique token.
17+
18+
Returns:
19+
A new UUID4 token string.
20+
"""
21+
return str(uuid.uuid4())
22+
23+
24+
class TokenManager(ABC):
25+
"""Abstract base class for managing client token to session ID mappings."""
26+
27+
def __init__(self):
28+
"""Initialize the token manager with local dictionaries."""
29+
# Keep a mapping between socket ID and client token.
30+
self.token_to_sid: dict[str, str] = {}
31+
# Keep a mapping between client token and socket ID.
32+
self.sid_to_token: dict[str, str] = {}
33+
34+
@abstractmethod
35+
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
36+
"""Link a token to a session ID.
37+
38+
Args:
39+
token: The client token.
40+
sid: The Socket.IO session ID.
41+
42+
Returns:
43+
New token if duplicate detected and new token generated, None otherwise.
44+
"""
45+
46+
@abstractmethod
47+
async def disconnect_token(self, token: str, sid: str) -> None:
48+
"""Clean up token mapping when client disconnects.
49+
50+
Args:
51+
token: The client token.
52+
sid: The Socket.IO session ID.
53+
"""
54+
55+
@classmethod
56+
def create(cls) -> TokenManager:
57+
"""Factory method to create appropriate TokenManager implementation.
58+
59+
Returns:
60+
RedisTokenManager if Redis is available, LocalTokenManager otherwise.
61+
"""
62+
if prerequisites.check_redis_used():
63+
redis_client = prerequisites.get_redis()
64+
if redis_client is not None:
65+
return RedisTokenManager(redis_client)
66+
67+
return LocalTokenManager()
68+
69+
70+
class LocalTokenManager(TokenManager):
71+
"""Token manager using local in-memory dictionaries (single worker)."""
72+
73+
def __init__(self):
74+
"""Initialize the local token manager."""
75+
super().__init__()
76+
77+
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
78+
"""Link a token to a session ID.
79+
80+
Args:
81+
token: The client token.
82+
sid: The Socket.IO session ID.
83+
84+
Returns:
85+
New token if duplicate detected and new token generated, None otherwise.
86+
"""
87+
# Check if token is already mapped to a different SID (duplicate tab)
88+
if token in self.token_to_sid and sid != self.token_to_sid.get(token):
89+
new_token = _get_new_token()
90+
self.token_to_sid[new_token] = sid
91+
self.sid_to_token[sid] = new_token
92+
return new_token
93+
94+
# Normal case - link token to SID
95+
self.token_to_sid[token] = sid
96+
self.sid_to_token[sid] = token
97+
return None
98+
99+
async def disconnect_token(self, token: str, sid: str) -> None:
100+
"""Clean up token mapping when client disconnects.
101+
102+
Args:
103+
token: The client token.
104+
sid: The Socket.IO session ID.
105+
"""
106+
# Clean up both mappings
107+
self.token_to_sid.pop(token, None)
108+
self.sid_to_token.pop(sid, None)
109+
110+
111+
class RedisTokenManager(LocalTokenManager):
112+
"""Token manager using Redis for distributed multi-worker support.
113+
114+
Inherits local dict logic from LocalTokenManager and adds Redis layer
115+
for cross-worker duplicate detection.
116+
"""
117+
118+
def __init__(self, redis: Redis):
119+
"""Initialize the Redis token manager.
120+
121+
Args:
122+
redis: The Redis client instance.
123+
"""
124+
# Initialize parent's local dicts
125+
super().__init__()
126+
127+
self.redis = redis
128+
129+
# Get token expiration from config (default 1 hour)
130+
from reflex.config import get_config
131+
132+
config = get_config()
133+
self.token_expiration = config.redis_token_expiration
134+
135+
def _get_redis_key(self, token: str) -> str:
136+
"""Get Redis key for token mapping.
137+
138+
Args:
139+
token: The client token.
140+
141+
Returns:
142+
Redis key following Reflex conventions: {token}_sid
143+
"""
144+
return f"{token}_sid"
145+
146+
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
147+
"""Link a token to a session ID with Redis-based duplicate detection.
148+
149+
Args:
150+
token: The client token.
151+
sid: The Socket.IO session ID.
152+
153+
Returns:
154+
New token if duplicate detected and new token generated, None otherwise.
155+
"""
156+
# Fast local check first (handles reconnections)
157+
if token in self.token_to_sid and self.token_to_sid[token] == sid:
158+
return None # Same token, same SID = reconnection, no Redis check needed
159+
160+
# Check Redis for cross-worker duplicates
161+
redis_key = self._get_redis_key(token)
162+
163+
try:
164+
token_exists_in_redis = await self.redis.exists(redis_key)
165+
except Exception as e:
166+
console.error(f"Redis error checking token existence: {e}")
167+
return await super().link_token_to_sid(token, sid)
168+
169+
if token_exists_in_redis:
170+
# Duplicate exists somewhere - generate new token
171+
new_token = _get_new_token()
172+
new_redis_key = self._get_redis_key(new_token)
173+
174+
try:
175+
# Store in Redis
176+
await self.redis.set(new_redis_key, "1", ex=self.token_expiration)
177+
except Exception as e:
178+
console.error(f"Redis error storing new token: {e}")
179+
# Still update local dicts and continue
180+
181+
# Store in local dicts (always do this)
182+
self.token_to_sid[new_token] = sid
183+
self.sid_to_token[sid] = new_token
184+
return new_token
185+
186+
# Normal case - store in both Redis and local dicts
187+
try:
188+
await self.redis.set(redis_key, "1", ex=self.token_expiration)
189+
except Exception as e:
190+
console.error(f"Redis error storing token: {e}")
191+
# Continue with local storage
192+
193+
# Store in local dicts (always do this)
194+
self.token_to_sid[token] = sid
195+
self.sid_to_token[sid] = token
196+
return None
197+
198+
async def disconnect_token(self, token: str, sid: str) -> None:
199+
"""Clean up token mapping when client disconnects.
200+
201+
Args:
202+
token: The client token.
203+
sid: The Socket.IO session ID.
204+
"""
205+
# Only clean up if we own it locally (fast ownership check)
206+
if self.token_to_sid.get(token) == sid:
207+
# Clean up Redis
208+
redis_key = self._get_redis_key(token)
209+
try:
210+
await self.redis.delete(redis_key)
211+
except Exception as e:
212+
console.error(f"Redis error deleting token: {e}")
213+
214+
# Clean up local dicts (always do this)
215+
await super().disconnect_token(token, sid)

0 commit comments

Comments
 (0)