Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion packages/reflex-base/src/reflex_base/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def is_backend_base_variable(name: str, cls: type[BaseState]) -> bool:
if name in cls.inherited_backend_vars:
return False

from reflex_base.vars.base import is_computed_var
from reflex_base.vars.base import Field, Var, is_computed_var

if name in cls.__dict__:
value = cls.__dict__[name]
Expand All @@ -864,6 +864,12 @@ def is_backend_base_variable(name: str, cls: type[BaseState]) -> bool:
) or is_computed_var(value):
return False

# Custom descriptors should be invoked via their __get__/__set__
# rather than shadowed by backend var storage. Field/Var define
# __get__ for type-checking but are not user descriptors.
if hasattr(type(value), "__get__") and not isinstance(value, (Field, Var)):
return False

return True


Expand Down
60 changes: 59 additions & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,38 @@ def _override_base_method(fn: Callable[PARAMS, RETURN]) -> Callable[PARAMS, RETU
return fn


def _is_user_descriptor(value: Any) -> bool:
"""Whether a class attribute is a user-defined descriptor.

Excludes framework-recognized callables and var types so user-defined
descriptors (with __get__/__set__) are surfaced for computed-var dependency
tracking without being shadowed by backend-var storage.

Args:
value: The class attribute value to check.

Returns:
True if the value is a custom descriptor.
"""
if not hasattr(type(value), "__get__"):
return False
if isinstance(
value,
(
FunctionType,
classmethod,
staticmethod,
property,
functools.cached_property,
EventHandler,
Var,
Field,
),
):
return False
return not is_computed_var(value)


all_base_state_classes: dict[str, None] = {}

CLASS_VAR_NAMES = frozenset({
Expand Down Expand Up @@ -559,18 +591,43 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):
for name, f in cls.get_fields().items()
if name not in cls.get_skip_vars() and f.is_var and not name.startswith("_")
}
# Surface user-defined descriptors as vars so computed vars can declare
# dependencies on them without routing through backend var storage.
descriptor_vars: dict[str, Var] = {}
hints = cls._get_type_hints()
for source_cls in (*cls._mixins(), cls):
for dname, dvalue in source_cls.__dict__.items():
if (
dname in cls.base_vars
or dname in descriptor_vars
or dname in cls.inherited_vars
or dname in cls.inherited_backend_vars
or dname not in hints
or not _is_user_descriptor(dvalue)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Child override of parent descriptor silently uses parent Var

When a subclass overrides a parent's descriptor with a different implementation, dname in cls.inherited_vars is True (because the parent's placeholder Var was merged into parent_state.vars, which becomes inherited_vars). The child's descriptor is therefore skipped and the dependency-tracking Var still refers to the parent-state field name in its VarData. Any computed var in the child that depends on the overridden descriptor would be attached to the parent's tracking dict rather than the child's, potentially missing invalidation.

This may be an accepted limitation, but it would be worth a comment or a # TODO to document the behaviour for future readers.

):
continue
descriptor_vars[dname] = dispatch(
field_name=format.format_state_name(cls.get_full_name())
+ "."
+ dname
+ FIELD_MARKER,
var_data=VarData.from_state(cls, dname),
result_var_type=hints[dname],
)
cls.computed_vars = {
name: v._replace(merge_var_data=VarData.from_state(cls))
for name, v in computed_vars
}
cls.vars = {
**descriptor_vars,
**cls.inherited_vars,
**cls.base_vars,
**cls.computed_vars,
}
cls.event_handlers = {}

# Setup the base vars at the class level.
# Setup the base vars at the class level (skip descriptors, which have
# no backing pydantic field and manage their own access).
for name, prop in cls.base_vars.items():
cls._init_var(name, prop)

Expand Down Expand Up @@ -675,6 +732,7 @@ def _item_is_event_handler(name: str, value: Any) -> bool:
not name.startswith("_")
and isinstance(value, Callable)
and not isinstance(value, EventHandler)
and not getattr(value, "__override_base_method__", False)
and hasattr(value, "__code__")
)

Expand Down
53 changes: 53 additions & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4646,3 +4646,56 @@ async def test_rebind_mutable_proxy(
assert state.data["a"] == [2, 3]
# Object identity persists across serialization, so data["b"] is also mutated.
assert state.data["b"] == [2, 3]


def test_override_base_method_skips_event_handler_wrapping():
"""A method marked with __override_base_method__ should not be wrapped as an EventHandler."""
from reflex.state import _override_base_method

class OverrideState(rx.State):
@_override_base_method
def custom_override(self) -> int:
return 42

# The marked method must remain a plain function, not an EventHandler.
assert not isinstance(OverrideState.__dict__["custom_override"], EventHandler)
assert "custom_override" not in OverrideState.event_handlers
assert OverrideState().custom_override() == 42


def test_descriptor_attribute_not_in_backend_vars():
"""A custom descriptor on a state should appear in vars/base_vars but not backend_vars."""

class _IntDescriptor:
def __init__(self):
self._values: dict[int, int] = {}

def __set_name__(self, owner, name):
self._name = name

def __get__(self, instance, owner):
if instance is None:
return self
return self._values.get(id(instance), 0)

def __set__(self, instance, value):
self._values[id(instance)] = value

class DescriptorState(rx.State):
_desc_value: int = _IntDescriptor() # pyright: ignore[reportAssignmentType]

@rx.var
def doubled(self) -> int:
return self._desc_value * 2

# The descriptor should not be tracked as a backend var.
assert "_desc_value" not in DescriptorState.backend_vars
# But it should be visible in base_vars/vars for dependency tracking.
assert "_desc_value" in DescriptorState.base_vars
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incorrect base_vars assertion — test will fail

_desc_value starts with _, so it is explicitly excluded from base_vars by the not name.startswith("_") guard in __init_subclass__. Descriptor vars are added only to descriptor_vars, which is then merged into cls.vars. The assertion on line 4694 will therefore always be False and the test will fail.

The comment on line 4693 says "visible in base_vars/vars", but only the vars half is correct; the base_vars assertion should be inverted (or simply removed).

Suggested change
# But it should be visible in base_vars/vars for dependency tracking.
assert "_desc_value" in DescriptorState.base_vars
# But it should be visible in vars for dependency tracking.
assert "_desc_value" not in DescriptorState.base_vars
assert "_desc_value" in DescriptorState.vars

assert "_desc_value" in DescriptorState.vars
# Descriptor remains the class-level attribute (not overwritten by a Var).
assert isinstance(DescriptorState.__dict__["_desc_value"], _IntDescriptor)

# A computed var depending on the descriptor must register the dependency.
deps = DescriptorState._var_dependencies.get("_desc_value", set())
assert (DescriptorState.get_full_name(), "doubled") in deps
Loading