Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, overload
from typing import Any, TypeVar, overload

from reflex_core.components.component import BaseComponent, Component, field
from reflex_core.components.tags import CondTag, Tag
Expand Down Expand Up @@ -100,15 +100,27 @@ def cond(condition: Any, c1: Component, /) -> Component: ...


@overload
def cond(condition: Any, c1: Var[Component], c2: Var[Component], /) -> Component: ... # pyright: ignore [reportOverlappingOverload]
def cond(condition: Any, c1: Any, c2: Component, /) -> Component: ... # pyright: ignore [reportOverlappingOverload]


T = TypeVar("T", covariant=True)
U = TypeVar("U", covariant=True)


@overload
def cond(condition: Any, c1: Any, c2: Component, /) -> Component: ... # pyright: ignore [reportOverlappingOverload]
def cond(condition: Any, c1: Var[T], c2: Var[U], /) -> Var[T | U]: ... # pyright: ignore [reportOverlappingOverload]


@overload
def cond(condition: Any, c1: T, c2: Var[U], /) -> Var[T | U]: ... # pyright: ignore [reportOverlappingOverload]


@overload
def cond(condition: Any, c1: Var[T], c2: U, /) -> Var[T | U]: ... # pyright: ignore [reportOverlappingOverload]


@overload
def cond(condition: Any, c1: Any, c2: Any, /) -> Var: ...
def cond(condition: Any, c1: T, c2: U, /) -> Var[T | U]: ...


def cond(condition: Any, c1: Any, c2: Any = types.Unset(), /) -> Component | Var:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2945,7 +2945,7 @@ def render_dict_to_var(tag: dict | Component | str) -> Var:
frozen=True,
slots=True,
)
class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
class LiteralComponentVar(CachedVarOperation, LiteralVar[Component], ComponentVar):
"""A Var that represents a Component."""

_var_value: BaseComponent = dataclasses.field(default_factory=empty_component)
Expand Down
11 changes: 6 additions & 5 deletions packages/reflex-core/src/reflex_core/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)

from rich.markup import escape
from typing_extensions import dataclass_transform, override
from typing_extensions import LiteralString, dataclass_transform, override

from reflex_core import constants
from reflex_core.constants.compiler import Hooks
Expand Down Expand Up @@ -85,6 +85,7 @@
VAR_TYPE = TypeVar("VAR_TYPE", covariant=True)
OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
STRING_T = TypeVar("STRING_T", bound=str)
LITERAL_STRING_T = TypeVar("LITERAL_STRING_T", bound=LiteralString)
SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence)

warnings.filterwarnings("ignore", message="fields may not start with an underscore")
Expand Down Expand Up @@ -651,9 +652,9 @@ def create( # pyright: ignore [reportOverlappingOverload]
@classmethod
def create( # pyright: ignore [reportOverlappingOverload]
cls,
value: str,
value: LITERAL_STRING_T,
_var_data: VarData | None = None,
) -> LiteralStringVar: ...
) -> LiteralStringVar[LITERAL_STRING_T]: ...

@overload
@classmethod
Expand Down Expand Up @@ -1391,7 +1392,7 @@ def create(
)


class LiteralVar(Var):
class LiteralVar(Var[VAR_TYPE]):
"""Base class for immutable literal vars."""

def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -2932,7 +2933,7 @@ class NoneVar(Var[None], python_types=type(None)):
frozen=True,
slots=True,
)
class LiteralNoneVar(LiteralVar, NoneVar):
class LiteralNoneVar(LiteralVar[None], NoneVar):
"""A var representing None."""

_var_value: None = None
Expand Down
2 changes: 1 addition & 1 deletion packages/reflex-core/src/reflex_core/vars/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ColorVar(StringVar[Color], python_types=Color):
frozen=True,
slots=True,
)
class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar):
class LiteralColorVar(CachedVarOperation, LiteralVar[Color], ColorVar):
"""Base class for immutable literal color vars."""

_var_value: Color = dataclasses.field(default_factory=lambda: Color(color="black"))
Expand Down
2 changes: 1 addition & 1 deletion packages/reflex-core/src/reflex_core/vars/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def date_compare_operation(
frozen=True,
slots=True,
)
class LiteralDatetimeVar(LiteralVar, DateTimeVar):
class LiteralDatetimeVar(LiteralVar[DATETIME_T], DateTimeVar[DATETIME_T]):
"""Base class for immutable datetime and date vars."""

_var_value: date = dataclasses.field(default=datetime.now())
Expand Down
4 changes: 2 additions & 2 deletions packages/reflex-core/src/reflex_core/vars/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def boolean_not_operation(value: BooleanVar):
frozen=True,
slots=True,
)
class LiteralNumberVar(LiteralVar, NumberVar[NUMBER_T]):
class LiteralNumberVar(LiteralVar[NUMBER_T], NumberVar[NUMBER_T]):
"""Base class for immutable literal number vars."""

_var_value: float | int | decimal.Decimal = dataclasses.field(default=0)
Expand Down Expand Up @@ -1020,7 +1020,7 @@ def create(
frozen=True,
slots=True,
)
class LiteralBooleanVar(LiteralVar, BooleanVar):
class LiteralBooleanVar(LiteralVar[bool], BooleanVar):
"""Base class for immutable literal boolean vars."""

_var_value: bool = dataclasses.field(default=False)
Expand Down
4 changes: 3 additions & 1 deletion packages/reflex-core/src/reflex_core/vars/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,9 @@ class RestProp(ObjectVar[dict[str, Any]]):
frozen=True,
slots=True,
)
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
class LiteralObjectVar(
CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar[OBJECT_TYPE]
):
"""Base class for immutable literal object vars."""

_var_value: Mapping[Var | Any, Var | Any] = dataclasses.field(default_factory=dict)
Expand Down
10 changes: 6 additions & 4 deletions packages/reflex-core/src/reflex_core/vars/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def foreach(self, fn: Any):
frozen=True,
slots=True,
)
class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
class LiteralArrayVar(
CachedVarOperation, LiteralVar[ARRAY_VAR_TYPE], ArrayVar[ARRAY_VAR_TYPE]
):
"""Base class for immutable literal array vars."""

_var_value: Sequence[Var | Any] = dataclasses.field(default=())
Expand Down Expand Up @@ -571,7 +573,7 @@ def create(
)


STRING_TYPE = TypingExtensionsTypeVar("STRING_TYPE", default=str)
STRING_TYPE = TypingExtensionsTypeVar("STRING_TYPE", default=str, covariant=True)


class StringVar(Var[STRING_TYPE], python_types=str):
Expand Down Expand Up @@ -1151,7 +1153,7 @@ def get_decimal_string_operation(
frozen=True,
slots=True,
)
class LiteralStringVar(LiteralVar, StringVar[str]):
class LiteralStringVar(LiteralVar[STRING_TYPE], StringVar[STRING_TYPE]):
"""Base class for immutable literal string vars."""

_var_value: str = dataclasses.field(default="")
Expand Down Expand Up @@ -1770,7 +1772,7 @@ class RangeVar(ArrayVar[Sequence[int]], python_types=range):
frozen=True,
slots=True,
)
class LiteralRangeVar(CachedVarOperation, LiteralVar, RangeVar):
class LiteralRangeVar(CachedVarOperation, LiteralVar[Sequence[int]], RangeVar):
"""Base class for immutable literal range vars."""

_var_value: range = dataclasses.field(default_factory=lambda: range(0))
Expand Down
2 changes: 1 addition & 1 deletion pyi_hashes.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,5 @@
"packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "f09c503c4ab880c13c13d6fa67d708b8",
"reflex/__init__.pyi": "7696c38fd9c04a598518b49c5185c414",
"reflex/components/__init__.pyi": "55bb242d5e5428db329b88b4923c2ba5",
"reflex/experimental/memo.pyi": "d16eccf33993c781e2f8bc2dd8bbd4d4"
"reflex/experimental/memo.pyi": "2566571c9f68fce5e8178e8acac11944"
}
10 changes: 9 additions & 1 deletion reflex/experimental/memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from reflex_core.constants.state import CAMEL_CASE_MEMO_MARKER
from reflex_core.utils import format
from reflex_core.utils.imports import ImportVar
from reflex_core.utils.types import safe_issubclass
from reflex_core.vars import VarData
from reflex_core.vars.base import LiteralVar, Var
from reflex_core.vars.function import (
Expand Down Expand Up @@ -253,7 +254,14 @@ def _is_component_annotation(annotation: Any) -> bool:
Whether the annotation resolves to Component.
"""
origin = get_origin(annotation) or annotation
return isinstance(origin, type) and issubclass(origin, Component)
return isinstance(origin, type) and (
safe_issubclass(origin, Component)
or bool(
safe_issubclass(origin, Var)
and (args := get_args(annotation))
and safe_issubclass(args[0], Component)
)
)


def _children_annotation_is_valid(annotation: Any) -> bool:
Expand Down
45 changes: 44 additions & 1 deletion tests/units/components/core/test_cond.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import json
from typing import Any
from typing import Any, Literal

import pytest
from reflex_components_core.base.fragment import Fragment
from reflex_components_core.core.cond import Cond, cond
from reflex_components_radix.themes.typography.text import Text
from reflex_core.components.component import Component
from reflex_core.constants.state import FIELD_MARKER
from reflex_core.utils.format import format_state_name
from reflex_core.vars.base import LiteralVar, Var, computed_var
from typing_extensions import assert_type

from reflex.state import BaseState

Expand Down Expand Up @@ -145,3 +147,44 @@ def computed_str(self) -> str:
)

assert comp._var_type == int | str


def test_cond_assert_types() -> None:
"""Test that pyright infers the correct return types for cond overloads."""
text_comp = Text.create("hello")
text_comp2 = Text.create("world")
var_int: Var[int] = LiteralVar.create(1)
var_str: Var[str] = LiteralVar.create("a")

# Component, Component -> Component
_ = assert_type(cond(True, text_comp, text_comp2), Component)

# Component, non-component -> Component
_ = assert_type(cond(True, text_comp, "fallback"), Component)

# Component only (no else) -> Component
_ = assert_type(cond(True, text_comp), Component)

# non-component, Component -> Component
_ = assert_type(cond(True, "hello", text_comp), Component)

# T, T -> Var[T]
_ = assert_type(cond(True, "hello", "world"), Var[str])

# T, U -> Var[T | U]
_ = assert_type(cond(True, "hello", 3), Var[str | int])

# T, Var[T] -> Var[T]
_ = assert_type(cond(True, "hello", var_str), Var[str])

# Var[T], T -> Var[T]
_ = assert_type(cond(True, var_str, "world"), Var[str])

# T, Var[U] -> Var[T | U]
_ = assert_type(cond(True, "hello", var_int), Var[str | int])

# Var[T], U -> Var[T | U]
_ = assert_type(cond(True, var_str, 3), Var[int | Literal["a"]])

# Var[T], Var[U] -> Var[T | U]
_ = assert_type(cond(True, var_int, var_str), Var[int | Literal["a"]])
2 changes: 1 addition & 1 deletion tests/units/experimental/test_memo.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def conditional_slot(
show: rx.Var[bool],
first: rx.Var[rx.Component],
second: rx.Var[rx.Component],
) -> rx.Component:
) -> rx.Var[rx.Component]:
return rx.cond(show, first, second)

definition = EXPERIMENTAL_MEMOS["ConditionalSlot"]
Expand Down
Loading