diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index 06587b69b08..63fc90586aa 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -279,7 +279,7 @@ async def get_state( token, state_path = _split_substate_key(token) if state_path: # Get the State class associated with the given path. - state_cls = self.state.get_class_substate(state_path) + requested_state_cls = self.state.get_class_substate(state_path) else: msg = f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" raise RuntimeError(msg) @@ -291,7 +291,7 @@ async def get_state( # Determine which states from the tree need to be fetched. required_state_classes = sorted( - self._get_required_state_classes(state_cls, subclasses=True) + self._get_required_state_classes(requested_state_cls, subclasses=True) - {type(s) for s in flat_state_tree.values()}, key=lambda x: x.get_full_name(), ) @@ -337,7 +337,7 @@ async def get_state( # the top-level state which should always be fetched or already cached. if top_level: return flat_state_tree[self.state.get_full_name()] - return flat_state_tree[state_cls.get_full_name()] + return flat_state_tree[requested_state_cls.get_full_name()] @override async def set_state( diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 412c0e41c70..3a9e8a985f9 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2488,12 +2488,18 @@ def add_dependency(self, objclass: type[BaseState], dep: Var): var_name = all_var_data.field_name if var_name: self._static_deps.setdefault(state_name, set()).add(var_name) - objclass.get_root_state().get_class_substate( + target_state_class = objclass.get_root_state().get_class_substate( state_name - )._var_dependencies.setdefault(var_name, set()).add(( + ) + target_state_class._var_dependencies.setdefault( + var_name, set() + ).add(( objclass.get_full_name(), self._name, )) + target_state_class._potentially_dirty_states.add( + objclass.get_full_name() + ) return msg = ( "ComputedVar dependencies must be Var instances with a state and " diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 2bb3d05179b..a984a739dfb 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -55,7 +55,7 @@ ) from reflex.utils.format import json_dumps from reflex.utils.token_manager import SocketRecord -from reflex.vars.base import Var, computed_var +from reflex.vars.base import Field, Var, computed_var, field from tests.units.mock_redis import mock_redis from .states import GenState @@ -306,12 +306,12 @@ def test_base_class_vars(test_state): fields = test_state.get_fields() cls = type(test_state) - for field in fields: - if field.startswith("_") or field in cls.get_skip_vars(): + for field_name in fields: + if field_name.startswith("_") or field_name in cls.get_skip_vars(): continue - prop = getattr(cls, field) + prop = getattr(cls, field_name) assert isinstance(prop, Var) - assert prop._js_expr.split(".")[-1] == field + FIELD_MARKER + assert prop._js_expr.split(".")[-1] == field_name + FIELD_MARKER assert cls.num1._var_type is int assert cls.num2._var_type is float @@ -4304,6 +4304,8 @@ class OtherState(rx.State): state = await mock_app.state_manager.get_state(_substate_key(token, OtherState)) other_state = await state.get_state(OtherState) assert comp.State is not None + # The state should have been pre-cached from the dependency. + assert comp.State.get_name() in state.substates comp_state = await state.get_state(comp.State) assert comp_state.dirty_vars == set() @@ -4329,3 +4331,35 @@ class SecondCvState(CvMixin, rx.State): assert first_cv is not second_cv assert first_cv._static_deps is not second_cv._static_deps + + +@pytest.mark.asyncio +async def test_add_dependency_get_state_regression(mock_app: rx.App, token: str): + """Ensure that a state class can be fetched separately when it's is explicit dep.""" + + class DataState(rx.State): + """A state with a var.""" + + data: Field[list[int]] = field(default_factory=lambda: [1, 2, 3]) + + class StatsState(rx.State): + """A state with a computed var depending on DataState.""" + + @rx.var(cache=True) + async def total(self) -> int: + data_state = await self.get_state(DataState) + return sum(data_state.data) + + StatsState.computed_vars["total"].add_dependency(StatsState, DataState.data) + + class OtherState(rx.State): + """A state that gets DataState.""" + + @rx.event + async def fetch_data_state(self) -> None: + print(await self.get_state(DataState)) + + mock_app.state_manager.state = mock_app._state = rx.State + state = await mock_app.state_manager.get_state(_substate_key(token, OtherState)) + other_state = await state.get_state(OtherState) + await other_state.fetch_data_state() # Should not raise exception.