From 90d01aeda698c07a3144b6f3842cbe0682217068 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 4 Dec 2025 12:26:23 -0800 Subject: [PATCH 1/3] optimize frozen dict get item --- reflex/istate/data.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/reflex/istate/data.py b/reflex/istate/data.py index e73b588498b..4747fde7244 100644 --- a/reflex/istate/data.py +++ b/reflex/istate/data.py @@ -2,6 +2,7 @@ import dataclasses from collections.abc import Mapping +from types import MappingProxyType from typing import TYPE_CHECKING from urllib.parse import _NetlocResultMixinStr, parse_qsl, urlsplit @@ -12,20 +13,34 @@ @dataclasses.dataclass(frozen=True, init=False) class _FrozenDictStrStr(Mapping[str, str]): - _data: tuple[tuple[str, str], ...] + _data: MappingProxyType[str, str] def __init__(self, **kwargs): - object.__setattr__(self, "_data", tuple(sorted(kwargs.items()))) + object.__setattr__( + self, "_data", MappingProxyType(dict(sorted(kwargs.items()))) + ) def __getitem__(self, key: str) -> str: - return dict(self._data)[key] + return self._data[key] def __iter__(self): - return (x[0] for x in self._data) + return iter(self._data) def __len__(self): return len(self._data) + def __hash__(self) -> int: + return hash(frozenset(self._data.items())) + + def __getstate__(self) -> object: + return dict(self._data) + + def __setstate__(self, state: object) -> None: + if not isinstance(state, dict): + msg = "Invalid state for _FrozenDictStrStr" + raise TypeError(msg) + object.__setattr__(self, "_data", MappingProxyType(state)) + @dataclasses.dataclass(frozen=True) class _HeaderData: From bde39fe175bad426358ac417b7d27a72941191db Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 4 Dec 2025 13:10:39 -0800 Subject: [PATCH 2/3] asdict is weird --- reflex/components/component.py | 4 +++- reflex/istate/data.py | 4 ++-- reflex/vars/base.py | 8 ++++---- tests/units/test_state.py | 18 ++++++++++-------- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 89f6c465469..0ce70d81aa6 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -516,7 +516,9 @@ def _deterministic_hash(value: object) -> str: str((value._js_expr, _deterministic_hash(value._get_all_var_data()))) ) if isinstance(value, VarData): - return _hash_dict(dataclasses.asdict(value)) + return _hash_dict({ + k.name: getattr(value, k.name) for k in dataclasses.fields(value) + }) if isinstance(value, BaseComponent): # If the value is a component, hash its rendered code. return _hash_dict(value.render()) diff --git a/reflex/istate/data.py b/reflex/istate/data.py index 4747fde7244..ba75a8df4dc 100644 --- a/reflex/istate/data.py +++ b/reflex/istate/data.py @@ -185,7 +185,7 @@ def from_router_data(cls, router_data: dict) -> "PageData": @serializer(to=dict) def _serialize_page_data(obj: PageData) -> dict: - return dataclasses.asdict(obj) + return {key.name: getattr(obj, key.name) for key in dataclasses.fields(obj)} @dataclasses.dataclass(frozen=True) @@ -215,7 +215,7 @@ def from_router_data(cls, router_data: dict) -> "SessionData": @serializer(to=dict) def _serialize_session_data(obj: SessionData) -> dict: - return dataclasses.asdict(obj) + return {key.name: getattr(obj, key.name) for key in dataclasses.fields(obj)} @dataclasses.dataclass(frozen=True) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 3a9e8a985f9..54583d88980 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1508,8 +1508,8 @@ def _create_literal_var( if dataclasses.is_dataclass(value) and not isinstance(value, type): return LiteralObjectVar.create( { - k: (None if callable(v) else v) - for k, v in dataclasses.asdict(value).items() + k.name: (None if callable(v := getattr(value, k.name)) else v) + for k in dataclasses.fields(value) }, _var_type=type(value), _var_data=_var_data, @@ -1591,8 +1591,8 @@ def _get_all_var_data_without_creating_var_dispatch( if dataclasses.is_dataclass(value) and not isinstance(value, type): return LiteralObjectVar._get_all_var_data_without_creating_var({ - k: (None if callable(v) else v) - for k, v in dataclasses.asdict(value).items() + k.name: (None if callable(v := getattr(value, k.name)) else v) + for k in dataclasses.fields(value) }) if isinstance(value, range): diff --git a/tests/units/test_state.py b/tests/units/test_state.py index a984a739dfb..941c569095f 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -28,6 +28,7 @@ from reflex.constants.state import FIELD_MARKER from reflex.environment import environment from reflex.event import Event, EventHandler +from reflex.istate.data import HeaderData, _FrozenDictStrStr from reflex.istate.manager import StateManager from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory @@ -925,7 +926,11 @@ def test_get_sid(test_state, router_data): assert test_state.router.session.session_id == "9fpxSzPb9aFMb4wFAAAH" -def test_get_headers(test_state, router_data, router_data_headers): +def test_get_headers( + test_state: TestState, + router_data: dict[str, str | dict], + router_data_headers: dict[str, str], +): """Test getting client headers. Args: @@ -936,13 +941,10 @@ def test_get_headers(test_state, router_data, router_data_headers): print(router_data_headers) test_state.router = RouterData.from_router_data(router_data) print(test_state.router.headers) - assert dataclasses.asdict(test_state.router.headers) == { - format.to_snake_case(k): v for k, v in router_data_headers.items() - } | { - "raw_headers": { - "_data": tuple(sorted((k, v) for k, v in router_data_headers.items())) - } - } + assert test_state.router.headers == HeaderData( + **{format.to_snake_case(k): v for k, v in router_data_headers.items()}, + raw_headers=_FrozenDictStrStr(**router_data_headers), + ) def test_get_client_ip(test_state, router_data): From c1b17c2d8c265b2182988a770e9df0616a84a31a Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 4 Dec 2025 13:39:32 -0800 Subject: [PATCH 3/3] make that into a dataclass generic --- reflex/components/component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 0ce70d81aa6..e62783d2579 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -515,7 +515,7 @@ def _deterministic_hash(value: object) -> str: return _hash_str( str((value._js_expr, _deterministic_hash(value._get_all_var_data()))) ) - if isinstance(value, VarData): + if dataclasses.is_dataclass(value): return _hash_dict({ k.name: getattr(value, k.name) for k in dataclasses.fields(value) })