diff --git a/reflex/app.py b/reflex/app.py index 96556e2af7c..f20adf67185 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -488,7 +488,7 @@ def __post_init__(self): set_breakpoints(self.style.pop("breakpoints")) # Set up the API. - self._api = Starlette(lifespan=self._run_lifespan_tasks) + self._api = Starlette() App._add_cors(self._api) self._add_default_endpoints() @@ -629,6 +629,7 @@ def __call__(self) -> ASGIApp: if not self._api: raise ValueError("The app has not been initialized.") + if self._cached_fastapi_app is not None: asgi_app = self._cached_fastapi_app asgi_app.mount("", self._api) @@ -653,7 +654,11 @@ def __call__(self) -> ASGIApp: # Transform the asgi app. asgi_app = api_transformer(asgi_app) - return asgi_app + top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks) + top_asgi_app.mount("", asgi_app) + App._add_cors(top_asgi_app) + + return top_asgi_app def _add_default_endpoints(self): """Add default api endpoints (ping).""" diff --git a/reflex/testing.py b/reflex/testing.py index 47ea2976b46..2c4d50b62be 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -45,6 +45,7 @@ ) from reflex.utils import console from reflex.utils.export import export +from reflex.utils.types import ASGIApp try: from selenium import webdriver @@ -110,6 +111,7 @@ class AppHarness: app_module_path: Path app_module: types.ModuleType | None = None app_instance: reflex.App | None = None + app_asgi: ASGIApp | None = None frontend_process: subprocess.Popen | None = None frontend_url: str | None = None frontend_output_thread: threading.Thread | None = None @@ -270,11 +272,14 @@ def _initialize_app(self): # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) os.environ[reflex.constants.APP_HARNESS_FLAG] = "true" - self.app_module = reflex.utils.prerequisites.get_compiled_app( - # Do not reload the module for pre-existing apps (only apps generated from source) - reload=self.app_source is not None + # Ensure we actually compile the app during first initialization. + self.app_instance, self.app_module = ( + reflex.utils.prerequisites.get_and_validate_app( + # Do not reload the module for pre-existing apps (only apps generated from source) + reload=self.app_source is not None + ) ) - self.app_instance = self.app_module.app + self.app_asgi = self.app_instance() if self.app_instance and isinstance( self.app_instance._state_manager, StateManagerRedis ): @@ -300,10 +305,10 @@ def _get_backend_shutdown_handler(self): async def _shutdown(*args, **kwargs) -> None: # ensure redis is closed before event loop if self.app_instance is not None and isinstance( - self.app_instance.state_manager, StateManagerRedis + self.app_instance._state_manager, StateManagerRedis ): with contextlib.suppress(ValueError): - await self.app_instance.state_manager.close() + await self.app_instance._state_manager.close() # socketio shutdown handler if self.app_instance is not None and self.app_instance.sio is not None: @@ -323,11 +328,11 @@ async def _shutdown(*args, **kwargs) -> None: return _shutdown def _start_backend(self, port: int = 0): - if self.app_instance is None or self.app_instance._api is None: + if self.app_asgi is None: raise RuntimeError("App was not initialized.") self.backend = uvicorn.Server( uvicorn.Config( - app=self.app_instance._api, + app=self.app_asgi, host="127.0.0.1", port=port, ) @@ -349,13 +354,13 @@ async def _reset_backend_state_manager(self): if ( self.app_instance is not None and isinstance( - self.app_instance.state_manager, + self.app_instance._state_manager, StateManagerRedis, ) and self.app_instance._state is not None ): with contextlib.suppress(RuntimeError): - await self.app_instance.state_manager.close() + await self.app_instance._state_manager.close() self.app_instance._state_manager = StateManagerRedis.create( state=self.app_instance._state, ) @@ -959,12 +964,12 @@ def _wait_frontend(self): raise RuntimeError("Frontend did not start") def _start_backend(self): - if self.app_instance is None: + if self.app_asgi is None: raise RuntimeError("App was not initialized.") environment.REFLEX_SKIP_COMPILE.set(True) self.backend = uvicorn.Server( uvicorn.Config( - app=self.app_instance, + app=self.app_asgi, host="127.0.0.1", port=0, workers=reflex.utils.processes.get_num_workers(), diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index fdb31db747c..e6b27617699 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -1,6 +1,5 @@ """Test case for displaying the connection banner when the websocket drops.""" -import functools from collections.abc import Generator import pytest @@ -77,7 +76,7 @@ def connection_banner( with AppHarness.create( root=tmp_path, - app_source=functools.partial(ConnectionBanner), + app_source=ConnectionBanner, app_name=( "connection_banner_reflex_cloud" if simulate_compile_context == constants.CompileContext.DEPLOY diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py index 50eacff6962..c084c48bcec 100644 --- a/tests/integration/test_lifespan.py +++ b/tests/integration/test_lifespan.py @@ -1,5 +1,6 @@ """Test cases for the Starlette lifespan integration.""" +import functools from collections.abc import Generator import pytest @@ -10,8 +11,15 @@ from .utils import SessionStorage -def LifespanApp(): - """App with lifespan tasks and context.""" +def LifespanApp( + mount_cached_fastapi: bool = False, mount_api_transformer: bool = False +) -> None: + """App with lifespan tasks and context. + + Args: + mount_cached_fastapi: Whether to mount the cached FastAPI app. + mount_api_transformer: Whether to mount the API transformer. + """ import asyncio from contextlib import asynccontextmanager @@ -72,25 +80,68 @@ def index(): ), ) - app = rx.App() + from fastapi import FastAPI + + app = rx.App(api_transformer=FastAPI() if mount_api_transformer else None) + + if mount_cached_fastapi: + assert app.api is not None + app.register_lifespan_task(lifespan_task) app.register_lifespan_task(lifespan_context, inc=2) app.add_page(index) +@pytest.fixture( + params=[False, True], ids=["no_api_transformer", "mount_api_transformer"] +) +def mount_api_transformer(request: pytest.FixtureRequest) -> bool: + """Whether to use api_transformer in the app. + + Args: + request: pytest fixture request object + + Returns: + bool: Whether to use api_transformer + """ + return request.param + + +@pytest.fixture(params=[False, True], ids=["no_fastapi", "mount_cached_fastapi"]) +def mount_cached_fastapi(request: pytest.FixtureRequest) -> bool: + """Whether to use cached FastAPI in the app (app.api). + + Args: + request: pytest fixture request object + + Returns: + Whether to use cached FastAPI + """ + return request.param + + @pytest.fixture() -def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]: +def lifespan_app( + tmp_path, mount_api_transformer: bool, mount_cached_fastapi: bool +) -> Generator[AppHarness, None, None]: """Start LifespanApp app at tmp_path via AppHarness. Args: tmp_path: pytest tmp_path fixture + mount_api_transformer: Whether to mount the API transformer. + mount_cached_fastapi: Whether to mount the cached FastAPI app. Yields: running AppHarness instance """ with AppHarness.create( root=tmp_path, - app_source=LifespanApp, + app_source=functools.partial( + LifespanApp, + mount_cached_fastapi=mount_cached_fastapi, + mount_api_transformer=mount_api_transformer, + ), + app_name=f"lifespanapp_fastapi{mount_cached_fastapi}_transformer{mount_api_transformer}", ) as harness: yield harness @@ -112,7 +163,7 @@ async def test_lifespan(lifespan_app: AppHarness): context_global = driver.find_element(By.ID, "context_global") task_global = driver.find_element(By.ID, "task_global") - assert context_global.text == "2" + assert lifespan_app.poll_for_content(context_global, exp_not_equal="0") == "2" assert lifespan_app.app_module.lifespan_context_global == 2 original_task_global_text = task_global.text