diff --git a/docs/utility_methods/lifespan_tasks.md b/docs/utility_methods/lifespan_tasks.md index bf5f84eb8ad..63be078aef9 100644 --- a/docs/utility_methods/lifespan_tasks.md +++ b/docs/utility_methods/lifespan_tasks.md @@ -10,6 +10,8 @@ Lifespan tasks are defined as async coroutines or async contextmanagers. To avoi blocking the event thread, never use `time.sleep` or perform non-async I/O within a lifespan task. +Tasks execute in the order they are registered. + In dev mode, lifespan tasks will stop and restart when a hot-reload occurs. ## Tasks @@ -38,14 +40,23 @@ async def long_running_task(foo, bar): To register a lifespan task, use `app.register_lifespan_task(coro_func, **kwargs)`. Any keyword arguments specified during registration will be passed to the task. -If the task accepts the special argument, `app`, it will be an instance of the `FastAPI` object -associated with the app. +If the task accepts the special argument, `app`, it will be passed the `Starlette` +application instance. ```python app = rx.App() app.register_lifespan_task(long_running_task, foo=42, bar=os.environ["BAR_PARAM"]) ``` +All tasks must be registered before the app starts. Calling +`register_lifespan_task` after the lifespan has begun (for example, from an +event handler or from within another lifespan task) will raise a `RuntimeError`. + +### Inspecting Registered Tasks + +To get the currently registered lifespan tasks, use `app.get_lifespan_tasks()`, +which returns a `tuple` of tasks in registration order. + ## Context Managers Lifespan tasks can also be defined as async contextmanagers. This is useful for @@ -55,9 +66,6 @@ protocol. Code up to the first `yield` will run when the backend comes up. As the backend is shutting down, the code after the `yield` will run to clean up. -Here is an example borrowed from the FastAPI docs and modified to work with this -interface. - ```python from contextlib import asynccontextmanager @@ -70,7 +78,7 @@ ml_models = \{} @asynccontextmanager -async def setup_model(app: FastAPI): +async def setup_model(app): # Load the ML model ml_models["answer_to_everything"] = fake_answer_to_everything_ml_model yield diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index c8ea12962bf..a62195469c2 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -9,6 +9,7 @@ import inspect import time from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING from reflex_base.utils import console from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError @@ -16,30 +17,85 @@ from .mixin import AppMixin +if TYPE_CHECKING: + from typing_extensions import deprecated + + +def _get_task_name(task: asyncio.Task | Callable) -> str: + """Get a display name for a lifespan task. + + Args: + task: The task to get the name for. + + Returns: + The name of the task. + """ + if isinstance(task, asyncio.Task): + return task.get_name() + return task.__name__ # pyright: ignore[reportAttributeAccessIssue] + @dataclasses.dataclass class LifespanMixin(AppMixin): """A Mixin that allow tasks to run during the whole app lifespan. Attributes: - lifespan_tasks: Lifespan tasks that are planned to run. + lifespan_tasks: Set of lifespan tasks that are planned to run (deprecated). """ - lifespan_tasks: set[asyncio.Task | Callable] = dataclasses.field( - default_factory=set + _lifespan_tasks: dict[asyncio.Task | Callable, None] = dataclasses.field( + default_factory=dict, init=False, repr=False + ) + _lifespan_tasks_started: bool = dataclasses.field( + default=False, init=False, repr=False ) + if TYPE_CHECKING: + # Static deprecation warning for IDE/type checkers. + @property + @deprecated("Use get_lifespan_tasks method instead.") + def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]: + """Get a copy of registered lifespan tasks (deprecated).""" + ... + + else: + + @property + def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]: + """Get a copy of registered lifespan tasks. + + Returns: + A frozenset of registered lifespan tasks. + """ + # Runtime deprecation warning prints to the console when accessed. + console.deprecate( + feature_name="LifespanMixin.lifespan_tasks", + reason="Use get_lifespan_tasks method instead to get a copy of registered lifespan tasks.", + deprecation_version="0.9.0", + removal_version="1.0", + ) + return frozenset(self._lifespan_tasks) + + def get_lifespan_tasks(self) -> tuple[asyncio.Task | Callable, ...]: + """Get a copy of currently registered lifespan tasks. + + Returns: + A tuple of registered lifespan tasks. + """ + return tuple(self._lifespan_tasks) + @contextlib.asynccontextmanager async def _run_lifespan_tasks(self, app: Starlette): + self._lifespan_tasks_started = True running_tasks = [] try: async with contextlib.AsyncExitStack() as stack: - for task in self.lifespan_tasks: - run_msg = f"Started lifespan task: {task.__name__} as {{type}}" # pyright: ignore [reportAttributeAccessIssue] + for task in self._lifespan_tasks: + task_name = _get_task_name(task) + run_msg = f"Started lifespan task: {task_name} as {{type}}" if isinstance(task, asyncio.Task): running_tasks.append(task) else: - task_name = task.__name__ signature = inspect.signature(task) if "app" in signature.parameters: task = functools.partial(task, app=app) @@ -90,15 +146,22 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): Raises: InvalidLifespanTaskTypeError: If the task is a generator function. + RuntimeError: If lifespan tasks are already running. """ + if self._lifespan_tasks_started: + msg = ( + f"Cannot register lifespan task {_get_task_name(task)!r} after " + "lifespan has started. Register all tasks before the app starts." + ) + raise RuntimeError(msg) if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task): msg = f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager." raise InvalidLifespanTaskTypeError(msg) - task_name = task.__name__ # pyright: ignore [reportAttributeAccessIssue] + task_name = _get_task_name(task) if task_kwargs: original_task = task task = functools.partial(task, **task_kwargs) # pyright: ignore [reportArgumentType] functools.update_wrapper(task, original_task) # pyright: ignore [reportArgumentType] - self.lifespan_tasks.add(task) + self._lifespan_tasks[task] = None console.debug(f"Registered lifespan task: {task_name}") diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py index 61be4b87519..dd20dbdba91 100644 --- a/tests/integration/test_lifespan.py +++ b/tests/integration/test_lifespan.py @@ -26,9 +26,12 @@ def LifespanApp( from contextlib import asynccontextmanager import reflex as rx + from reflex.istate.manager.token import BaseStateToken lifespan_task_global = 0 lifespan_context_global = 0 + raw_asyncio_task_global = 0 + connected_tokens: set[str] = set() @asynccontextmanager async def lifespan_context(app, inc: int = 1): # noqa: RUF029 @@ -52,13 +55,47 @@ async def lifespan_task(inc: int = 1): print(f"Lifespan global cancelled: {ce}.") lifespan_task_global = 0 + async def raw_asyncio_task_coro(): + global raw_asyncio_task_global + print("Raw asyncio task started.") + try: + while True: + raw_asyncio_task_global += 1 # pyright: ignore[reportUnboundVariable, reportPossiblyUnboundVariable] + await asyncio.sleep(0.1) + except asyncio.CancelledError as ce: + print(f"Raw asyncio task cancelled: {ce}.") + raw_asyncio_task_global = 0 + + @asynccontextmanager + async def assert_register_blocked_during_lifespan(app): # noqa: RUF029 + """Negative test: registering a task after lifespan has started must raise.""" + from reflex.utils.prerequisites import get_app + + reflex_app = get_app().app + task = asyncio.create_task(raw_asyncio_task_coro(), name="raw_asyncio_task") + try: + reflex_app.register_lifespan_task(task) + except RuntimeError as exc: + print(f"Expected RuntimeError: {exc}") + else: + msg = "register_lifespan_task should have raised RuntimeError" + raise AssertionError(msg) + finally: + task.cancel() + yield + class LifespanState(rx.State): interval: int = 100 + modify_count: int = 0 @rx.event def set_interval(self, interval: int): self.interval = interval + @rx.event + def register_token(self): + connected_tokens.add(self.router.session.client_token) + @rx.var(cache=False) def task_global(self) -> int: return lifespan_task_global @@ -67,14 +104,36 @@ def task_global(self) -> int: def context_global(self) -> int: return lifespan_context_global + @rx.var(cache=False) + def asyncio_task_global(self) -> int: + return raw_asyncio_task_global + @rx.event def tick(self, date): pass + async def modify_state_task(): + from reflex.utils.prerequisites import get_app + + reflex_app = get_app().app + try: + while True: + for token in list(connected_tokens): + async with reflex_app.modify_state( + BaseStateToken(ident=token, cls=LifespanState) + ) as state: + lifespan_state = await state.get_state(LifespanState) + lifespan_state.modify_count += 1 + await asyncio.sleep(0.1) + except asyncio.CancelledError: + print("modify_state_task cancelled.") + def index(): return rx.vstack( rx.text(LifespanState.task_global, id="task_global"), rx.text(LifespanState.context_global, id="context_global"), + rx.text(LifespanState.modify_count, id="modify_count"), + rx.text(LifespanState.asyncio_task_global, id="asyncio_task_global"), rx.button( rx.moment( interval=LifespanState.interval, on_change=LifespanState.tick @@ -84,6 +143,7 @@ def index(): ), id="toggle-tick", ), + on_mount=LifespanState.register_token, ) from fastapi import FastAPI @@ -95,6 +155,9 @@ def index(): app.register_lifespan_task(lifespan_task) app.register_lifespan_task(lifespan_context, inc=2) + app.register_lifespan_task(raw_asyncio_task_coro) + app.register_lifespan_task(assert_register_blocked_during_lifespan) + app.register_lifespan_task(modify_state_task) app.add_page(index) @@ -160,6 +223,63 @@ def lifespan_app( yield harness +def test_lifespan_modify_state(lifespan_app: AppHarness): + """Test that a lifespan task can use app.modify_state to push state updates. + + Args: + lifespan_app: harness for LifespanApp app + """ + assert lifespan_app.app_module is not None, "app module is not found" + assert lifespan_app.app_instance is not None, "app is not running" + driver = lifespan_app.frontend() + + ss = SessionStorage(driver) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + + modify_count = driver.find_element(By.ID, "modify_count") + + # Wait for modify_count to become non-zero (lifespan task is pushing updates) + assert lifespan_app.poll_for_content(modify_count, exp_not_equal="0") + + # Verify it continues to increase + first_value = modify_count.text + next_value = lifespan_app.poll_for_content(modify_count, exp_not_equal=first_value) + assert int(next_value) > int(first_value) + + +def test_lifespan_raw_asyncio_task(lifespan_app: AppHarness): + """Test that a coroutine function registered as a lifespan task runs as an asyncio.Task. + + Args: + lifespan_app: harness for LifespanApp app + """ + assert lifespan_app.app_module is not None, "app module is not found" + assert lifespan_app.app_instance is not None, "app is not running" + driver = lifespan_app.frontend() + + ss = SessionStorage(driver) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + + asyncio_task_global = driver.find_element(By.ID, "asyncio_task_global") + + # Wait for asyncio_task_global to become non-zero + assert lifespan_app.poll_for_content(asyncio_task_global, exp_not_equal="0") + + # Verify it continues to increase + first_value = asyncio_task_global.text + next_value = lifespan_app.poll_for_content( + asyncio_task_global, exp_not_equal=first_value + ) + assert int(next_value) > int(first_value) + assert lifespan_app.app_module.raw_asyncio_task_global > 0 + + +# --- test_lifespan MUST be the last test in this file. --- +# It shuts down the backend and asserts cancellation of lifespan tasks. +# The lifespan_app fixture is session-scoped (expensive to rebuild), so all +# other tests that need a running backend must be defined ABOVE this point. + + def test_lifespan(lifespan_app: AppHarness): """Test the lifespan integration. @@ -195,3 +315,9 @@ def test_lifespan(lifespan_app: AppHarness): # Check that the lifespan tasks have been cancelled assert lifespan_app.app_module.lifespan_task_global == 0 assert lifespan_app.app_module.lifespan_context_global == 4 + assert lifespan_app.app_module.raw_asyncio_task_global == 0 + + +# --- Do NOT add new test cases below this line. --- +# test_lifespan (above) kills the backend; any test defined after it will +# find the harness in a stopped state and fail.