|
1 | 1 | """Base classes for shared / linked states.""" |
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | import contextlib |
4 | 5 | from collections.abc import AsyncIterator |
5 | 6 |
|
6 | 7 | from reflex.event import Event, get_hydrate_event |
7 | 8 | from reflex.state import BaseState, State, _override_base_method, _substate_key |
8 | 9 | from reflex.utils.exceptions import ReflexRuntimeError |
9 | 10 |
|
| 11 | +UPDATE_OTHER_CLIENT_TASKS: set[asyncio.Task] = set() |
| 12 | + |
| 13 | + |
| 14 | +def _do_update_other_tokens( |
| 15 | + affected_tokens: set[str], |
| 16 | + previous_dirty_vars: dict[str, set[str]], |
| 17 | + state_type: type[BaseState], |
| 18 | +) -> list[asyncio.Task]: |
| 19 | + """Update other clients after a shared state update. |
| 20 | +
|
| 21 | + Submit the updates in separate asyncio tasks to avoid deadlocking. |
| 22 | +
|
| 23 | + Args: |
| 24 | + affected_tokens: The tokens to update. |
| 25 | + previous_dirty_vars: The dirty vars to apply to other clients. |
| 26 | + state_type: The type of the shared state. |
| 27 | +
|
| 28 | + Returns: |
| 29 | + The list of asyncio tasks created to perform the updates. |
| 30 | + """ |
| 31 | + from reflex.utils.prerequisites import get_app |
| 32 | + |
| 33 | + app = get_app().app |
| 34 | + |
| 35 | + async def _update_client(token: str): |
| 36 | + async with app.modify_state( |
| 37 | + _substate_key(token, state_type), |
| 38 | + previous_dirty_vars=previous_dirty_vars, |
| 39 | + ): |
| 40 | + pass |
| 41 | + |
| 42 | + tasks = [] |
| 43 | + for affected_token in affected_tokens: |
| 44 | + # Don't send updates for disconnected clients. |
| 45 | + if affected_token not in app.event_namespace._token_manager.token_to_socket: |
| 46 | + continue |
| 47 | + # TODO: remove disconnected client's after some time. |
| 48 | + t = asyncio.create_task(_update_client(affected_token)) |
| 49 | + UPDATE_OTHER_CLIENT_TASKS.add(t) |
| 50 | + t.add_done_callback(UPDATE_OTHER_CLIENT_TASKS.discard) |
| 51 | + tasks.append(t) |
| 52 | + return tasks |
| 53 | + |
10 | 54 |
|
11 | 55 | class SharedStateBaseInternal(State): |
12 | 56 | """The private base state for all shared states.""" |
@@ -215,23 +259,11 @@ async def _modify_linked_states( |
215 | 259 |
|
216 | 260 | # Only propagate dirty vars when we are not already propagating from another state. |
217 | 261 | if previous_dirty_vars is None: |
218 | | - from reflex.utils.prerequisites import get_app |
219 | | - |
220 | | - app = get_app().app |
221 | | - |
222 | | - for affected_token in affected_tokens: |
223 | | - # Don't send updates for disconnected clients. |
224 | | - if ( |
225 | | - affected_token |
226 | | - not in app.event_namespace._token_manager.token_to_socket |
227 | | - ): |
228 | | - continue |
229 | | - # TODO: remove disconnected client's after some time. |
230 | | - async with app.modify_state( |
231 | | - _substate_key(affected_token, type(self)), |
232 | | - previous_dirty_vars=current_dirty_vars, |
233 | | - ): |
234 | | - pass |
| 262 | + _do_update_other_tokens( |
| 263 | + affected_tokens=affected_tokens, |
| 264 | + previous_dirty_vars=current_dirty_vars, |
| 265 | + state_type=type(self), |
| 266 | + ) |
235 | 267 |
|
236 | 268 |
|
237 | 269 | class SharedState(SharedStateBaseInternal, mixin=True): |
|
0 commit comments