Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,6 +1996,13 @@ async def emit_update(self, update: StateUpdate, sid: str) -> None:
update: The state update to send.
sid: The Socket.IO session id.
"""
if not sid:
# If the sid is None, we are not connected to a client. Prevent sending
# updates to all clients.
return
if sid not in self.sid_to_token:
console.warn(f"Attempting to send delta to disconnected websocket {sid}")
return
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
Expand Down
29 changes: 25 additions & 4 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,21 +1948,37 @@ class ModelDC:


@pytest.mark.asyncio
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
async def test_state_proxy(
grandchild_state: GrandchildState, mock_app: rx.App, token: str
):
"""Test that the state proxy works.

Args:
grandchild_state: A grandchild state.
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
child_state = grandchild_state.parent_state
assert child_state is not None
parent_state = child_state.parent_state
assert parent_state is not None
router_data = RouterData({"query": {}, "token": token, "sid": "test_sid"})
grandchild_state.router = router_data
namespace = mock_app.event_namespace
assert namespace is not None
namespace.sid_to_token[router_data.session.session_id] = token
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
mock_app.state_manager.states[parent_state.router.session.client_token] = (
parent_state
)
elif isinstance(mock_app.state_manager, StateManagerRedis):
pickle_state = parent_state._serialize()
if pickle_state:
await mock_app.state_manager.redis.set(
_substate_key(parent_state.router.session.client_token, parent_state),
pickle_state,
ex=mock_app.state_manager.token_expiration,
)

sp = StateProxy(grandchild_state)
assert sp.__wrapped__ == grandchild_state
Expand Down Expand Up @@ -2029,6 +2045,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert mcall.args[0] == str(SocketEvent.EVENT)
assert mcall.args[1] == StateUpdate(
delta={
TestState.get_full_name(): {"router": router_data},
grandchild_state.get_full_name(): {
"value2": "42",
},
Expand Down Expand Up @@ -2154,7 +2171,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
router_data = {"query": {}}
router_data = {"query": {}, "token": token}
sid = "test_sid"
namespace = mock_app.event_namespace
assert namespace is not None
namespace.sid_to_token[sid] = token
mock_app.state_manager.state = mock_app._state = BackgroundTaskState
async for update in rx.app.process(
mock_app,
Expand All @@ -2164,7 +2185,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
router_data=router_data,
payload={},
),
sid="",
sid=sid,
headers={},
client_ip="",
):
Expand All @@ -2184,7 +2205,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
router_data=router_data,
payload={},
),
sid="",
sid=sid,
headers={},
client_ip="",
):
Expand Down