Skip to content

Commit 743f0c3

Browse files
committed
remove background task deadlock updating linked state
perform the subsequent updates in an asyncio.Task to allow the original caller to drop the lock for the other shared states.
1 parent a549af7 commit 743f0c3

2 files changed

Lines changed: 49 additions & 20 deletions

File tree

reflex/istate/shared.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,56 @@
11
"""Base classes for shared / linked states."""
22

3+
import asyncio
34
import contextlib
45
from collections.abc import AsyncIterator
56

67
from reflex.event import Event, get_hydrate_event
78
from reflex.state import BaseState, State, _override_base_method, _substate_key
89
from reflex.utils.exceptions import ReflexRuntimeError
910

11+
UPDATE_OTHER_CLIENT_TASKS: set[asyncio.Task] = set()
12+
13+
14+
def _do_update_other_tokens(
15+
affected_tokens: set[str],
16+
previous_dirty_vars: dict[str, set[str]],
17+
state_type: type[BaseState],
18+
) -> list[asyncio.Task]:
19+
"""Update other clients after a shared state update.
20+
21+
Submit the updates in separate asyncio tasks to avoid deadlocking.
22+
23+
Args:
24+
affected_tokens: The tokens to update.
25+
previous_dirty_vars: The dirty vars to apply to other clients.
26+
state_type: The type of the shared state.
27+
28+
Returns:
29+
The list of asyncio tasks created to perform the updates.
30+
"""
31+
from reflex.utils.prerequisites import get_app
32+
33+
app = get_app().app
34+
35+
async def _update_client(token: str):
36+
async with app.modify_state(
37+
_substate_key(token, state_type),
38+
previous_dirty_vars=previous_dirty_vars,
39+
):
40+
pass
41+
42+
tasks = []
43+
for affected_token in affected_tokens:
44+
# Don't send updates for disconnected clients.
45+
if affected_token not in app.event_namespace._token_manager.token_to_socket:
46+
continue
47+
# TODO: remove disconnected client's after some time.
48+
t = asyncio.create_task(_update_client(affected_token))
49+
UPDATE_OTHER_CLIENT_TASKS.add(t)
50+
t.add_done_callback(UPDATE_OTHER_CLIENT_TASKS.discard)
51+
tasks.append(t)
52+
return tasks
53+
1054

1155
class SharedStateBaseInternal(State):
1256
"""The private base state for all shared states."""
@@ -215,23 +259,11 @@ async def _modify_linked_states(
215259

216260
# Only propagate dirty vars when we are not already propagating from another state.
217261
if previous_dirty_vars is None:
218-
from reflex.utils.prerequisites import get_app
219-
220-
app = get_app().app
221-
222-
for affected_token in affected_tokens:
223-
# Don't send updates for disconnected clients.
224-
if (
225-
affected_token
226-
not in app.event_namespace._token_manager.token_to_socket
227-
):
228-
continue
229-
# TODO: remove disconnected client's after some time.
230-
async with app.modify_state(
231-
_substate_key(affected_token, type(self)),
232-
previous_dirty_vars=current_dirty_vars,
233-
):
234-
pass
262+
_do_update_other_tokens(
263+
affected_tokens=affected_tokens,
264+
previous_dirty_vars=current_dirty_vars,
265+
state_type=type(self),
266+
)
235267

236268

237269
class SharedState(SharedStateBaseInternal, mixin=True):

tests/integration/test_linked_state.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
def LinkedStateApp():
1818
"""Test that linked state works as expected."""
19-
import asyncio
2019
from typing import Any
2120

2221
import reflex as rx
@@ -72,13 +71,11 @@ async def bump_counter_bg(self):
7271
async with self:
7372
ss = await self.get_state(SharedState)
7473
ss.counter += 1
75-
await asyncio.sleep(0)
7674
async with self:
7775
ss = await self.get_state(SharedState)
7876
for _ in range(5):
7977
async with ss:
8078
ss.counter += 1
81-
await asyncio.sleep(0)
8279

8380
@rx.event
8481
async def bump_counter_yield(self):

0 commit comments

Comments
 (0)