diff --git a/reflex/app.py b/reflex/app.py index 34ee12689b2..f8f1a63ef76 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -489,7 +489,7 @@ def __post_init__(self): # Set up the API. self._api = Starlette(lifespan=self._run_lifespan_tasks) - self._add_cors() + App._add_cors(self._api) self._add_default_endpoints() for clz in App.__mro__: @@ -613,19 +613,6 @@ def __call__(self) -> ASGIApp: Returns: The backend api. """ - if self._cached_fastapi_app is not None: - asgi_app = self._cached_fastapi_app - - if not asgi_app or not self._api: - raise ValueError("The app has not been initialized.") - - asgi_app.mount("", self._api) - else: - asgi_app = self._api - - if not asgi_app: - raise ValueError("The app has not been initialized.") - # For py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). self._apply_decorated_pages() @@ -637,9 +624,17 @@ def __call__(self) -> ASGIApp: # Force background compile errors to print eagerly lambda f: f.result() ) - # Wait for the compile to finish in prod mode to ensure all optional endpoints are mounted. - if is_prod_mode(): - compile_future.result() + # Wait for the compile to finish to ensure all optional endpoints are mounted. + compile_future.result() + + if not self._api: + raise ValueError("The app has not been initialized.") + if self._cached_fastapi_app is not None: + asgi_app = self._cached_fastapi_app + asgi_app.mount("", self._api) + App._add_cors(asgi_app) + else: + asgi_app = self._api if self.api_transformer is not None: api_transformers: Sequence[Starlette | Callable[[ASGIApp], ASGIApp]] = ( @@ -651,6 +646,7 @@ def __call__(self) -> ASGIApp: for api_transformer in api_transformers: if isinstance(api_transformer, Starlette): # Mount the api to the fastapi app. + App._add_cors(api_transformer) api_transformer.mount("", asgi_app) asgi_app = api_transformer else: @@ -709,11 +705,14 @@ def _add_optional_endpoints(self): if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get(): self.add_all_routes_endpoint() - def _add_cors(self): - """Add CORS middleware to the app.""" - if not self._api: - return - self._api.add_middleware( + @staticmethod + def _add_cors(api: Starlette): + """Add CORS middleware to the app. + + Args: + api: The Starlette app to add CORS middleware to. + """ + api.add_middleware( cors.CORSMiddleware, allow_credentials=True, allow_methods=["*"], diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 1e9fa33f1b3..7857e1606c8 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1502,6 +1502,7 @@ def test_raise_on_state(): def test_call_app(): """Test that the app can be called.""" app = App() + app._compile = unittest.mock.Mock() api = app() assert isinstance(api, Starlette)