Skip to content
Merged
2 changes: 1 addition & 1 deletion pyi_hashes.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,6 @@
"reflex/components/recharts/general.pyi": "e992c57e3df1dc47b6a544c535d87319",
"reflex/components/recharts/polar.pyi": "cb74280e343562f52f087973eff27389",
"reflex/components/recharts/recharts.pyi": "2ab851e1b2cb63ae690e89144319a02c",
"reflex/components/sonner/toast.pyi": "a9f5529905f171fddd93e78b6e7ad53e",
"reflex/components/sonner/toast.pyi": "1681a04ee927febbe7b015dc224ba00b",
"reflex/components/suneditor/editor.pyi": "447a2c210d7218816236b7e31e6c135d"
}
148 changes: 71 additions & 77 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,7 @@
from functools import wraps
from hashlib import md5
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Generic,
TypeVar,
cast,
get_args,
get_origin,
)
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast, get_args, get_origin

from rich.markup import escape
from typing_extensions import dataclass_transform
Expand All @@ -33,6 +23,7 @@
from reflex.compiler.templates import STATEFUL_COMPONENT
from reflex.components.core.breakpoints import Breakpoints
from reflex.components.dynamic import load_dynamic_serializer
from reflex.components.field import BaseField, FieldBasedMeta
from reflex.components.tags import Tag
from reflex.constants import (
Dirs,
Expand Down Expand Up @@ -75,7 +66,7 @@
FIELD_TYPE = TypeVar("FIELD_TYPE")


class ComponentField(Generic[FIELD_TYPE]):
class ComponentField(BaseField[FIELD_TYPE]):
"""A field for a component."""

def __init__(
Expand All @@ -93,30 +84,8 @@ def __init__(
is_javascript: Whether the field is a javascript property.
annotated_type: The annotated type for the field.
"""
self.default = default
self.default_factory = default_factory
super().__init__(default, default_factory, annotated_type)
self.is_javascript = is_javascript
self.outer_type_ = self.annotated_type = annotated_type
type_origin = get_origin(annotated_type) or annotated_type
if type_origin is Annotated:
type_origin = annotated_type.__origin__ # pyright: ignore [reportAttributeAccessIssue]
self.type_ = self.type_origin = type_origin

def default_value(self) -> FIELD_TYPE:
"""Get the default value for the field.

Returns:
The default value for the field.

Raises:
ValueError: If no default value or factory is provided.
"""
if self.default is not MISSING:
return self.default
if self.default_factory is not None:
return self.default_factory()
msg = "No default value or factory provided."
raise ValueError(msg)

def __repr__(self) -> str:
"""Represent the field in a readable format.
Expand Down Expand Up @@ -163,7 +132,7 @@ def field(


@dataclass_transform(kw_only_default=True, field_specifiers=(field,))
class BaseComponentMeta(ABCMeta):
class BaseComponentMeta(FieldBasedMeta, ABCMeta):
"""Meta class for BaseComponent."""

if TYPE_CHECKING:
Expand All @@ -172,46 +141,24 @@ class BaseComponentMeta(ABCMeta):
_fields: Mapping[str, ComponentField]
_js_fields: Mapping[str, ComponentField]

def __new__(cls, name: str, bases: tuple[type], namespace: dict[str, Any]) -> type:
"""Create a new class.

Args:
name: The name of the class.
bases: The bases of the class.
namespace: The namespace of the class.

Returns:
The new class.
"""
# Add the field to the class
inherited_fields: dict[str, ComponentField] = {}
own_fields: dict[str, ComponentField] = {}
resolved_annotations = types.resolve_annotations(
@classmethod
def _resolve_annotations(
cls, namespace: dict[str, Any], name: str
) -> dict[str, Any]:
return types.resolve_annotations(
namespace.get("__annotations__", {}), namespace["__module__"]
)

for base in bases[::-1]:
if hasattr(base, "_inherited_fields"):
inherited_fields.update(base._inherited_fields)
for base in bases[::-1]:
if hasattr(base, "_own_fields"):
inherited_fields.update(base._own_fields)

for key, value, inherited_field in [
(key, value, inherited_field)
for key, value in namespace.items()
if key not in resolved_annotations
and ((inherited_field := inherited_fields.get(key)) is not None)
]:
new_value = ComponentField(
default=value,
is_javascript=inherited_field.is_javascript,
annotated_type=inherited_field.annotated_type,
)

own_fields[key] = new_value
@classmethod
def _process_annotated_fields(
cls,
namespace: dict[str, Any],
annotations: dict[str, Any],
inherited_fields: dict[str, ComponentField],
) -> dict[str, ComponentField]:
own_fields: dict[str, ComponentField] = {}

for key, annotation in resolved_annotations.items():
for key, annotation in annotations.items():
value = namespace.get(key, MISSING)

if types.is_classvar(annotation):
Expand Down Expand Up @@ -244,16 +191,63 @@ def __new__(cls, name: str, bases: tuple[type], namespace: dict[str, Any]) -> ty

own_fields[key] = value

namespace["_own_fields"] = own_fields
namespace["_inherited_fields"] = inherited_fields
all_fields = inherited_fields | own_fields
namespace["_fields"] = all_fields
return own_fields

@classmethod
def _create_field(
cls,
annotated_type: Any,
default: Any = MISSING,
default_factory: Callable[[], Any] | None = None,
) -> ComponentField:
return ComponentField(
annotated_type=annotated_type,
default=default,
default_factory=default_factory,
is_javascript=True, # Default for components
)

@classmethod
def _process_field_overrides(
cls,
namespace: dict[str, Any],
annotations: dict[str, Any],
inherited_fields: dict[str, Any],
) -> dict[str, ComponentField]:
own_fields: dict[str, ComponentField] = {}

for key, value, inherited_field in [
(key, value, inherited_field)
for key, value in namespace.items()
if key not in annotations
and ((inherited_field := inherited_fields.get(key)) is not None)
]:
new_field = ComponentField(
default=value,
is_javascript=inherited_field.is_javascript,
annotated_type=inherited_field.annotated_type,
)
own_fields[key] = new_field

return own_fields

@classmethod
def _finalize_fields(
cls,
namespace: dict[str, Any],
inherited_fields: dict[str, ComponentField],
own_fields: dict[str, ComponentField],
) -> None:
# Call parent implementation
super()._finalize_fields(namespace, inherited_fields, own_fields)

# Add JavaScript fields mapping
all_fields = namespace["_fields"]
namespace["_js_fields"] = {
key: value
for key, value in all_fields.items()
if value.is_javascript is True
}
return super().__new__(cls, name, bases, namespace)


class BaseComponent(metaclass=BaseComponentMeta):
Expand Down
175 changes: 175 additions & 0 deletions reflex/components/field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Shared field infrastructure for components and props."""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import _MISSING_TYPE, MISSING
from typing import Annotated, Any, Generic, TypeVar, get_origin

FIELD_TYPE = TypeVar("FIELD_TYPE")


class BaseField(Generic[FIELD_TYPE]):
"""Base field class used by internal metadata classes."""

def __init__(
self,
default: FIELD_TYPE | _MISSING_TYPE = MISSING,
default_factory: Callable[[], FIELD_TYPE] | None = None,
annotated_type: type[Any] | _MISSING_TYPE = MISSING,
) -> None:
"""Initialize the field.

Args:
default: The default value for the field.
default_factory: The default factory for the field.
annotated_type: The annotated type for the field.
"""
self.default = default
self.default_factory = default_factory
self.outer_type_ = self.annotated_type = annotated_type

# Process type annotation
type_origin = get_origin(annotated_type) or annotated_type
if type_origin is Annotated:
type_origin = annotated_type.__origin__ # pyright: ignore [reportAttributeAccessIssue]
# For Annotated types, use the actual type inside the annotation
self.type_ = annotated_type
else:
# For other types (including Union), preserve the original type
self.type_ = annotated_type
self.type_origin = type_origin

def default_value(self) -> FIELD_TYPE:
"""Get the default value for the field.

Returns:
The default value for the field.

Raises:
ValueError: If no default value or factory is provided.
"""
if self.default is not MISSING:
return self.default
if self.default_factory is not None:
return self.default_factory()
msg = "No default value or factory provided."
raise ValueError(msg)


class FieldBasedMeta(type):
"""Shared metaclass for field-based classes like components and props.

Provides common field inheritance and processing logic for both
PropsBaseMeta and BaseComponentMeta.
"""

def __new__(cls, name: str, bases: tuple[type], namespace: dict[str, Any]) -> type:
"""Create a new field-based class.

Args:
name: The name of the class.
bases: The base classes.
namespace: The class namespace.

Returns:
The new class.
"""
# Collect inherited fields from base classes
inherited_fields = cls._collect_inherited_fields(bases)

# Get annotations from the namespace
annotations = cls._resolve_annotations(namespace, name)

# Process field overrides (fields with values but no annotations)
own_fields = cls._process_field_overrides(
namespace, annotations, inherited_fields
)

# Process annotated fields
own_fields.update(
cls._process_annotated_fields(namespace, annotations, inherited_fields)
)

# Finalize fields and store on class
cls._finalize_fields(namespace, inherited_fields, own_fields)

return super().__new__(cls, name, bases, namespace)

@classmethod
def _collect_inherited_fields(cls, bases: tuple[type]) -> dict[str, Any]:
inherited_fields: dict[str, Any] = {}

# Collect inherited fields from base classes
for base in bases[::-1]:
if hasattr(base, "_inherited_fields"):
inherited_fields.update(base._inherited_fields)
for base in bases[::-1]:
if hasattr(base, "_own_fields"):
inherited_fields.update(base._own_fields)

return inherited_fields

@classmethod
def _resolve_annotations(
cls, namespace: dict[str, Any], name: str
) -> dict[str, Any]:
return namespace.get("__annotations__", {})

@classmethod
def _process_field_overrides(
cls,
namespace: dict[str, Any],
annotations: dict[str, Any],
inherited_fields: dict[str, Any],
) -> dict[str, Any]:
own_fields: dict[str, Any] = {}

for key, value in namespace.items():
if key not in annotations and key in inherited_fields:
inherited_field = inherited_fields[key]
new_field = cls._create_field(
annotated_type=inherited_field.annotated_type,
default=value,
default_factory=None,
)
own_fields[key] = new_field

return own_fields

@classmethod
def _process_annotated_fields(
cls,
namespace: dict[str, Any],
annotations: dict[str, Any],
inherited_fields: dict[str, Any],
) -> dict[str, Any]:
raise NotImplementedError

@classmethod
def _create_field(
cls,
annotated_type: Any,
default: Any = MISSING,
default_factory: Callable[[], Any] | None = None,
) -> Any:
raise NotImplementedError

@classmethod
def _finalize_fields(
cls,
namespace: dict[str, Any],
inherited_fields: dict[str, Any],
own_fields: dict[str, Any],
) -> None:
# Combine all fields
all_fields = inherited_fields | own_fields

# Set field names for compatibility
for field_name, field in all_fields.items():
field._name = field_name

# Store field mappings on the class
namespace["_own_fields"] = own_fields
namespace["_inherited_fields"] = inherited_fields
namespace["_fields"] = all_fields
Loading