diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index edfd25c7cdc..8d406d4639a 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -11,8 +11,6 @@ from typing import Any from urllib.parse import urlparse -from pydantic.v1.fields import ModelField - from reflex import constants from reflex.components.base import ( Body, @@ -34,7 +32,7 @@ from reflex.utils.exec import is_in_app_harness from reflex.utils.imports import ImportVar, ParsedImportDict from reflex.utils.prerequisites import get_web_dir -from reflex.vars.base import Var +from reflex.vars.base import Field, Var # To re-export this function. merge_imports = imports.merge_imports @@ -212,7 +210,7 @@ def compile_state(state: type[BaseState]) -> dict: def _compile_client_storage_field( - field: ModelField, + field: Field, ) -> tuple[ type[Cookie] | type[LocalStorage] | type[SessionStorage] | None, dict[str, Any] | None, diff --git a/reflex/components/component.py b/reflex/components/component.py index 02f4ce64914..3c8118e8862 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -7,7 +7,6 @@ import dataclasses import functools import inspect -import sys import typing from abc import ABC, ABCMeta, abstractmethod from collections.abc import Callable, Iterator, Mapping, Sequence @@ -20,10 +19,8 @@ Annotated, Any, ClassVar, - ForwardRef, Generic, TypeVar, - _eval_type, # pyright: ignore [reportAttributeAccessIssue] cast, get_args, get_origin, @@ -74,46 +71,6 @@ from reflex.vars.object import ObjectVar from reflex.vars.sequence import LiteralArrayVar, LiteralStringVar, StringVar - -def resolve_annotations( - raw_annotations: Mapping[str, type[Any]], module_name: str | None -) -> dict[str, type[Any]]: - """Partially taken from typing.get_type_hints. - - Resolve string or ForwardRef annotations into type objects if possible. - - Args: - raw_annotations: The raw annotations to resolve. - module_name: The name of the module. - - Returns: - The resolved annotations. - """ - module = sys.modules.get(module_name, None) if module_name is not None else None - - base_globals: dict[str, Any] | None = ( - module.__dict__ if module is not None else None - ) - - annotations = {} - for name, value in raw_annotations.items(): - if isinstance(value, str): - if sys.version_info == (3, 10, 0): - value = ForwardRef(value, is_argument=False) - else: - value = ForwardRef(value, is_argument=False, is_class=True) - try: - if sys.version_info >= (3, 13): - value = _eval_type(value, base_globals, None, type_params=()) - else: - value = _eval_type(value, base_globals, None) - except NameError: - # this is ok, it can be fixed with update_forward_refs - pass - annotations[name] = value - return annotations - - FIELD_TYPE = TypeVar("FIELD_TYPE") @@ -228,7 +185,7 @@ def __new__(cls, name: str, bases: tuple[type], namespace: dict[str, Any]) -> ty # Add the field to the class inherited_fields: dict[str, ComponentField] = {} own_fields: dict[str, ComponentField] = {} - resolved_annotations = resolve_annotations( + resolved_annotations = types.resolve_annotations( namespace.get("__annotations__", {}), namespace["__module__"] ) diff --git a/reflex/event.py b/reflex/event.py index 615a2a9e47d..55441d7a2e6 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -95,6 +95,7 @@ def substate_token(self) -> str: @dataclasses.dataclass( init=True, frozen=True, + kw_only=True, ) class EventActionsMixin: """Mixin for DOM event actions.""" @@ -170,6 +171,7 @@ def temporal(self) -> Self: @dataclasses.dataclass( init=True, frozen=True, + kw_only=True, ) class EventHandler(EventActionsMixin): """An event handler responds to an event to update the state.""" @@ -270,6 +272,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> EventSpec: @dataclasses.dataclass( init=True, frozen=True, + kw_only=True, ) class EventSpec(EventActionsMixin): """An event specification. diff --git a/reflex/state.py b/reflex/state.py index 4e407c1303a..696c557a11a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -13,20 +13,10 @@ import sys import typing import warnings -from abc import ABC from collections.abc import AsyncIterator, Callable, Sequence from hashlib import md5 from types import FunctionType -from typing import ( - TYPE_CHECKING, - Any, - BinaryIO, - ClassVar, - TypeVar, - cast, - get_args, - get_type_hints, -) +from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, TypeVar, cast, get_type_hints import pydantic.v1 as pydantic from pydantic import BaseModel as BaseModelV2 @@ -68,17 +58,12 @@ ) from reflex.utils.exceptions import ImmutableStateError as ImmutableStateError from reflex.utils.exec import is_testing_env -from reflex.utils.types import ( - _isinstance, - get_origin, - is_union, - true_type_for_pydantic_field, - value_inside_optional, -) -from reflex.vars import VarData +from reflex.utils.types import _isinstance, is_union, value_inside_optional +from reflex.vars import Field, VarData, field from reflex.vars.base import ( ComputedVar, DynamicRouteVar, + EvenMoreBasicBaseState, Var, computed_var, dispatch, @@ -255,38 +240,23 @@ def __call__(self, *args: Any) -> EventSpec: from pydantic.v1.fields import ModelField -def _unwrap_field_type(type_: types.GenericType) -> type: - """Unwrap rx.Field type annotations. - - Args: - type_: The type to unwrap. - - Returns: - The unwrapped type. - """ - from reflex.vars import Field - - if get_origin(type_) is Field: - return get_args(type_)[0] - return type_ - - -def get_var_for_field(cls: type[BaseState], f: ModelField): - """Get a Var instance for a Pydantic field. +def get_var_for_field(cls: type[BaseState], name: str, f: Field): + """Get a Var instance for a state field. Args: cls: The state class. - f: The Pydantic field. + name: The name of the field. + f: The Field instance. Returns: The Var instance. """ - field_name = format.format_state_name(cls.get_full_name()) + "." + f.name + field_name = format.format_state_name(cls.get_full_name()) + "." + name return dispatch( field_name=field_name, - var_data=VarData.from_state(cls, f.name), - result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)), + var_data=VarData.from_state(cls, name), + result_var_type=f.outer_type_, ) @@ -312,7 +282,7 @@ async def _resolve_delta(delta: Delta) -> Delta: all_base_state_classes: dict[str, None] = {} -class BaseState(Base, ABC, extra=pydantic.Extra.allow): +class BaseState(EvenMoreBasicBaseState): """The state of the app.""" # A map from the var name to the var. @@ -352,31 +322,34 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): _potentially_dirty_states: ClassVar[set[str]] = set() # The parent state. - parent_state: BaseState | None = None + parent_state: BaseState | None = field(default=None, is_var=False) # The substates of the state. - substates: builtins.dict[str, BaseState] = {} + substates: builtins.dict[str, BaseState] = field( + default_factory=builtins.dict, is_var=False + ) # The set of dirty vars. - dirty_vars: set[str] = set() + dirty_vars: set[str] = field(default_factory=set, is_var=False) # The set of dirty substates. - dirty_substates: set[str] = set() + dirty_substates: set[str] = field(default_factory=set, is_var=False) # The routing path that triggered the state - router_data: builtins.dict[str, Any] = {} + router_data: builtins.dict[str, Any] = field( + default_factory=builtins.dict, is_var=False + ) # Per-instance copy of backend base variable values - _backend_vars: builtins.dict[str, Any] = {} + _backend_vars: builtins.dict[str, Any] = field( + default_factory=builtins.dict, is_var=False + ) # The router data for the current page - router: RouterData = RouterData() + router: Field[RouterData] = field(default_factory=RouterData) # Whether the state has ever been touched since instantiation. - _was_touched: bool = False - - # Whether this state class is a mixin and should not be instantiated. - _mixin: ClassVar[bool] = False + _was_touched: bool = field(default=False, is_var=False) # A special event handler for setting base vars. setvar: ClassVar[EventHandler] @@ -409,13 +382,11 @@ def __init__( "See https://reflex.dev/docs/state/ for further information." ) raise ReflexRuntimeError(msg) - if type(self)._mixin: + if self._mixin: msg = f"{type(self).__name__} is a state mixin and cannot be instantiated directly." raise ReflexRuntimeError(msg) kwargs["parent_state"] = parent_state - super().__init__() - for name, value in kwargs.items(): - setattr(self, name, value) + super().__init__(**kwargs) # Setup the substates (for memory state manager only). if init_substates: @@ -481,8 +452,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): super().__init_subclass__(**kwargs) - cls._mixin = mixin - if mixin: + if cls._mixin: return # Handle locally-defined states for pickling. @@ -548,9 +518,9 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): # Set the base and computed vars. cls.base_vars = { - f.name: get_var_for_field(cls, f) - for f in cls.get_fields().values() - if f.name not in cls.get_skip_vars() + name: get_var_for_field(cls, name, f) + for name, f in cls.get_fields().items() + if name not in cls.get_skip_vars() and f.is_var and not name.startswith("_") } cls.computed_vars = { v._js_expr: v._replace(merge_var_data=VarData.from_state(cls)) @@ -898,7 +868,7 @@ def get_parent_state(cls) -> type[BaseState] | None: if issubclass(base, BaseState) and base is not BaseState and not base._mixin ] if len(parent_states) >= 2: - msg = f"Only one parent state is allowed {parent_states}." + msg = f"Only one parent state of is allowed. Found {parent_states} parents of {cls}." raise ValueError(msg) # The first non-mixin state in the mro is our parent. for base in cls.mro()[1:]: @@ -1131,17 +1101,8 @@ def _set_default_value(cls, prop: Var): """ # Get the pydantic field for the var. field = cls.get_fields()[prop._var_field_name] - if field.required: - default_value = prop._get_default_value() - if default_value is not None: - field.required = False - field.default = default_value - if ( - not field.required - and field.default is None - and field.default_factory is None - and not types.is_optional(prop._var_type) - ): + + if field.default is None and not types.is_optional(prop._var_type): # Ensure frontend uses null coalescing when accessing. object.__setattr__(prop, "_var_type", prop._var_type | None) @@ -1160,7 +1121,7 @@ def _get_var_default(cls, name: str, annotation_value: Any) -> Any: return getattr(cls, name) except AttributeError: try: - return Var("", _var_type=annotation_value)._get_default_value() + return types.get_default_value_for_type(annotation_value) except TypeError: pass return None @@ -1276,10 +1237,6 @@ def __getattribute__(self, name: str) -> Any: Returns: The value of the var. """ - # If the state hasn't been initialized yet, return the default value. - if not super().__getattribute__("__dict__"): - return super().__getattribute__(name) - # Fast path for dunder if name.startswith("__"): return super().__getattribute__(name) @@ -1306,7 +1263,7 @@ def __getattribute__(self, name: str) -> Any: fn.__qualname__ = handler.fn.__qualname__ return fn - backend_vars = super().__getattribute__("_backend_vars") + backend_vars = super().__getattribute__("_backend_vars") or {} if name in backend_vars: value = backend_vars[name] else: @@ -1370,9 +1327,8 @@ def __setattr__(self, name: str, value: Any): fields = self.get_fields() - if name in fields: - field = fields[name] - field_type = _unwrap_field_type(true_type_for_pydantic_field(field)) + if (field := fields.get(name)) is not None and field.is_var: + field_type = field.outer_type_ if not _isinstance(value, field_type, nested=1, treat_var_as_type=False): console.error( f"Expected field '{type(self).__name__}.{name}' to receive type '{escape(str(field_type))}'," @@ -1380,7 +1336,7 @@ def __setattr__(self, name: str, value: Any): ) # Set the attribute. - super().__setattr__(name, value) + object.__setattr__(self, name, value) # Add the var to the dirty list. if name in self.base_vars: @@ -2058,7 +2014,7 @@ def get_value(self, key: str) -> Any: Returns: The value of the field. """ - value = super().get_value(key) + value = getattr(self, key) if isinstance(value, MutableProxy): return value.__wrapped__ return value @@ -2147,19 +2103,19 @@ def __getstate__(self): Returns: The state dict for serialization. """ - state = super().__getstate__() - state["__dict__"] = state["__dict__"].copy() - if state["__dict__"].get("parent_state") is not None: + state = self.__dict__ + state = state.copy() + if state.get("parent_state") is not None: # Do not serialize router data in substates (only the root state). - state["__dict__"].pop("router", None) - state["__dict__"].pop("router_data", None) + state.pop("router", None) + state.pop("router_data", None) # Never serialize parent_state or substates. - state["__dict__"].pop("parent_state", None) - state["__dict__"].pop("substates", None) - state["__dict__"].pop("_was_touched", None) + state.pop("parent_state", None) + state.pop("substates", None) + state.pop("_was_touched", None) # Remove all inherited vars. for inherited_var_name in self.inherited_vars: - state["__dict__"].pop(inherited_var_name, None) + state.pop(inherited_var_name, None) return state def __setstate__(self, state: dict[str, Any]): @@ -2170,9 +2126,10 @@ def __setstate__(self, state: dict[str, Any]): Args: state: The state dict for deserialization. """ - state["__dict__"]["parent_state"] = None - state["__dict__"]["substates"] = {} - super().__setstate__(state) + state["parent_state"] = None + state["substates"] = {} + for key, value in state.items(): + object.__setattr__(self, key, value) def _check_state_size( self, @@ -2213,17 +2170,11 @@ def _to_schema(cls) -> str: def _field_tuple( field_name: str, - ) -> tuple[str, str, Any, bool | None, Any]: + ) -> tuple[str, Any, Any]: model_field = cls.__fields__[field_name] return ( field_name, - model_field.name, _serialize_type(model_field.type_), - ( - model_field.required - if isinstance(model_field.required, bool) - else None - ), (model_field.default if is_serializable(model_field.default) else None), ) @@ -2518,7 +2469,7 @@ def __init__(self, *args, **kwargs): Raises: ReflexRuntimeError: If the ComponentState is initialized directly. """ - if type(self)._mixin: + if self._mixin: raise ReflexRuntimeError( f"{ComponentState.__name__} {type(self).__name__} is not meant to be initialized directly. " + "Use the `create` method to create a new instance and access the state via the `State` attribute." diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 6304018629a..5ae7ddd5ae1 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -9,7 +9,7 @@ import inspect import json import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from datetime import date, datetime, time, timedelta from enum import Enum from pathlib import Path @@ -335,6 +335,19 @@ def serialize_sequence(value: Sequence) -> list: return list(value) +@serializer(to=dict) +def serialize_mapping(value: Mapping) -> dict: + """Serialize a mapping type to a dictionary. + + Args: + value: The mapping instance to serialize. + + Returns: + A new dictionary containing the same key-value pairs as the input mapping. + """ + return {**value} + + @serializer(to=str) def serialize_datetime(dt: date | datetime | time | timedelta) -> str: """Serialize a datetime to a JSON string. diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 62fb0339a1c..71c11cd5289 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -4,6 +4,7 @@ import dataclasses import inspect +import sys import types from collections.abc import Callable, Iterable, Mapping, Sequence from functools import cached_property, lru_cache, wraps @@ -23,6 +24,7 @@ Tuple, TypeVar, Union, + _eval_type, # pyright: ignore [reportAttributeAccessIssue] _GenericAlias, # pyright: ignore [reportAttributeAccessIssue] _SpecialGenericAlias, # pyright: ignore [reportAttributeAccessIssue] get_args, @@ -160,11 +162,7 @@ def __call__( dict: Dict, # noqa: UP006 } -RESERVED_BACKEND_VAR_NAMES = { - "_abc_impl", - "_backend_vars", - "_was_touched", -} +RESERVED_BACKEND_VAR_NAMES = {"_abc_impl", "_backend_vars", "_was_touched", "_mixin"} class Unset: @@ -326,7 +324,11 @@ def is_optional(cls: GenericType) -> bool: Returns: Whether the class is an Optional. """ - return is_union(cls) and type(None) in get_args(cls) + return ( + cls is None + or cls is type(None) + or (is_union(cls) and type(None) in get_args(cls)) + ) def is_classvar(a_type: Any) -> bool: @@ -1193,3 +1195,103 @@ def typehint_issubclass( ) if accepted_arg is not Any ) + + +def resolve_annotations( + raw_annotations: Mapping[str, type[Any]], module_name: str | None +) -> dict[str, type[Any]]: + """Partially taken from typing.get_type_hints. + + Resolve string or ForwardRef annotations into type objects if possible. + + Args: + raw_annotations: The raw annotations to resolve. + module_name: The name of the module. + + Returns: + The resolved annotations. + """ + module = sys.modules.get(module_name, None) if module_name is not None else None + + base_globals: dict[str, Any] | None = ( + module.__dict__ if module is not None else None + ) + + annotations = {} + for name, value in raw_annotations.items(): + if isinstance(value, str): + if sys.version_info == (3, 10, 0): + value = ForwardRef(value, is_argument=False) + else: + value = ForwardRef(value, is_argument=False, is_class=True) + try: + if sys.version_info >= (3, 13): + value = _eval_type(value, base_globals, None, type_params=()) + else: + value = _eval_type(value, base_globals, None) + except NameError: + # this is ok, it can be fixed with update_forward_refs + pass + annotations[name] = value + return annotations + + +TYPES_THAT_HAS_DEFAULT_VALUE = (int, float, tuple, list, set, dict, str) + + +def get_default_value_for_type(t: GenericType) -> Any: + """Get the default value of the var. + + Args: + t: The type of the var. + + Returns: + The default value of the var, if it has one, else None. + + Raises: + ImportError: If the var is a dataframe and pandas is not installed. + """ + if is_optional(t): + return None + + origin = get_origin(t) if is_generic_alias(t) else t + if origin is Literal: + args = get_args(t) + return args[0] if args else None + if safe_issubclass(origin, TYPES_THAT_HAS_DEFAULT_VALUE): + return origin() + if safe_issubclass(origin, Mapping): + return {} + if is_dataframe(origin): + try: + import pandas as pd + + return pd.DataFrame() + except ImportError as e: + msg = "Please install pandas to use dataframes in your app." + raise ImportError(msg) from e + return None + + +IMMUTABLE_TYPES = ( + int, + float, + bool, + str, + bytes, + frozenset, + tuple, + type(None), +) + + +def is_immutable(i: Any) -> bool: + """Check if a value is immutable. + + Args: + i: The value to check. + + Returns: + Whether the value is immutable. + """ + return isinstance(i, IMMUTABLE_TYPES) diff --git a/reflex/vars/__init__.py b/reflex/vars/__init__.py index cb02319bc6a..66a3591fd8e 100644 --- a/reflex/vars/__init__.py +++ b/reflex/vars/__init__.py @@ -1,5 +1,7 @@ """Immutable-Based Var System.""" +from .base import BaseStateMeta as BaseStateMeta +from .base import EvenMoreBasicBaseState as EvenMoreBasicBaseState from .base import Field as Field from .base import LiteralVar as LiteralVar from .base import Var as Var diff --git a/reflex/vars/base.py b/reflex/vars/base.py index d3e106f1a69..4a400e26759 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -14,11 +14,14 @@ import string import uuid import warnings +from abc import ABCMeta from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence +from dataclasses import _MISSING_TYPE, MISSING from decimal import Decimal from types import CodeType, FunctionType from typing import ( # noqa: UP035 TYPE_CHECKING, + Annotated, Any, ClassVar, Dict, @@ -40,7 +43,7 @@ ) from rich.markup import escape -from typing_extensions import deprecated, override +from typing_extensions import dataclass_transform, deprecated, override from reflex import constants from reflex.base import Base @@ -952,48 +955,6 @@ def guess_type(self) -> Var: return self - def _get_default_value(self) -> Any: - """Get the default value of the var. - - Returns: - The default value of the var. - - Raises: - ImportError: If the var is a dataframe and pandas is not installed. - """ - if types.is_optional(self._var_type): - return None - - type_ = ( - get_origin(self._var_type) - if types.is_generic_alias(self._var_type) - else self._var_type - ) - if type_ is Literal: - args = get_args(self._var_type) - return args[0] if args else None - if safe_issubclass(type_, str): - return "" - if safe_issubclass(type_, types.get_args(int | float)): - return 0 - if safe_issubclass(type_, bool): - return False - if safe_issubclass(type_, list): - return [] - if safe_issubclass(type_, Mapping): - return {} - if safe_issubclass(type_, tuple): - return () - if types.is_dataframe(type_): - try: - import pandas as pd - - return pd.DataFrame() - except ImportError as e: - msg = "Please install pandas to use dataframes in your app." - raise ImportError(msg) from e - return set() if safe_issubclass(type_, set) else None - def _get_setter_name(self, include_state: bool = True) -> str: """Get the name of the var's generated setter function. @@ -3382,19 +3343,100 @@ def dispatch( MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None) V = TypeVar("V") + FIELD_TYPE = TypeVar("FIELD_TYPE") class Field(Generic[FIELD_TYPE]): - """Shadow class for Var to allow for type hinting in the IDE.""" + """A field for a state.""" - def __set__(self, instance: Any, value: FIELD_TYPE): - """Set the Var. + if TYPE_CHECKING: + type_: GenericType + default: FIELD_TYPE | _MISSING_TYPE + default_factory: Callable[[], FIELD_TYPE] | None + + def __init__( + self, + default: FIELD_TYPE | _MISSING_TYPE = MISSING, + default_factory: Callable[[], FIELD_TYPE] | None = None, + is_var: bool = True, + annotated_type: GenericType # pyright: ignore [reportRedeclaration] + | _MISSING_TYPE = MISSING, + ) -> None: + """Initialize the field. Args: - instance: The instance of the class setting the Var. - value: The value to set the Var to. + default: The default value for the field. + default_factory: The default factory for the field. + is_var: Whether the field is a Var. + annotated_type: The annotated type for the field. + """ + self.default = default + self.default_factory = default_factory + self.is_var = is_var + if annotated_type is not MISSING: + type_origin = get_origin(annotated_type) or annotated_type + if type_origin is Field and ( + args := getattr(annotated_type, "__args__", None) + ): + annotated_type: GenericType = args[0] + type_origin = get_origin(annotated_type) or annotated_type + + if self.default is MISSING and self.default_factory is None: + self.default = types.get_default_value_for_type(annotated_type) + if self.default is None and not types.is_optional(annotated_type): + annotated_type = annotated_type | None + self.outer_type_ = self.annotated_type = annotated_type + + if type_origin is Annotated: + type_origin = annotated_type.__origin__ # pyright: ignore [reportAttributeAccessIssue] + + self.type_ = self.type_origin = type_origin + else: + self.outer_type_ = self.annotated_type = self.type_ = self.type_origin = Any + + def default_value(self) -> FIELD_TYPE: + """Get the default value for the field. + + Returns: + The default value for the field. + + Raises: + ValueError: If no default value or factory is provided. """ + if self.default is not MISSING: + return self.default + if self.default_factory is not None: + return self.default_factory() + msg = "No default value or factory provided." + raise ValueError(msg) + + def __repr__(self) -> str: + """Represent the field in a readable format. + + Returns: + The string representation of the field. + """ + annotated_type_str = ( + f", annotated_type={self.annotated_type!r}" + if self.annotated_type is not MISSING + else "" + ) + if self.default is not MISSING: + return f"Field(default={self.default!r}, is_var={self.is_var}{annotated_type_str})" + return f"Field(default_factory={self.default_factory!r}, is_var={self.is_var}{annotated_type_str})" + + if TYPE_CHECKING: + + def __set__(self, instance: Any, value: FIELD_TYPE): + """Set the Var. + + Args: + instance: The instance of the class setting the Var. + value: The value to set the Var to. + + # noqa: DAR101 self + """ @overload def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ... @@ -3484,13 +3526,252 @@ def __get__(self, instance: Any, owner: Any): # pyright: ignore [reportInconsis """ -def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]: - """Create a Field with a value. +@overload +def field( + default: FIELD_TYPE | _MISSING_TYPE = MISSING, + *, + is_var: Literal[False], + default_factory: Callable[[], FIELD_TYPE] | None = None, +) -> FIELD_TYPE: ... + + +@overload +def field( + default: FIELD_TYPE | _MISSING_TYPE = MISSING, + *, + default_factory: Callable[[], FIELD_TYPE] | None = None, + is_var: Literal[True] = True, +) -> Field[FIELD_TYPE]: ... + + +def field( + default: FIELD_TYPE | _MISSING_TYPE = MISSING, + *, + default_factory: Callable[[], FIELD_TYPE] | None = None, + is_var: bool = True, +) -> Field[FIELD_TYPE] | FIELD_TYPE: + """Create a field for a state. Args: - value: The value of the Field. + default: The default value for the field. + default_factory: The default factory for the field. + is_var: Whether the field is a Var. Returns: - The Field. + The field for the state. + + Raises: + ValueError: If both default and default_factory are specified. """ - return value # pyright: ignore [reportReturnType] + if default is not MISSING and default_factory is not None: + msg = "cannot specify both default and default_factory" + raise ValueError(msg) + if default is not MISSING and not types.is_immutable(default): + console.warn( + "Mutable default values are not recommended. " + "Use default_factory instead to avoid unexpected behavior." + ) + return Field( + default_factory=functools.partial(copy.deepcopy, default), + is_var=is_var, + ) + return Field( + default=default, + default_factory=default_factory, + is_var=is_var, + ) + + +@dataclass_transform(kw_only_default=True, field_specifiers=(field,)) +class BaseStateMeta(ABCMeta): + """Meta class for BaseState.""" + + if TYPE_CHECKING: + __inherited_fields__: Mapping[str, Field] + __own_fields__: dict[str, Field] + __fields__: dict[str, Field] + + # Whether this state class is a mixin and should not be instantiated. + _mixin: bool = False + + def __new__( + cls, + name: str, + bases: tuple[type], + namespace: dict[str, Any], + mixin: bool = False, + ) -> type: + """Create a new class. + + Args: + name: The name of the class. + bases: The bases of the class. + namespace: The namespace of the class. + mixin: Whether the class is a mixin and should not be instantiated. + + Returns: + The new class. + """ + state_bases = [ + base for base in bases if issubclass(base, EvenMoreBasicBaseState) + ] + mixin = mixin or ( + bool(state_bases) and all(base._mixin for base in state_bases) + ) + # Add the field to the class + inherited_fields: dict[str, Field] = {} + own_fields: dict[str, Field] = {} + resolved_annotations = types.resolve_annotations( + namespace.get("__annotations__", {}), namespace["__module__"] + ) + + for base in bases[::-1]: + if hasattr(base, "__inherited_fields__"): + inherited_fields.update(base.__inherited_fields__) + for base in bases[::-1]: + if hasattr(base, "__own_fields__"): + inherited_fields.update(base.__own_fields__) + + for key, value in [ + (key, value) + for key, value in namespace.items() + if key not in resolved_annotations + ]: + if isinstance(value, Field): + if value.annotated_type is not Any: + new_value = value + elif value.default is not MISSING: + new_value = Field( + default=value.default, + is_var=value.is_var, + annotated_type=figure_out_type(value.default), + ) + else: + new_value = Field( + default_factory=value.default_factory, + is_var=value.is_var, + annotated_type=Any, + ) + elif ( + not key.startswith("__") + and not callable(value) + and not isinstance(value, (staticmethod, classmethod, property, Var)) + ): + if types.is_immutable(value): + new_value = Field( + default=value, + annotated_type=figure_out_type(value), + ) + else: + new_value = Field( + default_factory=functools.partial(copy.deepcopy, value), + annotated_type=figure_out_type(value), + ) + else: + continue + + own_fields[key] = new_value + + for key, annotation in resolved_annotations.items(): + value = namespace.get(key, MISSING) + + if types.is_classvar(annotation): + # If the annotation is a classvar, skip it. + continue + + if value is MISSING: + value = Field( + annotated_type=annotation, + ) + elif not isinstance(value, Field): + if types.is_immutable(value): + value = Field( + default=value, + annotated_type=annotation, + ) + else: + value = Field( + default_factory=functools.partial(copy.deepcopy, value), + annotated_type=annotation, + ) + else: + value = Field( + default=value.default, + default_factory=value.default_factory, + is_var=value.is_var, + annotated_type=annotation, + ) + + own_fields[key] = value + + namespace["__own_fields__"] = own_fields + namespace["__inherited_fields__"] = inherited_fields + namespace["__fields__"] = inherited_fields | own_fields + namespace["_mixin"] = mixin + return super().__new__(cls, name, bases, namespace) + + +class EvenMoreBasicBaseState(metaclass=BaseStateMeta): + """A simplified base state class that provides basic functionality.""" + + def __init__( + self, + **kwargs, + ): + """Initialize the state with the given kwargs. + + Args: + **kwargs: The kwargs to pass to the state. + """ + super().__init__() + for key, value in kwargs.items(): + object.__setattr__(self, key, value) + for name, value in type(self).get_fields().items(): + if name not in kwargs: + default_value = value.default_value() + object.__setattr__(self, name, default_value) + + def set(self, **kwargs): + """Mutate the state by setting the given kwargs. Returns the state. + + Args: + **kwargs: The kwargs to set. + + Returns: + The state with the fields set to the given kwargs. + """ + for key, value in kwargs.items(): + setattr(self, key, value) + return self + + @classmethod + def get_fields(cls) -> Mapping[str, Field]: + """Get the fields of the component. + + Returns: + The fields of the component. + """ + return cls.__fields__ + + @classmethod + def add_field(cls, var: Var, default_value: Any): + """Add a field to the class after class definition. + + Used by State.add_var() to correctly handle the new variable. + + Args: + var: The variable to add a field for. + default_value: The default value of the field. + """ + var_name = var._var_field_name + if types.is_immutable(default_value): + new_field = Field( + default=default_value, + annotated_type=var._var_type, + ) + else: + new_field = Field( + default_factory=functools.partial(copy.deepcopy, default_value), + annotated_type=var._var_type, + ) + cls.__fields__[var_name] = new_field diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index e0568019d8b..996b7abd196 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -432,6 +432,16 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive driver: WebDriver instance. """ assert upload_file.app_instance is not None + driver.execute_cdp_cmd("Network.enable", {}) + driver.execute_cdp_cmd( + "Network.emulateNetworkConditions", + { + "offline": False, + "downloadThroughput": 1024 * 1024 / 8, # 1 Mbps + "uploadThroughput": 1024 * 1024 / 8, # 1 Mbps + "latency": 200, # 200ms + }, + ) token = poll_for_token(driver, upload_file) state_name = upload_file.get_state_name("_upload_state") state_full_name = upload_file.get_full_state_name(["_upload_state"]) @@ -444,16 +454,16 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive exp_name = "large.txt" target_file = tmp_path / exp_name with target_file.open("wb") as f: - f.seek(1024 * 1024 * 256) + f.seek(1024 * 1024) # 1 MB file, should upload in ~8 seconds f.write(b"0") upload_box.send_keys(str(target_file)) upload_button.click() - await asyncio.sleep(0.3) + await asyncio.sleep(1) cancel_button.click() # Wait a bit for the upload to get cancelled. - await asyncio.sleep(0.5) + await asyncio.sleep(12) # Get interim progress dicts saved in the on_upload_progress handler. async def _progress_dicts(): diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 0163fccde8b..50c3716104c 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -291,7 +291,7 @@ def test_base_class_vars(test_state): cls = type(test_state) for field in fields: - if field in test_state.get_skip_vars(): + if field.startswith("_") or field in cls.get_skip_vars(): continue prop = getattr(cls, field) assert isinstance(prop, Var) @@ -2776,7 +2776,7 @@ class UnionState(BaseState): assert ( str(UnionState.c3.c2r.c1r.foo) == f'{UnionState.c3!s}?.["c2r"]["c1r"]["foo"]' # pyright: ignore [reportOptionalMemberAccess] ) - assert str(UnionState.c3i.c2) == f'{UnionState.c3i!s}["c2"]' + assert str(UnionState.c3i.c2) == f'{UnionState.c3i!s}?.["c2"]' assert str(UnionState.c3r.c2) == f'{UnionState.c3r!s}["c2"]' assert UnionState.custom_union.foo is not None # pyright: ignore [reportAttributeAccessIssue] assert UnionState.custom_union.c1 is not None # pyright: ignore [reportAttributeAccessIssue] diff --git a/tests/units/test_var.py b/tests/units/test_var.py index de2896311bd..27a4bd985d4 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -19,6 +19,7 @@ UntypedComputedVarError, ) from reflex.utils.imports import ImportVar +from reflex.utils.types import get_default_value_for_type from reflex.vars import VarData from reflex.vars.base import ( ComputedVar, @@ -244,7 +245,7 @@ def test_default_value(prop: Var, expected): prop: The var to test. expected: The expected default value. """ - assert prop._get_default_value() == expected + assert get_default_value_for_type(prop._var_type) == expected @pytest.mark.parametrize(