Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions reflex/istate/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import inspect
import json
import sys
from collections.abc import Callable, Sequence
from importlib.util import find_spec
from types import MethodType
Expand Down Expand Up @@ -132,15 +133,20 @@ async def __aenter__(self) -> Self:
raise ImmutableStateError(msg)

await self._self_actx_lock.acquire()
self._self_actx_lock_holder = current_task
self._self_actx = self._self_app.modify_state(
token=self._self_substate_token, background=True
)
mutable_state = await self._self_actx.__aenter__()
super().__setattr__(
"__wrapped__", mutable_state.get_substate(self._self_substate_path)
)
self._self_mutable = True
try:
self._self_actx_lock_holder = current_task
self._self_actx = self._self_app.modify_state(
token=self._self_substate_token, background=True
)
mutable_state = await self._self_actx.__aenter__()
self._self_mutable = True
super().__setattr__(
"__wrapped__", mutable_state.get_substate(self._self_substate_path)
)
except (Exception, asyncio.CancelledError):
# Restore the proxy to a consistent state since __aexit__ will not be called when __aenter__ raises.
await self.__aexit__(*sys.exc_info())
raise
return self

async def __aexit__(self, *exc_info: Any) -> None:
Expand All @@ -154,15 +160,14 @@ async def __aexit__(self, *exc_info: Any) -> None:
if self._self_parent_state_proxy is not None:
await self._self_parent_state_proxy.__aexit__(*exc_info)
return
if self._self_actx is None:
return
self._self_mutable = False
try:
await self._self_actx.__aexit__(*exc_info)
if self._self_mutable and self._self_actx is not None:
await self._self_actx.__aexit__(*exc_info)
finally:
self._self_actx = None
self._self_mutable = False
self._self_actx_lock_holder = None
self._self_actx_lock.release()
self._self_actx = None

def __enter__(self):
"""Enter the regular context manager protocol.
Expand Down
15 changes: 15 additions & 0 deletions tests/units/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def app_module_mock(monkeypatch) -> mock.Mock:
return app_module_mock


@pytest.fixture
def mock_app(app_module_mock: mock.Mock, app: App) -> App:
"""A mocked dummy app per test.

Args:
app_module_mock: The mock for the main app module.
app: A default App instance.

Returns:
The mock app instance.
"""
app_module_mock.app = app
return app


@pytest.fixture(scope="session")
def windows_platform() -> bool:
"""Check if system is windows.
Expand Down
33 changes: 32 additions & 1 deletion tests/units/istate/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

import dataclasses
import pickle
from asyncio import CancelledError
from contextlib import asynccontextmanager
from unittest.mock import patch

import pytest

import reflex as rx
from reflex.istate.proxy import MutableProxy
from reflex.istate.proxy import MutableProxy, StateProxy


@dataclasses.dataclass
Expand Down Expand Up @@ -35,3 +40,29 @@ def test_mutable_proxy_pickle_preserves_object_identity():
assert unpickled["direct"][0].id == 1
assert unpickled["proxied"][0].id == 1
assert unpickled["direct"][0] is unpickled["proxied"][0]


@pytest.mark.usefixtures("mock_app")
@pytest.mark.asyncio
async def test_state_proxy_recovery():
"""Ensure that `async with self` can be re-entered after a lock issue."""
state = ProxyTestState()
state_proxy = StateProxy(state)

with patch("reflex.app.App.modify_state") as mock_modify_state:

@asynccontextmanager
async def mock_modify_state_context(*args, **kwargs): # noqa: RUF029
msg = "Simulated lock issue"
raise CancelledError(msg)
yield

mock_modify_state.side_effect = mock_modify_state_context

with pytest.raises(CancelledError, match="Simulated lock issue"):
async with state_proxy:
pass

# After the exception, we should be able to enter the context again without issues
async with state_proxy:
pass
Loading