Skip to content
Merged
5 changes: 5 additions & 0 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ export const connect = async (
transports: transports,
protocols: [reflexEnvironment.version],
autoUnref: false,
query: { token: getToken() },
});
// Ensure undefined fields in events are sent as null instead of removed
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v);
Expand Down Expand Up @@ -601,6 +602,10 @@ export const connect = async (
event_processing = false;
queueEvents([...initialEvents(), event], socket, true, navigate, params);
});
socket.current.on("new_token", async (new_token) => {
token = new_token;
window.sessionStorage.setItem(TOKEN_KEY, new_token);
});

document.addEventListener("visibilitychange", checkVisibility);
};
Expand Down
68 changes: 54 additions & 14 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import json
import sys
import traceback
import urllib.parse
from collections.abc import (
AsyncGenerator,
AsyncIterator,
Expand Down Expand Up @@ -114,6 +115,7 @@
)
from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env
from reflex.utils.imports import ImportVar
from reflex.utils.token_manager import TokenManager
from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send

if TYPE_CHECKING:
Expand Down Expand Up @@ -1952,12 +1954,6 @@ class EventNamespace(AsyncNamespace):
# The application object.
app: App

# Keep a mapping between socket ID and client token.
token_to_sid: dict[str, str]

# Keep a mapping between client token and socket ID.
sid_to_token: dict[str, str]

def __init__(self, namespace: str, app: App):
"""Initialize the event namespace.

Expand All @@ -1966,17 +1962,41 @@ def __init__(self, namespace: str, app: App):
app: The application object.
"""
super().__init__(namespace)
self.token_to_sid = {}
self.sid_to_token = {}
self.app = app

def on_connect(self, sid: str, environ: dict):
# Use TokenManager for distributed duplicate tab prevention
self._token_manager = TokenManager.create()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider putting the TokenManager on App as a real API into this data.

down the road we could even expose helper iterators on this new TokenManager to make cancelling background tasks easier, like session_is_connected and token_is_connected which would save the passed sid/token and yield True as long as these values were still being tracked.

Like so
@rx.event(background=True)
async def thing_checker(self):
    for _task_active in app.token_manager.session_is_connected(
        sid=self.router_data.session.session_id
    ):
        async with self:
            if await self._check_thing():
                await self._update_thing()
        await asyncio.sleep(CHECK_INTERVAL)

lot of possibilities for an asyncio.Event API for getting notified when a session or client disconnects, which people have asked for


@property
def token_to_sid(self) -> dict[str, str]:
"""Get token to SID mapping for backward compatibility.

Returns:
The token to SID mapping dict.
"""
# For backward compatibility, expose the underlying dict
return self._token_manager.token_to_sid
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should raise a deprecation here and below


@property
def sid_to_token(self) -> dict[str, str]:
"""Get SID to token mapping for backward compatibility.

Returns:
The SID to token mapping dict.
"""
# For backward compatibility, expose the underlying dict
return self._token_manager.sid_to_token

async def on_connect(self, sid: str, environ: dict):
"""Event for when the websocket is connected.

Args:
sid: The Socket.IO session id.
environ: The request information, including HTTP headers.
"""
query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING"))
await self.link_token_to_sid(sid, query_params.get("token", [])[0])
Comment thread
Lendemor marked this conversation as resolved.
Outdated

subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL")
if subprotocol and subprotocol != constants.Reflex.VERSION:
console.warn(
Expand All @@ -1989,9 +2009,18 @@ def on_disconnect(self, sid: str):
Args:
sid: The Socket.IO session id.
"""
disconnect_token = self.sid_to_token.pop(sid, None)
# Get token before cleaning up
disconnect_token = self.sid_to_token.get(sid)
if disconnect_token:
self.token_to_sid.pop(disconnect_token, None)
# Use async cleanup through token manager
task = asyncio.create_task(
self._token_manager.disconnect_token(disconnect_token, sid)
)
# Don't await to avoid blocking disconnect, but handle potential errors
task.add_done_callback(
lambda t: t.exception()
and console.error(f"Token cleanup error: {t.exception()}")
)

async def emit_update(self, update: StateUpdate, sid: str) -> None:
"""Emit an update to the client.
Expand Down Expand Up @@ -2049,9 +2078,6 @@ async def on_event(self, sid: str, data: Any):
msg = f"Failed to deserialize event data: {fields}."
raise exceptions.EventDeserializationError(msg) from ex

self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token

# Get the event environment.
if self.app.sio is None:
msg = "Socket.IO is not initialized."
Expand Down Expand Up @@ -2100,3 +2126,17 @@ async def on_ping(self, sid: str):
"""
# Emit the test event.
await self.emit(str(constants.SocketEvent.PING), "pong", to=sid)

async def link_token_to_sid(self, sid: str, token: str):
"""Link a token to a session id.

Args:
sid: The Socket.IO session id.
token: The client token.
"""
# Use TokenManager for duplicate detection and Redis support
new_token = await self._token_manager.link_token_to_sid(token, sid)

if new_token:
# Duplicate detected, emit new token to client
await self.emit("new_token", new_token, to=sid)
11 changes: 11 additions & 0 deletions reflex/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ async def _reset_backend_state_manager(self):
msg = "Failed to reset state manager."
raise RuntimeError(msg)

# Also reset the TokenManager to avoid loop affinity issues
if (
hasattr(self.app_instance, "event_namespace")
and self.app_instance.event_namespace is not None
and hasattr(self.app_instance.event_namespace, "_token_manager")
):
# Import here to avoid circular imports
from reflex.utils.token_manager import TokenManager

self.app_instance.event_namespace._token_manager = TokenManager.create()

def _start_frontend(self):
# Set up the frontend.
with chdir(self.app_path):
Expand Down
217 changes: 217 additions & 0 deletions reflex/utils/token_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
"""Token manager for handling client token to session ID mappings."""

from __future__ import annotations

import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from reflex.utils import console, prerequisites

if TYPE_CHECKING:
from redis.asyncio import Redis


def _get_new_token() -> str:
"""Generate a new unique token.

Returns:
A new UUID4 token string.
"""
return str(uuid.uuid4())


class TokenManager(ABC):
"""Abstract base class for managing client token to session ID mappings."""

def __init__(self):
"""Initialize the token manager with local dictionaries."""
# Keep a mapping between socket ID and client token.
self.token_to_sid: dict[str, str] = {}
# Keep a mapping between client token and socket ID.
self.sid_to_token: dict[str, str] = {}

@abstractmethod
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
"""Link a token to a session ID.

Args:
token: The client token.
sid: The Socket.IO session ID.

Returns:
New token if duplicate detected and new token generated, None otherwise.
"""

@abstractmethod
async def disconnect_token(self, token: str, sid: str) -> None:
"""Clean up token mapping when client disconnects.

Args:
token: The client token.
sid: The Socket.IO session ID.
"""

@classmethod
def create(cls) -> TokenManager:
"""Factory method to create appropriate TokenManager implementation.

Returns:
RedisTokenManager if Redis is available, LocalTokenManager otherwise.
"""
if prerequisites.check_redis_used():
redis_client = prerequisites.get_redis()
if redis_client is not None:
return RedisTokenManager(redis_client)

return LocalTokenManager()


class LocalTokenManager(TokenManager):
"""Token manager using local in-memory dictionaries (single worker)."""

def __init__(self):
"""Initialize the local token manager."""
super().__init__()

async def link_token_to_sid(self, token: str, sid: str) -> str | None:
"""Link a token to a session ID.

Args:
token: The client token.
sid: The Socket.IO session ID.

Returns:
New token if duplicate detected and new token generated, None otherwise.
"""
# Check if token is already mapped to a different SID (duplicate tab)
if token in self.sid_to_token.values() and sid != self.token_to_sid.get(token):
Comment thread
Lendemor marked this conversation as resolved.
Outdated
new_token = _get_new_token()
self.token_to_sid[new_token] = sid
self.sid_to_token[sid] = new_token
return new_token

# Normal case - link token to SID
self.token_to_sid[token] = sid
self.sid_to_token[sid] = token
return None

async def disconnect_token(self, token: str, sid: str) -> None:
"""Clean up token mapping when client disconnects.

Args:
token: The client token.
sid: The Socket.IO session ID.
"""
# Clean up both mappings
self.token_to_sid.pop(token, None)
self.sid_to_token.pop(sid, None)


class RedisTokenManager(LocalTokenManager):
"""Token manager using Redis for distributed multi-worker support.

Inherits local dict logic from LocalTokenManager and adds Redis layer
for cross-worker duplicate detection.
"""

def __init__(self, redis: Redis):
"""Initialize the Redis token manager.

Args:
redis: The Redis client instance.
"""
# Initialize parent's local dicts
super().__init__()

self.redis = redis

# Get token expiration from config (default 1 hour)
from reflex.config import get_config

config = get_config()
self.token_expiration = config.redis_token_expiration

def _get_redis_key(self, token: str) -> str:
"""Get Redis key for token mapping.

Args:
token: The client token.

Returns:
Redis key following Reflex conventions: {token}_sid
"""
return f"{token}_sid"

async def link_token_to_sid(self, token: str, sid: str) -> str | None:
"""Link a token to a session ID with Redis-based duplicate detection.

Args:
token: The client token.
sid: The Socket.IO session ID.

Returns:
New token if duplicate detected and new token generated, None otherwise.
"""
# Fast local check first (handles reconnections)
if token in self.token_to_sid and self.token_to_sid[token] == sid:
return None # Same token, same SID = reconnection, no Redis check needed

# Check Redis for cross-worker duplicates
redis_key = self._get_redis_key(token)

try:
token_exists_in_redis = await self.redis.exists(redis_key)
except Exception as e:
console.error(f"Redis error checking token existence: {e}")
return await super().link_token_to_sid(token, sid)

if token_exists_in_redis:
# Duplicate exists somewhere - generate new token
new_token = _get_new_token()
new_redis_key = self._get_redis_key(new_token)

try:
# Store in Redis
await self.redis.set(new_redis_key, "1", ex=self.token_expiration)
except Exception as e:
console.error(f"Redis error storing new token: {e}")
# Still update local dicts and continue

# Store in local dicts (always do this)
self.token_to_sid[new_token] = sid
self.sid_to_token[sid] = new_token

console.debug(f"Duplicate tab detected. Generated new token: {new_token}")
return new_token

# Normal case - store in both Redis and local dicts
try:
await self.redis.set(redis_key, "1", ex=self.token_expiration)
except Exception as e:
console.error(f"Redis error storing token: {e}")
# Continue with local storage

# Store in local dicts (always do this)
self.token_to_sid[token] = sid
self.sid_to_token[sid] = token
return None

async def disconnect_token(self, token: str, sid: str) -> None:
"""Clean up token mapping when client disconnects.

Args:
token: The client token.
sid: The Socket.IO session ID.
"""
# Only clean up if we own it locally (fast ownership check)
if self.token_to_sid.get(token) == sid:
# Clean up Redis
redis_key = self._get_redis_key(token)
try:
await self.redis.delete(redis_key)
except Exception as e:
console.error(f"Redis error deleting token: {e}")

# Clean up local dicts (always do this)
await super().disconnect_token(token, sid)
Loading