Skip to content
Closed
40 changes: 40 additions & 0 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
Optional,
Set,
Type,
TypeVar,
Union,
get_args,
get_type_hints,
overload,
)

from fastapi import FastAPI, HTTPException, Request, UploadFile
Expand Down Expand Up @@ -1102,6 +1104,44 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
sid=state.router.session.session_id,
)

S = TypeVar("S", bound=BaseState)

@overload
async def modify_states(
self, substate_cls: Type[S], from_state: None
) -> AsyncIterator[S]: ...

@overload
async def modify_states(
self, substate_cls: None, from_state: BaseState
) -> AsyncIterator[BaseState]: ...

async def modify_states(
self,
substate_cls: Type[S] | Type[BaseState] | None = None,
from_state: BaseState | None = None,
) -> AsyncIterator[S] | AsyncIterator[BaseState]:
"""Iterate over the states.

Args:
substate_cls: The substate class to iterate over.
from_state: The state from which this method is called.

Yields:
The states to modify.
"""
async for token in self.state_manager.iter_state_tokens():
# avoid deadlock when calling from event handler/background task
if from_state is not None and token.startswith(
from_state.router.session.client_token
):
state = from_state
continue
async with self.modify_state(token) as state:
if substate_cls is not None:
state = state.get_substate(substate_cls)
yield state

def _process_background(
self, state: BaseState, event: Event
) -> asyncio.Task | None:
Expand Down
6 changes: 3 additions & 3 deletions reflex/app_module_for_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from reflex import constants
from reflex.utils import telemetry
from reflex.utils.exec import is_prod_mode
from reflex.utils.prerequisites import get_app
from reflex.utils.prerequisites import get_app_module

if constants.CompileVars.APP != "app":
raise AssertionError("unexpected variable name for 'app'")

telemetry.send("compile")
app_module = get_app(reload=False)
app_module = get_app_module(reload=False)
app = getattr(app_module, constants.CompileVars.APP)
# For py3.8 and py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
Expand All @@ -30,7 +30,7 @@
# ensure only "app" is exposed.
del app_module
del compile_future
del get_app
del get_app_module
del is_prod_mode
del telemetry
del constants
Expand Down
117 changes: 103 additions & 14 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterator
from pathlib import Path
from types import FunctionType, MethodType
from typing import (
Expand All @@ -27,8 +28,10 @@
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)

import dill
Expand Down Expand Up @@ -71,6 +74,7 @@
LockExpiredError,
)
from reflex.utils.exec import is_testing_env
from reflex.utils.format import remove_prefix
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.utils.types import override
from reflex.vars import VarData
Expand Down Expand Up @@ -336,6 +340,9 @@ def __call__(self, *args: Any) -> EventSpec:
return super().__call__(*args)


S = TypeVar("S", bound="BaseState")


class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""

Expand Down Expand Up @@ -1266,18 +1273,28 @@ def _reset_client_storage(self):
for substate in self.substates.values():
substate._reset_client_storage()

def get_substate(self, path: Sequence[str]) -> BaseState:
@overload
def get_substate(self, path: Sequence[str]) -> BaseState: ...

@overload
def get_substate(self, path: Type[S]) -> S: ...

def get_substate(self, path: Sequence[str] | Type[S]) -> BaseState | S:
"""Get the substate.

Args:
path: The path to the substate.
path: The path to the substate or the class of the substate.

Returns:
The substate.

Raises:
ValueError: If the substate is not found.
"""
if isinstance(path, type):
path = remove_prefix(
text=path.get_full_name(), prefix=f"{self.get_full_name()}."
).split(".")
if len(path) == 0:
return self
if path[0] == self.get_name():
Expand Down Expand Up @@ -1418,7 +1435,7 @@ def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
The instance of state_cls associated with this state's client_token.
"""
root_state = self._get_root_state()
return root_state.get_substate(state_cls.get_full_name().split("."))
return root_state.get_substate(state_cls)

async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
"""Get a state instance from redis.
Expand Down Expand Up @@ -1588,7 +1605,7 @@ def _as_state_update(
except Exception as ex:
state._clean()

app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance = prerequisites.get_app()

event_specs = app_instance.backend_exception_handler(ex)

Expand Down Expand Up @@ -1662,7 +1679,7 @@ async def _process_event(
except Exception as ex:
telemetry.send_error(ex, context="backend")

app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance = prerequisites.get_app()

event_specs = app_instance.backend_exception_handler(ex)

Expand Down Expand Up @@ -1985,7 +2002,7 @@ def handle_frontend_exception(self, stack: str) -> None:
stack: The stack trace of the exception.

"""
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance = prerequisites.get_app()
app_instance.frontend_exception_handler(Exception(stack))


Expand Down Expand Up @@ -2024,7 +2041,7 @@ def on_load_internal(self) -> list[Event | EventSpec] | None:
The list of events to queue for on load handling.
"""
# Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app = prerequisites.get_app()
load_events = app.get_load_events(self.router.page.path)
if not load_events:
self.is_hydrated = True
Expand Down Expand Up @@ -2167,7 +2184,7 @@ def __init__(
"""
super().__init__(state_instance)
# compile is not relevant to backend logic
self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
self._self_app = prerequisites.get_app()
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_actx = None
self._self_mutable = False
Expand Down Expand Up @@ -2483,6 +2500,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""
yield self.state()

@abstractmethod
def iter_state_tokens(
self, substate_cls: Type[BaseState] | None = None
) -> AsyncIterator[str]:
"""Iterate over the state names.

Args:
substate_cls: The subclass of BaseState to filter by.

Raises:
NotImplementedError: Always, because this method must be implemented by subclasses.
"""
raise NotImplementedError


class StateManagerMemory(StateManager):
"""A state manager that stores states in memory."""
Expand Down Expand Up @@ -2552,6 +2583,21 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
yield state
await self.set_state(token, state)

@override
async def iter_state_tokens(
self, substate_cls: Type[BaseState] | None = None
) -> AsyncIterator[str]:
"""Iterate over the state names.

Args:
substate_cls: The subclass of BaseState to filter by.

Yields:
The state names.
"""
for token in self.states:
yield token


def _default_token_expiration() -> int:
"""Get the default token expiration time.
Expand Down Expand Up @@ -2663,15 +2709,21 @@ def states_directory(self) -> Path:
"""
return prerequisites.get_web_dir() / constants.Dirs.STATES

def _iter_pkl_files(self) -> Iterator[Path]:
"""Iterate over the pkl files in the states directory.

Yields:
The pkl files.
"""
for path in path_ops.ls(self.states_directory):
if path.suffix == ".pkl":
yield path

def _purge_expired_states(self):
"""Purge expired states from the disk."""
import time

for path in path_ops.ls(self.states_directory):
# check path is a pickle file
if path.suffix != ".pkl":
continue

for path in self._iter_pkl_files():
# load last edited field from file
last_edited = path.stat().st_mtime

Expand Down Expand Up @@ -2812,6 +2864,24 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
yield state
await self.set_state(token, state)

@override
async def iter_state_tokens(
self, substate_cls: Type[BaseState] | None = None
) -> AsyncIterator[str]:
"""Iterate over the state names.

Args:
substate_cls: The subclass of BaseState to filter by.

Yields:
The state names.
"""
for path in self._iter_pkl_files():
token = path.stem
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately this wont work anymore, because we started hashing the token to avoid file path limits on windows.

however, we should be able to use the same implementation here as we do for StateManagerMemory as the on-disk pickles are only read when the backend starts after a hot reload

if substate_cls is not None and not token.endswith(substate_cls.get_name()):
continue
yield token


# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
if not isinstance(State.validate.__func__, FunctionType):
Expand Down Expand Up @@ -3111,6 +3181,25 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
yield state
await self.set_state(token, state, lock_id)

@override
async def iter_state_tokens(
self, substate_cls: Type[BaseState] | None = None
) -> AsyncIterator[str]:
"""Iterate over the state names.

Args:
substate_cls: The subclass of BaseState to filter by.

Yields:
The state names.
"""
if substate_cls is None:
substate_cls = self.state
async for token in self.redis.scan_iter(
match=f"*_{substate_cls.get_name()}", _type="STRING"
):
yield token.decode()

@staticmethod
def _lock_key(token: str) -> bytes:
"""Get the redis key for a token's lock.
Expand Down Expand Up @@ -3236,7 +3325,7 @@ def get_state_manager() -> StateManager:
Returns:
The state manager.
"""
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app = prerequisites.get_app()
return app.state_manager


Expand Down
19 changes: 19 additions & 0 deletions reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os
import re
import sys
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union

from reflex import constants
Expand Down Expand Up @@ -781,3 +782,21 @@ def format_data_editor_cell(cell: Any):
"kind": Var(_js_expr="GridCellKind.Text"),
"data": cell,
}


def remove_prefix(text: str, prefix: str) -> str:
"""Remove a prefix from a string, if present.
This can be removed once we drop support for Python 3.8.

Args:
text: The string to remove the prefix from.
prefix: The prefix to remove.

Returns:
The string with the prefix removed, if present.
"""
if sys.version_info >= (3, 9):
return text.removeprefix(prefix)
if text.startswith(prefix):
return text[len(prefix) :]
return text
Loading