diff --git a/reflex/app.py b/reflex/app.py index 206d88865c7..6d84e6793d6 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -604,6 +604,11 @@ def __call__(self) -> ASGIApp: self._compile(prerender_routes=is_prod_mode()) + config = get_config() + + for plugin in config.plugins: + plugin.post_compile(app=self) + # We will not be making more vars, so we can clear the global cache to free up memory. GLOBAL_CACHE.clear() diff --git a/reflex/istate/manager.py b/reflex/istate/manager.py index 9c97feb048d..fac0f64586b 100644 --- a/reflex/istate/manager.py +++ b/reflex/istate/manager.py @@ -143,6 +143,8 @@ async def set_state(self, token: str, state: BaseState): token: The token to set the state for. state: The state to set. """ + token = _split_substate_key(token)[0] + self.states[token] = state @override @contextlib.asynccontextmanager @@ -165,7 +167,6 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: async with self._states_locks[token]: state = await self.get_state(token) yield state - await self.set_state(token, state) def _default_token_expiration() -> int: diff --git a/reflex/plugins/__init__.py b/reflex/plugins/__init__.py index b5d75217f7b..754409046b8 100644 --- a/reflex/plugins/__init__.py +++ b/reflex/plugins/__init__.py @@ -1,5 +1,6 @@ """Reflex Plugin System.""" +from ._screenshot import ScreenshotPlugin as _ScreenshotPlugin from .base import CommonContext, Plugin, PreCompileContext from .sitemap import SitemapPlugin from .tailwind_v3 import TailwindV3Plugin @@ -12,4 +13,5 @@ "SitemapPlugin", "TailwindV3Plugin", "TailwindV4Plugin", + "_ScreenshotPlugin", ] diff --git a/reflex/plugins/_screenshot.py b/reflex/plugins/_screenshot.py new file mode 100644 index 00000000000..faafbaa6135 --- /dev/null +++ b/reflex/plugins/_screenshot.py @@ -0,0 +1,144 @@ +"""Plugin to enable screenshot functionality.""" + +from typing import TYPE_CHECKING + +from reflex.plugins.base import Plugin as BasePlugin + +if TYPE_CHECKING: + from starlette.requests import Request + from starlette.responses import Response + from typing_extensions import Unpack + + from reflex.app import App + from reflex.plugins.base import PostCompileContext + from reflex.state import BaseState + +ACTIVE_CONNECTIONS = "/_active_connections" +CLONE_STATE = "/_clone_state" + + +def _deep_copy(state: "BaseState") -> "BaseState": + """Create a deep copy of the state. + + Args: + state: The state to copy. + + Returns: + A deep copy of the state. + """ + import copy + + copy_of_state = copy.deepcopy(state) + + def copy_substate(substate: "BaseState") -> "BaseState": + substate_copy = _deep_copy(substate) + + substate_copy.parent_state = copy_of_state + + return substate_copy + + copy_of_state.substates = { + substate_name: copy_substate(substate) + for substate_name, substate in state.substates.items() + } + + return copy_of_state + + +class ScreenshotPlugin(BasePlugin): + """Plugin to handle screenshot functionality.""" + + def post_compile(self, **context: "Unpack[PostCompileContext]") -> None: + """Called after the compilation of the plugin. + + Args: + context: The context for the plugin. + """ + app = context["app"] + self._add_active_connections_endpoint(app) + self._add_clone_state_endpoint(app) + + @staticmethod + def _add_active_connections_endpoint(app: "App") -> None: + """Add an endpoint to the app that returns the active connections. + + Args: + app: The application instance to which the endpoint will be added. + """ + if not app._api: + return + + async def active_connections(_request: "Request") -> "Response": + from starlette.responses import JSONResponse + + if not app.event_namespace: + return JSONResponse({}) + + return JSONResponse(app.event_namespace.token_to_sid) + + app._api.add_route( + ACTIVE_CONNECTIONS, + active_connections, + methods=["GET"], + ) + + @staticmethod + def _add_clone_state_endpoint(app: "App") -> None: + """Add an endpoint to the app that clones the current state. + + Args: + app: The application instance to which the endpoint will be added. + """ + if not app._api: + return + + async def clone_state(request: "Request") -> "Response": + import uuid + + from starlette.responses import JSONResponse + + from reflex.state import _substate_key + + if not app.event_namespace: + return JSONResponse({}) + + token_to_clone = await request.json() + + if not isinstance(token_to_clone, str): + return JSONResponse( + {"error": "Token to clone must be a string."}, status_code=400 + ) + + old_state = await app.state_manager.get_state(token_to_clone) + + new_state = _deep_copy(old_state) + + new_token = uuid.uuid4().hex + + all_states = [new_state] + + found_new = True + + while found_new: + found_new = False + + for state in all_states: + for substate in state.substates.values(): + substate._was_touched = True + + if substate not in all_states: + all_states.append(substate) + + found_new = True + + await app.state_manager.set_state( + _substate_key(new_token, new_state), new_state + ) + + return JSONResponse(new_token) + + app._api.add_route( + CLONE_STATE, + clone_state, + methods=["POST"], + ) diff --git a/reflex/plugins/base.py b/reflex/plugins/base.py index 8b1ac6005fb..52dfa8d7805 100644 --- a/reflex/plugins/base.py +++ b/reflex/plugins/base.py @@ -7,7 +7,7 @@ from typing_extensions import Unpack if TYPE_CHECKING: - from reflex.app import UnevaluatedPage + from reflex.app import App, UnevaluatedPage class CommonContext(TypedDict): @@ -44,6 +44,12 @@ class PreCompileContext(CommonContext): unevaluated_pages: Sequence["UnevaluatedPage"] +class PostCompileContext(CommonContext): + """Context for post-compile hooks.""" + + app: "App" + + class Plugin: """Base class for all plugins.""" @@ -104,6 +110,13 @@ def pre_compile(self, **context: Unpack[PreCompileContext]) -> None: context: The context for the plugin. """ + def post_compile(self, **context: Unpack[PostCompileContext]) -> None: + """Called after the compilation of the plugin. + + Args: + context: The context for the plugin. + """ + def __repr__(self): """Return a string representation of the plugin.