1313import json
1414import sys
1515import traceback
16+ import urllib .parse
1617from collections .abc import (
1718 AsyncGenerator ,
1819 AsyncIterator ,
114115)
115116from reflex .utils .exec import get_compile_context , is_prod_mode , is_testing_env
116117from reflex .utils .imports import ImportVar
118+ from reflex .utils .token_manager import TokenManager
117119from reflex .utils .types import ASGIApp , Message , Receive , Scope , Send
118120
119121if TYPE_CHECKING :
@@ -1958,12 +1960,6 @@ class EventNamespace(AsyncNamespace):
19581960 # The application object.
19591961 app : App
19601962
1961- # Keep a mapping between socket ID and client token.
1962- token_to_sid : dict [str , str ]
1963-
1964- # Keep a mapping between client token and socket ID.
1965- sid_to_token : dict [str , str ]
1966-
19671963 def __init__ (self , namespace : str , app : App ):
19681964 """Initialize the event namespace.
19691965
@@ -1972,17 +1968,45 @@ def __init__(self, namespace: str, app: App):
19721968 app: The application object.
19731969 """
19741970 super ().__init__ (namespace )
1975- self .token_to_sid = {}
1976- self .sid_to_token = {}
19771971 self .app = app
19781972
1979- def on_connect (self , sid : str , environ : dict ):
1973+ # Use TokenManager for distributed duplicate tab prevention
1974+ self ._token_manager = TokenManager .create ()
1975+
1976+ @property
1977+ def token_to_sid (self ) -> dict [str , str ]:
1978+ """Get token to SID mapping for backward compatibility.
1979+
1980+ Returns:
1981+ The token to SID mapping dict.
1982+ """
1983+ # For backward compatibility, expose the underlying dict
1984+ return self ._token_manager .token_to_sid
1985+
1986+ @property
1987+ def sid_to_token (self ) -> dict [str , str ]:
1988+ """Get SID to token mapping for backward compatibility.
1989+
1990+ Returns:
1991+ The SID to token mapping dict.
1992+ """
1993+ # For backward compatibility, expose the underlying dict
1994+ return self ._token_manager .sid_to_token
1995+
1996+ async def on_connect (self , sid : str , environ : dict ):
19801997 """Event for when the websocket is connected.
19811998
19821999 Args:
19832000 sid: The Socket.IO session id.
19842001 environ: The request information, including HTTP headers.
19852002 """
2003+ query_params = urllib .parse .parse_qs (environ .get ("QUERY_STRING" , "" ))
2004+ token_list = query_params .get ("token" , [])
2005+ if token_list :
2006+ await self .link_token_to_sid (sid , token_list [0 ])
2007+ else :
2008+ console .warn (f"No token provided in connection for session { sid } " )
2009+
19862010 subprotocol = environ .get ("HTTP_SEC_WEBSOCKET_PROTOCOL" )
19872011 if subprotocol and subprotocol != constants .Reflex .VERSION :
19882012 console .warn (
@@ -1995,9 +2019,18 @@ def on_disconnect(self, sid: str):
19952019 Args:
19962020 sid: The Socket.IO session id.
19972021 """
1998- disconnect_token = self .sid_to_token .pop (sid , None )
2022+ # Get token before cleaning up
2023+ disconnect_token = self .sid_to_token .get (sid )
19992024 if disconnect_token :
2000- self .token_to_sid .pop (disconnect_token , None )
2025+ # Use async cleanup through token manager
2026+ task = asyncio .create_task (
2027+ self ._token_manager .disconnect_token (disconnect_token , sid )
2028+ )
2029+ # Don't await to avoid blocking disconnect, but handle potential errors
2030+ task .add_done_callback (
2031+ lambda t : t .exception ()
2032+ and console .error (f"Token cleanup error: { t .exception ()} " )
2033+ )
20012034
20022035 async def emit_update (self , update : StateUpdate , sid : str ) -> None :
20032036 """Emit an update to the client.
@@ -2055,8 +2088,13 @@ async def on_event(self, sid: str, data: Any):
20552088 msg = f"Failed to deserialize event data: { fields } ."
20562089 raise exceptions .EventDeserializationError (msg ) from ex
20572090
2058- self .token_to_sid [event .token ] = sid
2059- self .sid_to_token [sid ] = event .token
2091+ # Correct the token if it doesn't match what we expect for this SID
2092+ expected_token = self .sid_to_token .get (sid )
2093+ if expected_token and event .token != expected_token :
2094+ # Create new event with corrected token since Event is frozen
2095+ from dataclasses import replace
2096+
2097+ event = replace (event , token = expected_token )
20602098
20612099 # Get the event environment.
20622100 if self .app .sio is None :
@@ -2106,3 +2144,17 @@ async def on_ping(self, sid: str):
21062144 """
21072145 # Emit the test event.
21082146 await self .emit (str (constants .SocketEvent .PING ), "pong" , to = sid )
2147+
2148+ async def link_token_to_sid (self , sid : str , token : str ):
2149+ """Link a token to a session id.
2150+
2151+ Args:
2152+ sid: The Socket.IO session id.
2153+ token: The client token.
2154+ """
2155+ # Use TokenManager for duplicate detection and Redis support
2156+ new_token = await self ._token_manager .link_token_to_sid (token , sid )
2157+
2158+ if new_token :
2159+ # Duplicate detected, emit new token to client
2160+ await self .emit ("new_token" , new_token , to = sid )
0 commit comments