Skip to content
Merged
14 changes: 8 additions & 6 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,7 @@ export const applyEvent = async (event, socket) => {

// Send the event to the server.
if (socket) {
socket.emit(
"event",
event,
);
socket.emit("event", event);
return true;
}

Expand Down Expand Up @@ -408,9 +405,10 @@ export const connect = async (
path: endpoint["pathname"],
transports: transports,
autoUnref: false,
query: { token: getToken() },
});
// Ensure undefined fields in events are sent as null instead of removed
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v);

function checkVisibility() {
if (document.visibilityState === "visible") {
Expand Down Expand Up @@ -461,6 +459,10 @@ export const connect = async (
event_processing = false;
queueEvents([...initialEvents(), event], socket);
});
socket.current.on("new_token", async (new_token) => {
token = new_token;
window.sessionStorage.setItem(TOKEN_KEY, new_token);
});

document.addEventListener("visibilitychange", checkVisibility);
};
Expand Down Expand Up @@ -488,7 +490,7 @@ export const uploadFiles = async (
return false;
}

const upload_ref_name = `__upload_controllers_${upload_id}`
const upload_ref_name = `__upload_controllers_${upload_id}`;

if (refs[upload_ref_name]) {
console.log("Upload already in progress for ", upload_id);
Expand Down
26 changes: 21 additions & 5 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import platform
import sys
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
Expand Down Expand Up @@ -1528,14 +1529,18 @@ def __init__(self, namespace: str, app: App):
self.sid_to_token = {}
self.app = app

def on_connect(self, sid, environ):
async def on_connect(self, sid, environ):
"""Event for when the websocket is connected.

Args:
sid: The Socket.IO session id.
environ: The request information, including HTTP headers.
"""
pass
query_string = environ.get("QUERY_STRING")
query_params = dict(
qc.split("=") for qc in query_string.split("&") if "=" in qc
)
Comment thread
Lendemor marked this conversation as resolved.
Outdated
await self.link_token_to_sid(sid, query_params.get("token"))

def on_disconnect(self, sid):
"""Event for when the websocket disconnects.
Expand Down Expand Up @@ -1575,9 +1580,6 @@ async def on_event(self, sid, data):
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}
)

self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token

# Get the event environment.
if self.app.sio is None:
raise RuntimeError("Socket.IO is not initialized.")
Expand Down Expand Up @@ -1610,3 +1612,17 @@ async def on_ping(self, sid):
"""
# Emit the test event.
await self.emit(str(constants.SocketEvent.PING), "pong", to=sid)

async def link_token_to_sid(self, sid, token):
"""Link a token to a session id.

Args:
sid: The Socket.IO session id.
token: The client token.
"""
if token in self.sid_to_token.values() and sid != self.token_to_sid.get(token):
token = uuid.uuid4().hex
Comment thread
Lendemor marked this conversation as resolved.
Outdated
await self.emit("new_token", token, to=sid)

self.token_to_sid[token] = sid
self.sid_to_token[sid] = token