diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index 80c3f83e1fe..11bff0d5aae 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -8,6 +8,7 @@ import functools import inspect import json +import sys from collections.abc import Callable, Sequence from importlib.util import find_spec from types import MethodType @@ -132,15 +133,20 @@ async def __aenter__(self) -> Self: raise ImmutableStateError(msg) await self._self_actx_lock.acquire() - self._self_actx_lock_holder = current_task - self._self_actx = self._self_app.modify_state( - token=self._self_substate_token, background=True - ) - mutable_state = await self._self_actx.__aenter__() - super().__setattr__( - "__wrapped__", mutable_state.get_substate(self._self_substate_path) - ) - self._self_mutable = True + try: + self._self_actx_lock_holder = current_task + self._self_actx = self._self_app.modify_state( + token=self._self_substate_token, background=True + ) + mutable_state = await self._self_actx.__aenter__() + self._self_mutable = True + super().__setattr__( + "__wrapped__", mutable_state.get_substate(self._self_substate_path) + ) + except (Exception, asyncio.CancelledError): + # Restore the proxy to a consistent state since __aexit__ will not be called when __aenter__ raises. + await self.__aexit__(*sys.exc_info()) + raise return self async def __aexit__(self, *exc_info: Any) -> None: @@ -154,15 +160,14 @@ async def __aexit__(self, *exc_info: Any) -> None: if self._self_parent_state_proxy is not None: await self._self_parent_state_proxy.__aexit__(*exc_info) return - if self._self_actx is None: - return - self._self_mutable = False try: - await self._self_actx.__aexit__(*exc_info) + if self._self_mutable and self._self_actx is not None: + await self._self_actx.__aexit__(*exc_info) finally: + self._self_actx = None + self._self_mutable = False self._self_actx_lock_holder = None self._self_actx_lock.release() - self._self_actx = None def __enter__(self): """Enter the regular context manager protocol. diff --git a/tests/units/conftest.py b/tests/units/conftest.py index c2dd936fce0..612d8beaf85 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -46,6 +46,21 @@ def app_module_mock(monkeypatch) -> mock.Mock: return app_module_mock +@pytest.fixture +def mock_app(app_module_mock: mock.Mock, app: App) -> App: + """A mocked dummy app per test. + + Args: + app_module_mock: The mock for the main app module. + app: A default App instance. + + Returns: + The mock app instance. + """ + app_module_mock.app = app + return app + + @pytest.fixture(scope="session") def windows_platform() -> bool: """Check if system is windows. diff --git a/tests/units/istate/test_proxy.py b/tests/units/istate/test_proxy.py index b71d91e619d..5fd29725fa9 100644 --- a/tests/units/istate/test_proxy.py +++ b/tests/units/istate/test_proxy.py @@ -2,9 +2,14 @@ import dataclasses import pickle +from asyncio import CancelledError +from contextlib import asynccontextmanager +from unittest.mock import patch + +import pytest import reflex as rx -from reflex.istate.proxy import MutableProxy +from reflex.istate.proxy import MutableProxy, StateProxy @dataclasses.dataclass @@ -35,3 +40,29 @@ def test_mutable_proxy_pickle_preserves_object_identity(): assert unpickled["direct"][0].id == 1 assert unpickled["proxied"][0].id == 1 assert unpickled["direct"][0] is unpickled["proxied"][0] + + +@pytest.mark.usefixtures("mock_app") +@pytest.mark.asyncio +async def test_state_proxy_recovery(): + """Ensure that `async with self` can be re-entered after a lock issue.""" + state = ProxyTestState() + state_proxy = StateProxy(state) + + with patch("reflex.app.App.modify_state") as mock_modify_state: + + @asynccontextmanager + async def mock_modify_state_context(*args, **kwargs): # noqa: RUF029 + msg = "Simulated lock issue" + raise CancelledError(msg) + yield + + mock_modify_state.side_effect = mock_modify_state_context + + with pytest.raises(CancelledError, match="Simulated lock issue"): + async with state_proxy: + pass + + # After the exception, we should be able to enter the context again without issues + async with state_proxy: + pass