Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,17 +1402,37 @@ def reset(self):
for substate in self.substates.values():
substate.reset()

@classmethod
@functools.lru_cache
def _is_client_storage(cls, prop_name_or_field: str | ModelField) -> bool:
"""Check if the var is a client storage var.

Args:
prop_name_or_field: The name of the var or the field itself.

Returns:
Whether the var is a client storage var.
"""
if isinstance(prop_name_or_field, str):
field = cls.get_fields().get(prop_name_or_field)
else:
field = prop_name_or_field
return field is not None and (
isinstance(field.default, ClientStorageBase)
or (
isinstance(field.type_, type)
and issubclass(field.type_, ClientStorageBase)
)
)

def _reset_client_storage(self):
"""Reset client storage base vars to their default values."""
# Client-side storage is reset during hydrate so that clearing cookies
# on the browser also resets the values on the backend.
fields = self.get_fields()
for prop_name in self.base_vars:
field = fields[prop_name]
if isinstance(field.default, ClientStorageBase) or (
isinstance(field.type_, type)
and issubclass(field.type_, ClientStorageBase)
):
if self._is_client_storage(field):
setattr(self, prop_name, copy.deepcopy(field.default))

# Recursively reset the substate client storage.
Expand Down Expand Up @@ -2393,8 +2413,9 @@ async def update_vars_internal(self, vars: dict[str, Any]) -> None:
for var, value in vars.items():
state_name, _, var_name = var.rpartition(".")
var_state_cls = State.get_class_substate(state_name)
var_state = await self.get_state(var_state_cls)
setattr(var_state, var_name, value)
if var_state_cls._is_client_storage(var_name):
var_state = await self.get_state(var_state_cls)
setattr(var_state, var_name, value)


class OnLoadInternalState(State):
Expand Down