Skip to content
9 changes: 7 additions & 2 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def __post_init__(self):
set_breakpoints(self.style.pop("breakpoints"))

# Set up the API.
self._api = Starlette(lifespan=self._run_lifespan_tasks)
self._api = Starlette()
App._add_cors(self._api)
self._add_default_endpoints()

Expand Down Expand Up @@ -629,6 +629,7 @@ def __call__(self) -> ASGIApp:

if not self._api:
raise ValueError("The app has not been initialized.")

if self._cached_fastapi_app is not None:
asgi_app = self._cached_fastapi_app
asgi_app.mount("", self._api)
Expand All @@ -653,7 +654,11 @@ def __call__(self) -> ASGIApp:
# Transform the asgi app.
asgi_app = api_transformer(asgi_app)

return asgi_app
top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks)
top_asgi_app.mount("", asgi_app)
App._add_cors(top_asgi_app)

return top_asgi_app

def _add_default_endpoints(self):
"""Add default api endpoints (ping)."""
Expand Down
29 changes: 17 additions & 12 deletions reflex/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from reflex.utils import console
from reflex.utils.export import export
from reflex.utils.types import ASGIApp

try:
from selenium import webdriver
Expand Down Expand Up @@ -110,6 +111,7 @@ class AppHarness:
app_module_path: Path
app_module: types.ModuleType | None = None
app_instance: reflex.App | None = None
app_asgi: ASGIApp | None = None
frontend_process: subprocess.Popen | None = None
frontend_url: str | None = None
frontend_output_thread: threading.Thread | None = None
Expand Down Expand Up @@ -270,11 +272,14 @@ def _initialize_app(self):
# Ensure the AppHarness test does not skip State assignment due to running via pytest
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
self.app_module = reflex.utils.prerequisites.get_compiled_app(
# Do not reload the module for pre-existing apps (only apps generated from source)
reload=self.app_source is not None
# Ensure we actually compile the app during first initialization.
self.app_instance, self.app_module = (
reflex.utils.prerequisites.get_and_validate_app(
# Do not reload the module for pre-existing apps (only apps generated from source)
reload=self.app_source is not None
)
)
self.app_instance = self.app_module.app
self.app_asgi = self.app_instance()
if self.app_instance and isinstance(
self.app_instance._state_manager, StateManagerRedis
):
Expand All @@ -300,10 +305,10 @@ def _get_backend_shutdown_handler(self):
async def _shutdown(*args, **kwargs) -> None:
# ensure redis is closed before event loop
if self.app_instance is not None and isinstance(
self.app_instance.state_manager, StateManagerRedis
self.app_instance._state_manager, StateManagerRedis
):
with contextlib.suppress(ValueError):
await self.app_instance.state_manager.close()
await self.app_instance._state_manager.close()

# socketio shutdown handler
if self.app_instance is not None and self.app_instance.sio is not None:
Expand All @@ -323,11 +328,11 @@ async def _shutdown(*args, **kwargs) -> None:
return _shutdown

def _start_backend(self, port: int = 0):
if self.app_instance is None or self.app_instance._api is None:
if self.app_asgi is None:
raise RuntimeError("App was not initialized.")
self.backend = uvicorn.Server(
uvicorn.Config(
app=self.app_instance._api,
app=self.app_asgi,
host="127.0.0.1",
port=port,
)
Expand All @@ -349,13 +354,13 @@ async def _reset_backend_state_manager(self):
if (
self.app_instance is not None
and isinstance(
self.app_instance.state_manager,
self.app_instance._state_manager,
StateManagerRedis,
)
and self.app_instance._state is not None
):
with contextlib.suppress(RuntimeError):
await self.app_instance.state_manager.close()
await self.app_instance._state_manager.close()
self.app_instance._state_manager = StateManagerRedis.create(
state=self.app_instance._state,
)
Expand Down Expand Up @@ -959,12 +964,12 @@ def _wait_frontend(self):
raise RuntimeError("Frontend did not start")

def _start_backend(self):
if self.app_instance is None:
if self.app_asgi is None:
raise RuntimeError("App was not initialized.")
environment.REFLEX_SKIP_COMPILE.set(True)
self.backend = uvicorn.Server(
uvicorn.Config(
app=self.app_instance,
app=self.app_asgi,
host="127.0.0.1",
port=0,
workers=reflex.utils.processes.get_num_workers(),
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_connection_banner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Test case for displaying the connection banner when the websocket drops."""

import functools
from collections.abc import Generator

import pytest
Expand Down Expand Up @@ -77,7 +76,7 @@ def connection_banner(

with AppHarness.create(
root=tmp_path,
app_source=functools.partial(ConnectionBanner),
app_source=ConnectionBanner,
app_name=(
"connection_banner_reflex_cloud"
if simulate_compile_context == constants.CompileContext.DEPLOY
Expand Down
63 changes: 57 additions & 6 deletions tests/integration/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test cases for the Starlette lifespan integration."""

import functools
from collections.abc import Generator

import pytest
Expand All @@ -10,8 +11,15 @@
from .utils import SessionStorage


def LifespanApp():
"""App with lifespan tasks and context."""
def LifespanApp(
mount_cached_fastapi: bool = False, mount_api_transformer: bool = False
) -> None:
"""App with lifespan tasks and context.

Args:
mount_cached_fastapi: Whether to mount the cached FastAPI app.
mount_api_transformer: Whether to mount the API transformer.
"""
import asyncio
from contextlib import asynccontextmanager

Expand Down Expand Up @@ -72,25 +80,68 @@ def index():
),
)

app = rx.App()
from fastapi import FastAPI

app = rx.App(api_transformer=FastAPI() if mount_api_transformer else None)

if mount_cached_fastapi:
assert app.api is not None

app.register_lifespan_task(lifespan_task)
app.register_lifespan_task(lifespan_context, inc=2)
app.add_page(index)


@pytest.fixture(
params=[False, True], ids=["no_api_transformer", "mount_api_transformer"]
)
def mount_api_transformer(request: pytest.FixtureRequest) -> bool:
"""Whether to use api_transformer in the app.

Args:
request: pytest fixture request object

Returns:
bool: Whether to use api_transformer
"""
return request.param


@pytest.fixture(params=[False, True], ids=["no_fastapi", "mount_cached_fastapi"])
def mount_cached_fastapi(request: pytest.FixtureRequest) -> bool:
"""Whether to use cached FastAPI in the app (app.api).

Args:
request: pytest fixture request object

Returns:
Whether to use cached FastAPI
"""
return request.param


@pytest.fixture()
def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]:
def lifespan_app(
tmp_path, mount_api_transformer: bool, mount_cached_fastapi: bool
) -> Generator[AppHarness, None, None]:
"""Start LifespanApp app at tmp_path via AppHarness.

Args:
tmp_path: pytest tmp_path fixture
mount_api_transformer: Whether to mount the API transformer.
mount_cached_fastapi: Whether to mount the cached FastAPI app.

Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path,
app_source=LifespanApp,
app_source=functools.partial(
LifespanApp,
mount_cached_fastapi=mount_cached_fastapi,
mount_api_transformer=mount_api_transformer,
),
app_name=f"lifespanapp_fastapi{mount_cached_fastapi}_transformer{mount_api_transformer}",
) as harness:
yield harness

Expand All @@ -112,7 +163,7 @@ async def test_lifespan(lifespan_app: AppHarness):
context_global = driver.find_element(By.ID, "context_global")
task_global = driver.find_element(By.ID, "task_global")

assert context_global.text == "2"
assert lifespan_app.poll_for_content(context_global, exp_not_equal="0") == "2"
assert lifespan_app.app_module.lifespan_context_global == 2

original_task_global_text = task_global.text
Expand Down
Loading