Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,10 @@ def _deterministic_hash(value: object) -> str:
return _hash_str(
str((value._js_expr, _deterministic_hash(value._get_all_var_data())))
)
if isinstance(value, VarData):
return _hash_dict(dataclasses.asdict(value))
if dataclasses.is_dataclass(value):
return _hash_dict({
k.name: getattr(value, k.name) for k in dataclasses.fields(value)
})
if isinstance(value, BaseComponent):
# If the value is a component, hash its rendered code.
return _hash_dict(value.render())
Expand Down
27 changes: 21 additions & 6 deletions reflex/istate/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
from collections.abc import Mapping
from types import MappingProxyType
from typing import TYPE_CHECKING
from urllib.parse import _NetlocResultMixinStr, parse_qsl, urlsplit

Expand All @@ -12,20 +13,34 @@

@dataclasses.dataclass(frozen=True, init=False)
class _FrozenDictStrStr(Mapping[str, str]):
_data: tuple[tuple[str, str], ...]
_data: MappingProxyType[str, str]

def __init__(self, **kwargs):
object.__setattr__(self, "_data", tuple(sorted(kwargs.items())))
object.__setattr__(
self, "_data", MappingProxyType(dict(sorted(kwargs.items())))
)

def __getitem__(self, key: str) -> str:
return dict(self._data)[key]
return self._data[key]

def __iter__(self):
return (x[0] for x in self._data)
return iter(self._data)

def __len__(self):
return len(self._data)

def __hash__(self) -> int:
return hash(frozenset(self._data.items()))

def __getstate__(self) -> object:
return dict(self._data)

def __setstate__(self, state: object) -> None:
if not isinstance(state, dict):
msg = "Invalid state for _FrozenDictStrStr"
raise TypeError(msg)
object.__setattr__(self, "_data", MappingProxyType(state))
Comment thread
adhami3310 marked this conversation as resolved.


@dataclasses.dataclass(frozen=True)
class _HeaderData:
Expand Down Expand Up @@ -170,7 +185,7 @@ def from_router_data(cls, router_data: dict) -> "PageData":

@serializer(to=dict)
def _serialize_page_data(obj: PageData) -> dict:
return dataclasses.asdict(obj)
return {key.name: getattr(obj, key.name) for key in dataclasses.fields(obj)}


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -200,7 +215,7 @@ def from_router_data(cls, router_data: dict) -> "SessionData":

@serializer(to=dict)
def _serialize_session_data(obj: SessionData) -> dict:
return dataclasses.asdict(obj)
return {key.name: getattr(obj, key.name) for key in dataclasses.fields(obj)}


@dataclasses.dataclass(frozen=True)
Expand Down
8 changes: 4 additions & 4 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,8 +1508,8 @@ def _create_literal_var(
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return LiteralObjectVar.create(
{
k: (None if callable(v) else v)
for k, v in dataclasses.asdict(value).items()
k.name: (None if callable(v := getattr(value, k.name)) else v)
for k in dataclasses.fields(value)
},
_var_type=type(value),
_var_data=_var_data,
Expand Down Expand Up @@ -1591,8 +1591,8 @@ def _get_all_var_data_without_creating_var_dispatch(

if dataclasses.is_dataclass(value) and not isinstance(value, type):
return LiteralObjectVar._get_all_var_data_without_creating_var({
k: (None if callable(v) else v)
for k, v in dataclasses.asdict(value).items()
k.name: (None if callable(v := getattr(value, k.name)) else v)
for k in dataclasses.fields(value)
})

if isinstance(value, range):
Expand Down
18 changes: 10 additions & 8 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from reflex.constants.state import FIELD_MARKER
from reflex.environment import environment
from reflex.event import Event, EventHandler
from reflex.istate.data import HeaderData, _FrozenDictStrStr
from reflex.istate.manager import StateManager
from reflex.istate.manager.disk import StateManagerDisk
from reflex.istate.manager.memory import StateManagerMemory
Expand Down Expand Up @@ -925,7 +926,11 @@ def test_get_sid(test_state, router_data):
assert test_state.router.session.session_id == "9fpxSzPb9aFMb4wFAAAH"


def test_get_headers(test_state, router_data, router_data_headers):
def test_get_headers(
test_state: TestState,
router_data: dict[str, str | dict],
router_data_headers: dict[str, str],
):
"""Test getting client headers.

Args:
Expand All @@ -936,13 +941,10 @@ def test_get_headers(test_state, router_data, router_data_headers):
print(router_data_headers)
test_state.router = RouterData.from_router_data(router_data)
print(test_state.router.headers)
assert dataclasses.asdict(test_state.router.headers) == {
format.to_snake_case(k): v for k, v in router_data_headers.items()
} | {
"raw_headers": {
"_data": tuple(sorted((k, v) for k, v in router_data_headers.items()))
}
}
assert test_state.router.headers == HeaderData(
**{format.to_snake_case(k): v for k, v in router_data_headers.items()},
raw_headers=_FrozenDictStrStr(**router_data_headers),
)


def test_get_client_ip(test_state, router_data):
Expand Down
Loading