diff --git a/packages/reflex-base/src/reflex_base/utils/types.py b/packages/reflex-base/src/reflex_base/utils/types.py index 0d2e63858f9..feeb3b0da50 100644 --- a/packages/reflex-base/src/reflex_base/utils/types.py +++ b/packages/reflex-base/src/reflex_base/utils/types.py @@ -845,7 +845,7 @@ def is_backend_base_variable(name: str, cls: type[BaseState]) -> bool: if name in cls.inherited_backend_vars: return False - from reflex_base.vars.base import is_computed_var + from reflex_base.vars.base import Field, Var, is_computed_var if name in cls.__dict__: value = cls.__dict__[name] @@ -864,6 +864,12 @@ def is_backend_base_variable(name: str, cls: type[BaseState]) -> bool: ) or is_computed_var(value): return False + # Custom descriptors should be invoked via their __get__/__set__ + # rather than shadowed by backend var storage. Field/Var define + # __get__ for type-checking but are not user descriptors. + if hasattr(type(value), "__get__") and not isinstance(value, (Field, Var)): + return False + return True diff --git a/reflex/state.py b/reflex/state.py index c09d68658f8..f105ec5dfc3 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -304,6 +304,38 @@ def _override_base_method(fn: Callable[PARAMS, RETURN]) -> Callable[PARAMS, RETU return fn +def _is_user_descriptor(value: Any) -> bool: + """Whether a class attribute is a user-defined descriptor. + + Excludes framework-recognized callables and var types so user-defined + descriptors (with __get__/__set__) are surfaced for computed-var dependency + tracking without being shadowed by backend-var storage. + + Args: + value: The class attribute value to check. + + Returns: + True if the value is a custom descriptor. + """ + if not hasattr(type(value), "__get__"): + return False + if isinstance( + value, + ( + FunctionType, + classmethod, + staticmethod, + property, + functools.cached_property, + EventHandler, + Var, + Field, + ), + ): + return False + return not is_computed_var(value) + + all_base_state_classes: dict[str, None] = {} CLASS_VAR_NAMES = frozenset({ @@ -529,6 +561,28 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): ) raise StateValueError(msg) + # A descriptor defined directly on this class overrides any same-named + # entry inherited from a parent state. Drop those names from the + # inherited maps so backend var assembly, dependency tracking, and the + # __setattr__ routing all resolve to the descriptor on this class. + hints = cls._get_type_hints() + own_descriptor_names = { + name + for name, value in cls.__dict__.items() + if name in hints and _is_user_descriptor(value) + } + if own_descriptor_names: + cls.inherited_vars = { + k: v + for k, v in cls.inherited_vars.items() + if k not in own_descriptor_names + } + cls.inherited_backend_vars = { + k: v + for k, v in cls.inherited_backend_vars.items() + if k not in own_descriptor_names + } + # Get computed vars. computed_vars = cls._get_computed_vars() cls._check_overridden_computed_vars() @@ -559,11 +613,40 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): for name, f in cls.get_fields().items() if name not in cls.get_skip_vars() and f.is_var and not name.startswith("_") } + # Surface user-defined descriptors as vars so computed vars can declare + # dependencies on them. Descriptors on this class always win over + # inherited or mixin-provided entries with the same name; mixin entries + # are skipped if already recorded. + descriptor_vars: dict[str, Var] = {} + for source_cls in (*cls._mixins(), cls): + is_self = source_cls is cls + for dname, dvalue in source_cls.__dict__.items(): + if ( + dname not in hints + or dname in cls.base_vars + or not _is_user_descriptor(dvalue) + ): + continue + if not is_self and ( + dname in descriptor_vars + or dname in cls.inherited_vars + or dname in cls.inherited_backend_vars + ): + continue + descriptor_vars[dname] = dispatch( + field_name=format.format_state_name(cls.get_full_name()) + + "." + + dname + + FIELD_MARKER, + var_data=VarData.from_state(cls, dname), + result_var_type=hints[dname], + ) cls.computed_vars = { name: v._replace(merge_var_data=VarData.from_state(cls)) for name, v in computed_vars } cls.vars = { + **descriptor_vars, **cls.inherited_vars, **cls.base_vars, **cls.computed_vars, @@ -675,6 +758,7 @@ def _item_is_event_handler(name: str, value: Any) -> bool: not name.startswith("_") and isinstance(value, Callable) and not isinstance(value, EventHandler) + and not getattr(value, "__override_base_method__", False) and hasattr(value, "__code__") ) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index cb2bf87bdd2..ed73d2b5b75 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -4646,3 +4646,107 @@ async def test_rebind_mutable_proxy( assert state.data["a"] == [2, 3] # Object identity persists across serialization, so data["b"] is also mutated. assert state.data["b"] == [2, 3] + + +def test_override_base_method_skips_event_handler_wrapping(): + """A method marked with __override_base_method__ should not be wrapped as an EventHandler.""" + from reflex.state import _override_base_method + + class OverrideState(rx.State): + @_override_base_method + def custom_override(self) -> int: + return 42 + + # The marked method must remain a plain function, not an EventHandler. + assert not isinstance(OverrideState.__dict__["custom_override"], EventHandler) + assert "custom_override" not in OverrideState.event_handlers + assert OverrideState().custom_override() == 42 + + +def test_descriptor_attribute_not_in_backend_vars(): + """A custom descriptor on a state should appear in vars/base_vars but not backend_vars.""" + + class _IntDescriptor: + def __init__(self): + self._values: dict[int, int] = {} + + def __set_name__(self, owner, name): + self._name = name + + def __get__(self, instance, owner): + if instance is None: + return self + return self._values.get(id(instance), 0) + + def __set__(self, instance, value): + self._values[id(instance)] = value + + class DescriptorState(rx.State): + _desc_value: int = _IntDescriptor() # pyright: ignore[reportAssignmentType] + + @rx.var + def doubled(self) -> int: + return self._desc_value * 2 + + # The descriptor should not be tracked as a backend var or a base var + # (only pydantic-backed public fields go into base_vars). + assert "_desc_value" not in DescriptorState.backend_vars + assert "_desc_value" not in DescriptorState.base_vars + # But it should be visible in vars for dependency tracking. + assert "_desc_value" in DescriptorState.vars + # Descriptor remains the class-level attribute (not overwritten by a Var). + assert isinstance(DescriptorState.__dict__["_desc_value"], _IntDescriptor) + + # A computed var depending on the descriptor must register the dependency. + deps = DescriptorState._var_dependencies.get("_desc_value", set()) + assert (DescriptorState.get_full_name(), "doubled") in deps + + +def test_descriptor_overrides_inherited_descriptor(): + """A child state defining a descriptor with the same name as a parent overrides it.""" + + class _Sentinel: + def __init__(self, label: str): + self.label = label + self._values: dict[int, int] = {} + + def __get__(self, instance, owner): + if instance is None: + return self + return self._values.get(id(instance), 0) + + def __set__(self, instance, value): + self._values[id(instance)] = value + + parent_descriptor = _Sentinel("parent") + child_descriptor = _Sentinel("child") + + class ParentDescState(rx.State): + _shared: int = parent_descriptor # pyright: ignore[reportAssignmentType] + + @rx.var + def parent_view(self) -> int: + return self._shared + + class ChildDescState(ParentDescState): + _shared: int = child_descriptor # pyright: ignore[reportAssignmentType] + + @rx.var + def child_view(self) -> int: + return self._shared * 10 + + # The child class's descriptor wins on the class itself. + assert ChildDescState.__dict__["_shared"] is child_descriptor + # The override drops the inherited entry so dependency tracking attaches + # to the child class rather than the parent. + assert "_shared" not in ChildDescState.inherited_vars + assert "_shared" not in ChildDescState.inherited_backend_vars + assert "_shared" not in ChildDescState.backend_vars + # The child's Var (not the parent's) is what shows up in vars. + child_var = ChildDescState.vars["_shared"] + assert child_var is not ParentDescState.vars["_shared"] + # Child's computed var depends on child's _shared, parent's stays at parent. + child_deps = ChildDescState._var_dependencies.get("_shared", set()) + parent_deps = ParentDescState._var_dependencies.get("_shared", set()) + assert (ChildDescState.get_full_name(), "child_view") in child_deps + assert (ParentDescState.get_full_name(), "parent_view") in parent_deps