diff --git a/packages/reflex-components-core/src/reflex_components_core/core/cond.py b/packages/reflex-components-core/src/reflex_components_core/core/cond.py index 9d8deb53fde..a6e996bafef 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/cond.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/cond.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, overload +from typing import Any, TypeVar, overload from reflex_core.components.component import BaseComponent, Component, field from reflex_core.components.tags import CondTag, Tag @@ -100,15 +100,27 @@ def cond(condition: Any, c1: Component, /) -> Component: ... @overload -def cond(condition: Any, c1: Var[Component], c2: Var[Component], /) -> Component: ... # pyright: ignore [reportOverlappingOverload] +def cond(condition: Any, c1: Any, c2: Component, /) -> Component: ... # pyright: ignore [reportOverlappingOverload] + + +T = TypeVar("T", covariant=True) +U = TypeVar("U", covariant=True) @overload -def cond(condition: Any, c1: Any, c2: Component, /) -> Component: ... # pyright: ignore [reportOverlappingOverload] +def cond(condition: Any, c1: Var[T], c2: Var[U], /) -> Var[T | U]: ... # pyright: ignore [reportOverlappingOverload] + + +@overload +def cond(condition: Any, c1: T, c2: Var[U], /) -> Var[T | U]: ... # pyright: ignore [reportOverlappingOverload] + + +@overload +def cond(condition: Any, c1: Var[T], c2: U, /) -> Var[T | U]: ... # pyright: ignore [reportOverlappingOverload] @overload -def cond(condition: Any, c1: Any, c2: Any, /) -> Var: ... +def cond(condition: Any, c1: T, c2: U, /) -> Var[T | U]: ... def cond(condition: Any, c1: Any, c2: Any = types.Unset(), /) -> Component | Var: diff --git a/packages/reflex-core/src/reflex_core/components/component.py b/packages/reflex-core/src/reflex_core/components/component.py index 82d60203b22..623f5069ea5 100644 --- a/packages/reflex-core/src/reflex_core/components/component.py +++ b/packages/reflex-core/src/reflex_core/components/component.py @@ -2945,7 +2945,7 @@ def render_dict_to_var(tag: dict | Component | str) -> Var: frozen=True, slots=True, ) -class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): +class LiteralComponentVar(CachedVarOperation, LiteralVar[Component], ComponentVar): """A Var that represents a Component.""" _var_value: BaseComponent = dataclasses.field(default_factory=empty_component) diff --git a/packages/reflex-core/src/reflex_core/vars/base.py b/packages/reflex-core/src/reflex_core/vars/base.py index e9792ff823b..ef93a789314 100644 --- a/packages/reflex-core/src/reflex_core/vars/base.py +++ b/packages/reflex-core/src/reflex_core/vars/base.py @@ -37,7 +37,7 @@ ) from rich.markup import escape -from typing_extensions import dataclass_transform, override +from typing_extensions import LiteralString, dataclass_transform, override from reflex_core import constants from reflex_core.constants.compiler import Hooks @@ -85,6 +85,7 @@ VAR_TYPE = TypeVar("VAR_TYPE", covariant=True) OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE") STRING_T = TypeVar("STRING_T", bound=str) +LITERAL_STRING_T = TypeVar("LITERAL_STRING_T", bound=LiteralString) SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence) warnings.filterwarnings("ignore", message="fields may not start with an underscore") @@ -651,9 +652,9 @@ def create( # pyright: ignore [reportOverlappingOverload] @classmethod def create( # pyright: ignore [reportOverlappingOverload] cls, - value: str, + value: LITERAL_STRING_T, _var_data: VarData | None = None, - ) -> LiteralStringVar: ... + ) -> LiteralStringVar[LITERAL_STRING_T]: ... @overload @classmethod @@ -1391,7 +1392,7 @@ def create( ) -class LiteralVar(Var): +class LiteralVar(Var[VAR_TYPE]): """Base class for immutable literal vars.""" def __init_subclass__(cls, **kwargs): @@ -2932,7 +2933,7 @@ class NoneVar(Var[None], python_types=type(None)): frozen=True, slots=True, ) -class LiteralNoneVar(LiteralVar, NoneVar): +class LiteralNoneVar(LiteralVar[None], NoneVar): """A var representing None.""" _var_value: None = None diff --git a/packages/reflex-core/src/reflex_core/vars/color.py b/packages/reflex-core/src/reflex_core/vars/color.py index 7db57acde2e..099380ea875 100644 --- a/packages/reflex-core/src/reflex_core/vars/color.py +++ b/packages/reflex-core/src/reflex_core/vars/color.py @@ -24,7 +24,7 @@ class ColorVar(StringVar[Color], python_types=Color): frozen=True, slots=True, ) -class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar): +class LiteralColorVar(CachedVarOperation, LiteralVar[Color], ColorVar): """Base class for immutable literal color vars.""" _var_value: Color = dataclasses.field(default_factory=lambda: Color(color="black")) diff --git a/packages/reflex-core/src/reflex_core/vars/datetime.py b/packages/reflex-core/src/reflex_core/vars/datetime.py index d528f678781..289583b7c0a 100644 --- a/packages/reflex-core/src/reflex_core/vars/datetime.py +++ b/packages/reflex-core/src/reflex_core/vars/datetime.py @@ -171,7 +171,7 @@ def date_compare_operation( frozen=True, slots=True, ) -class LiteralDatetimeVar(LiteralVar, DateTimeVar): +class LiteralDatetimeVar(LiteralVar[DATETIME_T], DateTimeVar[DATETIME_T]): """Base class for immutable datetime and date vars.""" _var_value: date = dataclasses.field(default=datetime.now()) diff --git a/packages/reflex-core/src/reflex_core/vars/number.py b/packages/reflex-core/src/reflex_core/vars/number.py index 3652a81ab4a..c5caa5b5045 100644 --- a/packages/reflex-core/src/reflex_core/vars/number.py +++ b/packages/reflex-core/src/reflex_core/vars/number.py @@ -944,7 +944,7 @@ def boolean_not_operation(value: BooleanVar): frozen=True, slots=True, ) -class LiteralNumberVar(LiteralVar, NumberVar[NUMBER_T]): +class LiteralNumberVar(LiteralVar[NUMBER_T], NumberVar[NUMBER_T]): """Base class for immutable literal number vars.""" _var_value: float | int | decimal.Decimal = dataclasses.field(default=0) @@ -1020,7 +1020,7 @@ def create( frozen=True, slots=True, ) -class LiteralBooleanVar(LiteralVar, BooleanVar): +class LiteralBooleanVar(LiteralVar[bool], BooleanVar): """Base class for immutable literal boolean vars.""" _var_value: bool = dataclasses.field(default=False) diff --git a/packages/reflex-core/src/reflex_core/vars/object.py b/packages/reflex-core/src/reflex_core/vars/object.py index 1e02cee562d..2f6f300eaee 100644 --- a/packages/reflex-core/src/reflex_core/vars/object.py +++ b/packages/reflex-core/src/reflex_core/vars/object.py @@ -370,7 +370,9 @@ class RestProp(ObjectVar[dict[str, Any]]): frozen=True, slots=True, ) -class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): +class LiteralObjectVar( + CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar[OBJECT_TYPE] +): """Base class for immutable literal object vars.""" _var_value: Mapping[Var | Any, Var | Any] = dataclasses.field(default_factory=dict) diff --git a/packages/reflex-core/src/reflex_core/vars/sequence.py b/packages/reflex-core/src/reflex_core/vars/sequence.py index 089a2b80813..793ff46ac41 100644 --- a/packages/reflex-core/src/reflex_core/vars/sequence.py +++ b/packages/reflex-core/src/reflex_core/vars/sequence.py @@ -469,7 +469,9 @@ def foreach(self, fn: Any): frozen=True, slots=True, ) -class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): +class LiteralArrayVar( + CachedVarOperation, LiteralVar[ARRAY_VAR_TYPE], ArrayVar[ARRAY_VAR_TYPE] +): """Base class for immutable literal array vars.""" _var_value: Sequence[Var | Any] = dataclasses.field(default=()) @@ -571,7 +573,7 @@ def create( ) -STRING_TYPE = TypingExtensionsTypeVar("STRING_TYPE", default=str) +STRING_TYPE = TypingExtensionsTypeVar("STRING_TYPE", default=str, covariant=True) class StringVar(Var[STRING_TYPE], python_types=str): @@ -1151,7 +1153,7 @@ def get_decimal_string_operation( frozen=True, slots=True, ) -class LiteralStringVar(LiteralVar, StringVar[str]): +class LiteralStringVar(LiteralVar[STRING_TYPE], StringVar[STRING_TYPE]): """Base class for immutable literal string vars.""" _var_value: str = dataclasses.field(default="") @@ -1770,7 +1772,7 @@ class RangeVar(ArrayVar[Sequence[int]], python_types=range): frozen=True, slots=True, ) -class LiteralRangeVar(CachedVarOperation, LiteralVar, RangeVar): +class LiteralRangeVar(CachedVarOperation, LiteralVar[Sequence[int]], RangeVar): """Base class for immutable literal range vars.""" _var_value: range = dataclasses.field(default_factory=lambda: range(0)) diff --git a/pyi_hashes.json b/pyi_hashes.json index 62611f50a6d..ed5945ac549 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -120,5 +120,5 @@ "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "f09c503c4ab880c13c13d6fa67d708b8", "reflex/__init__.pyi": "7696c38fd9c04a598518b49c5185c414", "reflex/components/__init__.pyi": "55bb242d5e5428db329b88b4923c2ba5", - "reflex/experimental/memo.pyi": "d16eccf33993c781e2f8bc2dd8bbd4d4" + "reflex/experimental/memo.pyi": "2566571c9f68fce5e8178e8acac11944" } diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index c145f85cc1d..956616f2c16 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -17,6 +17,7 @@ from reflex_core.constants.state import CAMEL_CASE_MEMO_MARKER from reflex_core.utils import format from reflex_core.utils.imports import ImportVar +from reflex_core.utils.types import safe_issubclass from reflex_core.vars import VarData from reflex_core.vars.base import LiteralVar, Var from reflex_core.vars.function import ( @@ -253,7 +254,14 @@ def _is_component_annotation(annotation: Any) -> bool: Whether the annotation resolves to Component. """ origin = get_origin(annotation) or annotation - return isinstance(origin, type) and issubclass(origin, Component) + return isinstance(origin, type) and ( + safe_issubclass(origin, Component) + or bool( + safe_issubclass(origin, Var) + and (args := get_args(annotation)) + and safe_issubclass(args[0], Component) + ) + ) def _children_annotation_is_valid(annotation: Any) -> bool: diff --git a/tests/units/components/core/test_cond.py b/tests/units/components/core/test_cond.py index b1f7d82d56c..b0ba72be752 100644 --- a/tests/units/components/core/test_cond.py +++ b/tests/units/components/core/test_cond.py @@ -1,13 +1,15 @@ import json -from typing import Any +from typing import Any, Literal import pytest from reflex_components_core.base.fragment import Fragment from reflex_components_core.core.cond import Cond, cond from reflex_components_radix.themes.typography.text import Text +from reflex_core.components.component import Component from reflex_core.constants.state import FIELD_MARKER from reflex_core.utils.format import format_state_name from reflex_core.vars.base import LiteralVar, Var, computed_var +from typing_extensions import assert_type from reflex.state import BaseState @@ -145,3 +147,44 @@ def computed_str(self) -> str: ) assert comp._var_type == int | str + + +def test_cond_assert_types() -> None: + """Test that pyright infers the correct return types for cond overloads.""" + text_comp = Text.create("hello") + text_comp2 = Text.create("world") + var_int: Var[int] = LiteralVar.create(1) + var_str: Var[str] = LiteralVar.create("a") + + # Component, Component -> Component + _ = assert_type(cond(True, text_comp, text_comp2), Component) + + # Component, non-component -> Component + _ = assert_type(cond(True, text_comp, "fallback"), Component) + + # Component only (no else) -> Component + _ = assert_type(cond(True, text_comp), Component) + + # non-component, Component -> Component + _ = assert_type(cond(True, "hello", text_comp), Component) + + # T, T -> Var[T] + _ = assert_type(cond(True, "hello", "world"), Var[str]) + + # T, U -> Var[T | U] + _ = assert_type(cond(True, "hello", 3), Var[str | int]) + + # T, Var[T] -> Var[T] + _ = assert_type(cond(True, "hello", var_str), Var[str]) + + # Var[T], T -> Var[T] + _ = assert_type(cond(True, var_str, "world"), Var[str]) + + # T, Var[U] -> Var[T | U] + _ = assert_type(cond(True, "hello", var_int), Var[str | int]) + + # Var[T], U -> Var[T | U] + _ = assert_type(cond(True, var_str, 3), Var[int | Literal["a"]]) + + # Var[T], Var[U] -> Var[T | U] + _ = assert_type(cond(True, var_int, var_str), Var[int | Literal["a"]]) diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index 103c1c56c0e..415d538c40f 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -117,7 +117,7 @@ def conditional_slot( show: rx.Var[bool], first: rx.Var[rx.Component], second: rx.Var[rx.Component], - ) -> rx.Component: + ) -> rx.Var[rx.Component]: return rx.cond(show, first, second) definition = EXPERIMENTAL_MEMOS["ConditionalSlot"]