diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index 0b7ac34ddfb..d700d22747d 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -38,6 +38,84 @@ NO_PROPS_TAGS = ("ul", "ol", "li") +def _h1(value: object): + from reflex.components.radix.themes.typography.heading import Heading + + return Heading.create(value, as_="h1", size="6", margin_y="0.5em") + + +def _h2(value: object): + from reflex.components.radix.themes.typography.heading import Heading + + return Heading.create(value, as_="h2", size="5", margin_y="0.5em") + + +def _h3(value: object): + from reflex.components.radix.themes.typography.heading import Heading + + return Heading.create(value, as_="h3", size="4", margin_y="0.5em") + + +def _h4(value: object): + from reflex.components.radix.themes.typography.heading import Heading + + return Heading.create(value, as_="h4", size="3", margin_y="0.5em") + + +def _h5(value: object): + from reflex.components.radix.themes.typography.heading import Heading + + return Heading.create(value, as_="h5", size="2", margin_y="0.5em") + + +def _h6(value: object): + from reflex.components.radix.themes.typography.heading import Heading + + return Heading.create(value, as_="h6", size="1", margin_y="0.5em") + + +def _p(value: object): + from reflex.components.radix.themes.typography.text import Text + + return Text.create(value, margin_y="1em") + + +def _ul(value: object): + from reflex.components.radix.themes.layout.list import UnorderedList + + return UnorderedList.create(value, margin_y="1em") + + +def _ol(value: object): + from reflex.components.radix.themes.layout.list import OrderedList + + return OrderedList.create(value, margin_y="1em") + + +def _li(value: object): + from reflex.components.radix.themes.layout.list import ListItem + + return ListItem.create(value, margin_y="0.5em") + + +def _a(value: object): + from reflex.components.radix.themes.typography.link import Link + + return Link.create(value) + + +def _code(value: object): + from reflex.components.radix.themes.typography.code import Code + + return Code.create(value) + + +def _codeblock(value: object, **props): + from reflex.components.datadisplay.code import CodeBlock + + return CodeBlock.create(value, margin_y="1em", wrap_long_lines=True, **props) + + # Component Mapping @lru_cache def get_base_component_map() -> dict[str, Callable]: @@ -46,33 +124,20 @@ def get_base_component_map() -> dict[str, Callable]: Returns: The base component map. """ - from reflex.components.datadisplay.code import CodeBlock - from reflex.components.radix.themes.layout.list import ( - ListItem, - OrderedList, - UnorderedList, - ) - from reflex.components.radix.themes.typography.code import Code - from reflex.components.radix.themes.typography.heading import Heading - from reflex.components.radix.themes.typography.link import Link - from reflex.components.radix.themes.typography.text import Text - return { - "h1": lambda value: Heading.create(value, as_="h1", size="6", margin_y="0.5em"), - "h2": lambda value: Heading.create(value, as_="h2", size="5", margin_y="0.5em"), - "h3": lambda value: Heading.create(value, as_="h3", size="4", margin_y="0.5em"), - "h4": lambda value: Heading.create(value, as_="h4", size="3", margin_y="0.5em"), - "h5": lambda value: Heading.create(value, as_="h5", size="2", margin_y="0.5em"), - "h6": lambda value: Heading.create(value, as_="h6", size="1", margin_y="0.5em"), - "p": lambda value: Text.create(value, margin_y="1em"), - "ul": lambda value: UnorderedList.create(value, margin_y="1em"), - "ol": lambda value: OrderedList.create(value, margin_y="1em"), - "li": lambda value: ListItem.create(value, margin_y="0.5em"), - "a": lambda value: Link.create(value), - "code": lambda value: Code.create(value), - "codeblock": lambda value, **props: CodeBlock.create( - value, margin_y="1em", wrap_long_lines=True, **props - ), + "h1": _h1, + "h2": _h2, + "h3": _h3, + "h4": _h4, + "h5": _h5, + "h6": _h6, + "p": _p, + "ul": _ul, + "ol": _ol, + "li": _li, + "a": _a, + "code": _code, + "codeblock": _codeblock, } @@ -413,7 +478,16 @@ def _get_map_fn_custom_code_from_children( @staticmethod def _component_map_hash(component_map: dict) -> str: inp = str( - {tag: component(_MOCK_ARG) for tag, component in component_map.items()} + { + tag: ( + f"{component.__module__}.{component.__qualname__}" + if ( + "<" not in component.__name__ + ) # simple way to check against lambdas + else component(_MOCK_ARG) + ) + for tag, component in component_map.items() + } ).encode() return md5(inp).hexdigest() diff --git a/reflex/utils/types.py b/reflex/utils/types.py index bfd2d858bb0..4c72118d83d 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -291,6 +291,7 @@ def is_literal(cls: GenericType) -> bool: return getattr(cls, "__origin__", None) is Literal +@lru_cache def has_args(cls: type) -> bool: """Check if the class has generic parameters. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 0a54bd845d6..9fe3dcae4f3 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1527,6 +1527,82 @@ def _create_literal_var( def __post_init__(self): """Post-initialize the var.""" + @classmethod + def _get_all_var_data_without_creating_var( + cls, + value: Any, + ) -> VarData | None: + return cls.create(value)._get_all_var_data() + + @classmethod + def _get_all_var_data_without_creating_var_dispatch( + cls, + value: Any, + ) -> VarData | None: + """Get all the var data without creating a var. + + Args: + value: The value to get the var data from. + + Returns: + The var data or None. + + Raises: + TypeError: If the value is not a supported type for LiteralVar. + """ + from .object import LiteralObjectVar + from .sequence import LiteralStringVar + + if isinstance(value, Var): + return value._get_all_var_data() + + for literal_subclass, var_subclass in _var_literal_subclasses[::-1]: + if isinstance(value, var_subclass.python_types): + return literal_subclass._get_all_var_data_without_creating_var(value) + + if ( + (as_var_method := getattr(value, "_as_var", None)) is not None + and callable(as_var_method) + and isinstance((resulting_var := as_var_method()), Var) + ): + return resulting_var._get_all_var_data() + + from reflex.event import EventHandler + from reflex.utils.format import get_event_handler_parts + + if isinstance(value, EventHandler): + return Var( + _js_expr=".".join(filter(None, get_event_handler_parts(value))) + )._get_all_var_data() + + serialized_value = serializers.serialize(value) + if serialized_value is not None: + if isinstance(serialized_value, Mapping): + return LiteralObjectVar._get_all_var_data_without_creating_var( + serialized_value + ) + if isinstance(serialized_value, str): + return LiteralStringVar._get_all_var_data_without_creating_var( + serialized_value + ) + return LiteralVar._get_all_var_data_without_creating_var_dispatch( + serialized_value + ) + + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return LiteralObjectVar._get_all_var_data_without_creating_var( + { + k: (None if callable(v) else v) + for k, v in dataclasses.asdict(value).items() + } + ) + + if isinstance(value, range): + return None + + msg = f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}." + raise TypeError(msg) + @property def _var_value(self) -> Any: msg = "LiteralVar subclasses must implement the _var_value property." @@ -1688,30 +1764,30 @@ def figure_out_type(value: Any) -> types.GenericType: Returns: The type of the value. """ - if isinstance(value, Var): - return value._var_type - type_ = type(value) - if has_args(type_): - return type_ - if isinstance(value, list): - if not value: - return Sequence[NoReturn] - return Sequence[unionize(*(figure_out_type(v) for v in value))] - if isinstance(value, set): - return set[unionize(*(figure_out_type(v) for v in value))] - if isinstance(value, tuple): - if not value: - return tuple[NoReturn, ...] - if len(value) <= 5: - return tuple[tuple(figure_out_type(v) for v in value)] - return tuple[unionize(*(figure_out_type(v) for v in value)), ...] - if isinstance(value, Mapping): - if not value: - return Mapping[NoReturn, NoReturn] - return Mapping[ - unionize(*(figure_out_type(k) for k in value)), - unionize(*(figure_out_type(v) for v in value.values())), - ] + if isinstance(value, (list, set, tuple, Mapping, Var)): + if isinstance(value, Var): + return value._var_type + if has_args(value_type := type(value)): + return value_type + if isinstance(value, list): + if not value: + return Sequence[NoReturn] + return Sequence[unionize(*{figure_out_type(v) for v in value[:100]})] + if isinstance(value, set): + return set[unionize(*{figure_out_type(v) for v in value})] + if isinstance(value, tuple): + if not value: + return tuple[NoReturn, ...] + if len(value) <= 5: + return tuple[tuple(figure_out_type(v) for v in value)] + return tuple[unionize(*{figure_out_type(v) for v in value[:100]}), ...] + if isinstance(value, Mapping): + if not value: + return Mapping[NoReturn, NoReturn] + return Mapping[ + unionize(*{figure_out_type(k) for k in list(value.keys())[:100]}), + unionize(*{figure_out_type(v) for v in list(value.values())[:100]}), + ] return type(value) @@ -2883,6 +2959,10 @@ def json(self) -> str: """ return "null" + @classmethod + def _get_all_var_data_without_creating_var(cls, value: None) -> VarData | None: + return None + @classmethod def create( cls, diff --git a/reflex/vars/color.py b/reflex/vars/color.py index f699eac35ad..0684e47990d 100644 --- a/reflex/vars/color.py +++ b/reflex/vars/color.py @@ -29,6 +29,23 @@ class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar): _var_value: Color = dataclasses.field(default_factory=lambda: Color(color="black")) + @classmethod + def _get_all_var_data_without_creating_var( + cls, + value: Color, + ) -> VarData | None: + return VarData.merge( + LiteralStringVar._get_all_var_data_without_creating_var(value.color) + if isinstance(value.color, str) + else value.color._get_all_var_data(), + value.alpha._get_all_var_data() + if not isinstance(value.alpha, bool) + else None, + value.shade._get_all_var_data() + if not isinstance(value.shade, int) + else None, + ) + @classmethod def create( cls, @@ -111,14 +128,17 @@ def _cached_get_all_var_data(self) -> VarData | None: The var data. """ return VarData.merge( - *[ - LiteralVar.create(var)._get_all_var_data() - for var in ( - self._var_value.color, - self._var_value.alpha, - self._var_value.shade, - ) - ], + LiteralStringVar._get_all_var_data_without_creating_var( + self._var_value.color + ) + if isinstance(self._var_value.color, str) + else self._var_value.color._get_all_var_data(), + self._var_value.alpha._get_all_var_data() + if not isinstance(self._var_value.alpha, bool) + else None, + self._var_value.shade._get_all_var_data() + if not isinstance(self._var_value.shade, int) + else None, self._var_data, ) diff --git a/reflex/vars/datetime.py b/reflex/vars/datetime.py index f71172e9c51..89a787e3e06 100644 --- a/reflex/vars/datetime.py +++ b/reflex/vars/datetime.py @@ -174,10 +174,14 @@ def date_compare_operation( class LiteralDatetimeVar(LiteralVar, DateTimeVar): """Base class for immutable datetime and date vars.""" - _var_value: datetime | date = dataclasses.field(default=datetime.now()) + _var_value: date = dataclasses.field(default=datetime.now()) @classmethod - def create(cls, value: datetime | date, _var_data: VarData | None = None): + def _get_all_var_data_without_creating_var(cls, value: date) -> VarData | None: + return None + + @classmethod + def create(cls, value: date, _var_data: VarData | None = None): """Create a new instance of the class. Args: diff --git a/reflex/vars/number.py b/reflex/vars/number.py index 93e60beb766..4f639c799d0 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -973,6 +973,20 @@ def __hash__(self) -> int: """ return hash((type(self).__name__, self._var_value)) + @classmethod + def _get_all_var_data_without_creating_var( + cls, value: float | int | decimal.Decimal + ) -> VarData | None: + """Get all the var data without creating the var. + + Args: + value: The value of the var. + + Returns: + The var data. + """ + return None + @classmethod def create( cls, value: float | int | decimal.Decimal, _var_data: VarData | None = None @@ -1027,6 +1041,18 @@ def __hash__(self) -> int: """ return hash((type(self).__name__, self._var_value)) + @classmethod + def _get_all_var_data_without_creating_var(cls, value: bool) -> VarData | None: + """Get all the var data without creating the var. + + Args: + value: The value of the var. + + Returns: + The var data. + """ + return None + @classmethod def create(cls, value: bool, _var_data: VarData | None = None): """Create the boolean var. diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 026bae5aa47..e5df56540df 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -40,7 +40,7 @@ var_operation_return, ) from .number import BooleanVar, NumberVar, raise_unsupported_operand_types -from .sequence import ArrayVar, StringVar +from .sequence import ArrayVar, LiteralArrayVar, StringVar OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True) @@ -437,6 +437,24 @@ def __hash__(self) -> int: """ return hash((type(self).__name__, self._js_expr)) + @classmethod + def _get_all_var_data_without_creating_var( + cls, + value: Mapping, + ) -> VarData | None: + """Get all the var data without creating a var. + + Args: + value: The value to get the var data from. + + Returns: + The var data. + """ + return VarData.merge( + LiteralArrayVar._get_all_var_data_without_creating_var(value), + LiteralArrayVar._get_all_var_data_without_creating_var(value.values()), + ) + @cached_property_no_lock def _cached_get_all_var_data(self) -> VarData | None: """Get all the var data. @@ -445,11 +463,10 @@ def _cached_get_all_var_data(self) -> VarData | None: The var data. """ return VarData.merge( - *[LiteralVar.create(var)._get_all_var_data() for var in self._var_value], - *[ - LiteralVar.create(var)._get_all_var_data() - for var in self._var_value.values() - ], + LiteralArrayVar._get_all_var_data_without_creating_var(self._var_value), + LiteralArrayVar._get_all_var_data_without_creating_var( + self._var_value.values() + ), self._var_data, ) diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index dc9018dd0c3..135f58377f6 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -495,6 +495,23 @@ def _cached_var_name(self) -> str: + "]" ) + @classmethod + def _get_all_var_data_without_creating_var(cls, value: Iterable) -> VarData | None: + """Get all the VarData associated with the Var without creating a Var. + + Args: + value: The value to get the VarData for. + + Returns: + The VarData associated with the Var. + """ + return VarData.merge( + *[ + LiteralVar._get_all_var_data_without_creating_var_dispatch(element) + for element in value + ] + ) + @cached_property_no_lock def _cached_get_all_var_data(self) -> VarData | None: """Get all the VarData associated with the Var. @@ -504,7 +521,7 @@ def _cached_get_all_var_data(self) -> VarData | None: """ return VarData.merge( *[ - LiteralVar.create(element)._get_all_var_data() + LiteralVar._get_all_var_data_without_creating_var_dispatch(element) for element in self._var_value ], self._var_data, @@ -1147,6 +1164,20 @@ class LiteralStringVar(LiteralVar, StringVar[str]): _var_value: str = dataclasses.field(default="") + @classmethod + def _get_all_var_data_without_creating_var(cls, value: str) -> VarData | None: + """Get all the VarData associated with the Var without creating a Var. + + Args: + value: The value to get the VarData for. + + Returns: + The VarData associated with the Var. + """ + if REFLEX_VAR_OPENING_TAG not in value: + return None + return cls.create(value)._get_all_var_data() + @classmethod def create( cls,