Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def __post_init__(self):

# Set up the API.
self._api = Starlette(lifespan=self._run_lifespan_tasks)
self._add_cors()
App._add_cors(self._api)
self._add_default_endpoints()

for clz in App.__mro__:
Expand Down Expand Up @@ -613,19 +613,6 @@ def __call__(self) -> ASGIApp:
Returns:
The backend api.
"""
if self._cached_fastapi_app is not None:
asgi_app = self._cached_fastapi_app

if not asgi_app or not self._api:
raise ValueError("The app has not been initialized.")

asgi_app.mount("", self._api)
else:
asgi_app = self._api

if not asgi_app:
raise ValueError("The app has not been initialized.")

# For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
self._apply_decorated_pages()
Expand All @@ -637,9 +624,17 @@ def __call__(self) -> ASGIApp:
# Force background compile errors to print eagerly
lambda f: f.result()
)
# Wait for the compile to finish in prod mode to ensure all optional endpoints are mounted.
if is_prod_mode():
compile_future.result()
# Wait for the compile to finish to ensure all optional endpoints are mounted.
compile_future.result()

if not self._api:
raise ValueError("The app has not been initialized.")
if self._cached_fastapi_app is not None:
asgi_app = self._cached_fastapi_app
asgi_app.mount("", self._api)
App._add_cors(asgi_app)
else:
asgi_app = self._api

if self.api_transformer is not None:
api_transformers: Sequence[Starlette | Callable[[ASGIApp], ASGIApp]] = (
Expand All @@ -651,6 +646,7 @@ def __call__(self) -> ASGIApp:
for api_transformer in api_transformers:
if isinstance(api_transformer, Starlette):
# Mount the api to the fastapi app.
App._add_cors(api_transformer)
api_transformer.mount("", asgi_app)
asgi_app = api_transformer
else:
Expand Down Expand Up @@ -709,11 +705,14 @@ def _add_optional_endpoints(self):
if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
self.add_all_routes_endpoint()

def _add_cors(self):
"""Add CORS middleware to the app."""
if not self._api:
return
self._api.add_middleware(
@staticmethod
def _add_cors(api: Starlette):
"""Add CORS middleware to the app.

Args:
api: The Starlette app to add CORS middleware to.
"""
api.add_middleware(
cors.CORSMiddleware,
allow_credentials=True,
allow_methods=["*"],
Expand Down
1 change: 1 addition & 0 deletions tests/units/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,7 @@ def test_raise_on_state():
def test_call_app():
"""Test that the app can be called."""
app = App()
app._compile = unittest.mock.Mock()
api = app()
assert isinstance(api, Starlette)

Expand Down
Loading