diff --git a/reflex/app.py b/reflex/app.py index 6245a9f0d1d..97db3cfc81b 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -9,7 +9,6 @@ import dataclasses import functools import inspect -import io import json import operator import sys @@ -19,6 +18,7 @@ from collections.abc import ( AsyncGenerator, AsyncIterator, + Awaitable, Callable, Coroutine, Mapping, @@ -1886,6 +1886,27 @@ async def health(_request: Request) -> JSONResponse: return JSONResponse(content=health_status, status_code=status_code) +class _UploadStreamingResponse(StreamingResponse): + """Streaming response that always releases upload form resources.""" + + _on_finish: Callable[[], Awaitable[None]] + + def __init__( + self, + *args: Any, + on_finish: Callable[[], Awaitable[None]], + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._on_finish = on_finish + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + try: + await super().__call__(scope, receive, send) + finally: + await self._on_finish() + + def upload(app: App): """Upload a file. @@ -1915,87 +1936,98 @@ async def upload_file(request: Request): # Get the files from the request. try: - files = await request.form() + form_data = await request.form() except ClientDisconnect: return Response() # user cancelled - files = files.getlist("files") - if not files: - msg = "No files were uploaded." - raise UploadValueError(msg) - - token = request.headers.get("reflex-client-token") - handler = request.headers.get("reflex-event-handler") - - if not token or not handler: - raise HTTPException( - status_code=400, - detail="Missing reflex-client-token or reflex-event-handler header.", - ) - # Get the state for the session. - substate_token = _substate_key(token, handler.rpartition(".")[0]) - state = await app.state_manager.get_state(substate_token) + form_data_closed = False - handler_upload_param = () + async def _close_form_data() -> None: + """Close the parsed form data exactly once.""" + nonlocal form_data_closed + if form_data_closed: + return + form_data_closed = True + await form_data.close() - _current_state, event_handler = state._get_event_handler(handler) + async def _create_upload_event() -> Event: + """Create an upload event using the live Starlette temp files. - if event_handler.is_background: - msg = f"@rx.event(background=True) is not supported for upload handler `{handler}`." - raise UploadTypeError(msg) - func = event_handler.fn - if isinstance(func, functools.partial): - func = func.func - for k, v in get_type_hints(func).items(): - if types.is_generic_alias(v) and types._issubclass( - get_args(v)[0], - UploadFile, - ): - handler_upload_param = (k, v) - break + Returns: + The upload event backed by the original temp files. + """ + files = form_data.getlist("files") + if not files: + msg = "No files were uploaded." + raise UploadValueError(msg) + + token = request.headers.get("reflex-client-token") + handler = request.headers.get("reflex-event-handler") + + if not token or not handler: + raise HTTPException( + status_code=400, + detail="Missing reflex-client-token or reflex-event-handler header.", + ) - if not handler_upload_param: - msg = ( - f"`{handler}` handler should have a parameter annotated as " - "list[rx.UploadFile]" - ) - raise UploadValueError(msg) - - # Make a copy of the files as they are closed after the request. - # This behaviour changed from fastapi 0.103.0 to 0.103.1 as the - # AsyncExitStack was removed from the request scope and is now - # part of the routing function which closes this before the - # event is handled. - file_copies = [] - for file in files: - if not isinstance(file, StarletteUploadFile): - raise UploadValueError( - "Uploaded file is not an UploadFile." + str(file) + # Get the state for the session. + substate_token = _substate_key(token, handler.rpartition(".")[0]) + state = await app.state_manager.get_state(substate_token) + + handler_upload_param = () + + _current_state, event_handler = state._get_event_handler(handler) + + if event_handler.is_background: + msg = f"@rx.event(background=True) is not supported for upload handler `{handler}`." + raise UploadTypeError(msg) + func = event_handler.fn + if isinstance(func, functools.partial): + func = func.func + for k, v in get_type_hints(func).items(): + if types.is_generic_alias(v) and types._issubclass( + get_args(v)[0], + UploadFile, + ): + handler_upload_param = (k, v) + break + + if not handler_upload_param: + msg = ( + f"`{handler}` handler should have a parameter annotated as " + "list[rx.UploadFile]" ) - content_copy = io.BytesIO() - content_copy.write(await file.read()) - content_copy.seek(0) - file_copies.append( - UploadFile( - file=content_copy, - path=Path(file.filename.lstrip("/")) if file.filename else None, - size=file.size, - headers=file.headers, + raise UploadValueError(msg) + + # Keep the parsed form data alive until the upload event finishes so + # the underlying Starlette temp files remain available to the handler. + file_uploads = [] + for file in files: + if not isinstance(file, StarletteUploadFile): + raise UploadValueError( + "Uploaded file is not an UploadFile." + str(file) + ) + file_uploads.append( + UploadFile( + file=file.file, + path=Path(file.filename.lstrip("/")) if file.filename else None, + size=file.size, + headers=file.headers, + ) ) - ) - for file in files: - if not isinstance(file, StarletteUploadFile): - raise UploadValueError( - "Uploaded file is not an UploadFile." + str(file) - ) - await file.close() + return Event( + token=token, + name=handler, + payload={handler_upload_param[0]: file_uploads}, + ) - event = Event( - token=token, - name=handler, - payload={handler_upload_param[0]: file_copies}, - ) + event: Event | None = None + try: + event = await _create_upload_event() + finally: + if event is None: + await _close_form_data() async def _ndjson_updates(): """Process the upload event, generating ndjson updates. @@ -2013,9 +2045,10 @@ async def _ndjson_updates(): yield update.json() + "\n" # Stream updates to client - return StreamingResponse( + return _UploadStreamingResponse( _ndjson_updates(), media_type="application/x-ndjson", + on_finish=_close_form_data, ) return upload_file diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py index 1c2d32a3bb6..6c732796a73 100644 --- a/tests/units/states/upload.py +++ b/tests/units/states/upload.py @@ -1,7 +1,6 @@ """Test states for upload-related tests.""" from pathlib import Path -from typing import ClassVar import reflex as rx from reflex.state import BaseState, State @@ -35,11 +34,11 @@ async def handle_upload(self, files: list[rx.UploadFile]): """ -class FileUploadState(State): - """The base state for uploading a file.""" +class _FileUploadMixin(BaseState, mixin=True): + """Common fields and handlers for upload state tests.""" img_list: list[str] - _tmp_path: ClassVar[Path] + _tmp_path: Path = Path() async def handle_upload2(self, files): """Handle the upload of a file. @@ -64,6 +63,7 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]): # Update the img var. self.img_list.append(file.name) + yield @rx.event(background=True) async def bg_upload(self, files: list[rx.UploadFile]): @@ -74,87 +74,21 @@ async def bg_upload(self, files: list[rx.UploadFile]): """ +class FileUploadState(_FileUploadMixin, State): + """The base state for uploading a file.""" + + class FileStateBase1(State): """The base state for a child FileUploadState.""" -class ChildFileUploadState(FileStateBase1): +class ChildFileUploadState(_FileUploadMixin, FileStateBase1): """The child state for uploading a file.""" - img_list: list[str] - _tmp_path: ClassVar[Path] - - async def handle_upload2(self, files): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - - async def multi_handle_upload(self, files: list[rx.UploadFile]): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - for file in files: - upload_data = await file.read() - assert file.name is not None - outfile = self._tmp_path / file.name - - # Save the file. - outfile.write_bytes(upload_data) - - # Update the img var. - self.img_list.append(file.name) - - @rx.event(background=True) - async def bg_upload(self, files: list[rx.UploadFile]): - """Background task cannot be upload handler. - - Args: - files: The uploaded files. - """ - class FileStateBase2(FileStateBase1): """The parent state for a grandchild FileUploadState.""" -class GrandChildFileUploadState(FileStateBase2): +class GrandChildFileUploadState(_FileUploadMixin, FileStateBase2): """The child state for uploading a file.""" - - img_list: list[str] - _tmp_path: ClassVar[Path] - - async def handle_upload2(self, files): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - - async def multi_handle_upload(self, files: list[rx.UploadFile]): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - for file in files: - upload_data = await file.read() - assert file.name is not None - outfile = self._tmp_path / file.name - - # Save the file. - outfile.write_bytes(upload_data) - - # Update the img var. - self.img_list.append(file.name) - - @rx.event(background=True) - async def bg_upload(self, files: list[rx.UploadFile]): - """Background task cannot be upload handler. - - Args: - files: The uploaded files. - """ diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 6efd006f1fa..25c71c0d17e 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio import functools import io +import json import unittest.mock import uuid from collections.abc import Generator @@ -14,7 +16,7 @@ import pytest from pytest_mock import MockerFixture from starlette.applications import Starlette -from starlette.datastructures import UploadFile +from starlette.datastructures import FormData, UploadFile from starlette.responses import StreamingResponse import reflex as rx @@ -939,7 +941,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker: MockerFix Args: tmp_path: Temporary path. state: The state class. - delta: Expected delta + delta: Expected delta after processing all files. token: a Token. mocker: pytest mocker object. """ @@ -947,17 +949,13 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker: MockerFix "reflex.state.State.class_subclasses", {state if state is FileUploadState else FileStateBase1}, ) - state._tmp_path = tmp_path # The App state must be the "root" of the state tree app = App() app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] - current_state = await app.state_manager.get_state(_substate_key(token, state)) + async with app.modify_state(_substate_key(token, state)) as root_state: + root_state.get_substate(state.get_full_name().split("."))._tmp_path = tmp_path data = b"This is binary data" - # Create a binary IO object and write data to it - bio = io.BytesIO() - bio.write(data) - request_mock = unittest.mock.Mock() request_mock.headers = { "reflex-client-token": token, @@ -966,44 +964,231 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker: MockerFix file1 = UploadFile( filename="image1.jpg", - file=bio, + file=io.BytesIO(data), ) file2 = UploadFile( filename="image2.jpg", - file=bio, + file=io.BytesIO(data), ) async def form(): # noqa: RUF029 - files_mock = unittest.mock.Mock() + return FormData([("files", file1), ("files", file2)]) + + request_mock.form = form + + upload_fn = upload(app) + streaming_response = await upload_fn(request_mock) + assert isinstance(streaming_response, StreamingResponse) + # Handler yields after each file, producing intermediate + final updates. + updates = [] + async for state_update in streaming_response.body_iterator: + updates.append(json.loads(str(state_update))) + # 2 intermediate yields + 1 final + assert len(updates) == 3 + assert all(not u["final"] for u in updates[:-1]) + assert updates[-1]["final"] + + # The last intermediate update should contain the full cumulative delta. + assert updates[1]["delta"] == delta + + await app.state_manager.close() + + +@pytest.mark.asyncio +async def test_upload_file_keeps_form_open_until_stream_completes( + tmp_path, + token: str, + mocker: MockerFixture, +): + """Test that upload files are not eagerly copied into memory. + + Uses two distinct BinaryIO instances, sets _tmp_path via modify_state, + and verifies that both file handles remain open during streaming and are + closed (along with correct file content) after the stream completes. + + Args: + tmp_path: Temporary path. + token: A token. + mocker: pytest mocker object. + """ + mocker.patch( + "reflex.state.State.class_subclasses", + {FileUploadState}, + ) + app = App() + app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] - def getlist(key: str): - assert key == "files" - return [file1, file2] + # Set _tmp_path via modify_state instead of setting class attribute directly. + async with app.modify_state(_substate_key(token, FileUploadState)) as root_state: + root_state.get_substate( + FileUploadState.get_full_name().split(".") + )._tmp_path = tmp_path - files_mock.getlist = getlist + request_mock = unittest.mock.Mock() + request_mock.headers = { + "reflex-client-token": token, + "reflex-event-handler": f"{FileUploadState.get_full_name()}.multi_handle_upload", + } - return files_mock + data1 = b"contents of image one" + data2 = b"contents of image two" + bio1 = io.BytesIO(data1) + bio2 = io.BytesIO(data2) + file1 = UploadFile(filename="image1.jpg", file=bio1) + file2 = UploadFile(filename="image2.jpg", file=bio2) + + form_data = FormData([("files", file1), ("files", file2)]) + original_close = form_data.close + form_close = AsyncMock(side_effect=original_close) + form_data.close = form_close + + async def form(): # noqa: RUF029 + return form_data request_mock.form = form upload_fn = upload(app) streaming_response = await upload_fn(request_mock) + assert isinstance(streaming_response, StreamingResponse) - async for state_update in streaming_response.body_iterator: - assert ( - state_update - == StateUpdate(delta=delta, events=[], final=True).json() + "\n" - ) + # Before streaming starts, nothing should be read or closed. + assert form_close.await_count == 0 + assert not bio1.closed + assert not bio2.closed - if environment.REFLEX_OPLOCK_ENABLED.get(): - await app.state_manager.close() + # Drive the response through the full ASGI lifecycle so that + # _UploadStreamingResponse.__call__ invokes the on_finish callback. + scope = {"type": "http"} + done = asyncio.Event() + + async def receive(): + await done.wait() + return {"type": "http.disconnect"} + + async def send(message): # noqa: RUF029 + if message.get("type") == "http.response.body" and not message.get("body"): + done.set() + + await streaming_response(scope, receive, send) + + # After the ASGI call completes, form_data.close() should have been called, + # closing both underlying file handles. + assert form_close.await_count == 1 + assert bio1.closed + assert bio2.closed + + # Verify files were written to the tmp dir with the correct content. + assert (tmp_path / "image1.jpg").read_bytes() == data1 + assert (tmp_path / "image2.jpg").read_bytes() == data2 + + await app.state_manager.close() + + +@pytest.mark.asyncio +async def test_upload_file_closes_form_on_event_creation_cancellation( + token: str, + mocker: MockerFixture, +): + """Test that cancellation during upload event creation closes form data.""" + mocker.patch( + "reflex.state.State.class_subclasses", + {FileUploadState}, + ) + app = App() + + request_mock = unittest.mock.Mock() + request_mock.headers = { + "reflex-client-token": token, + "reflex-event-handler": f"{FileUploadState.get_full_name()}.multi_handle_upload", + } + + file1 = UploadFile(filename="image1.jpg", file=io.BytesIO(b"data")) + form_data = FormData([("files", file1)]) + original_close = form_data.close + form_close = AsyncMock(side_effect=original_close) + form_data.close = form_close + + async def form(): # noqa: RUF029 + return form_data + + async def cancelled_get_state(*_args, **_kwargs): + await asyncio.sleep(0) + raise asyncio.CancelledError + + request_mock.form = form + mocker.patch.object(app.state_manager, "get_state", side_effect=cancelled_get_state) + + upload_fn = upload(app) + with pytest.raises(asyncio.CancelledError): + await upload_fn(request_mock) + + assert form_close.await_count == 1 + assert file1.file.closed + + await app.state_manager.close() + + +@pytest.mark.asyncio +async def test_upload_file_closes_form_if_response_cancelled_before_stream_starts( + tmp_path, + token: str, + mocker: MockerFixture, +): + """Test that response cancellation before iteration still closes form data.""" + mocker.patch( + "reflex.state.State.class_subclasses", + {FileUploadState}, + ) + app = App() + app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + + async with app.modify_state(_substate_key(token, FileUploadState)) as root_state: + root_state.get_substate( + FileUploadState.get_full_name().split(".") + )._tmp_path = tmp_path + + request_mock = unittest.mock.Mock() + request_mock.headers = { + "reflex-client-token": token, + "reflex-event-handler": f"{FileUploadState.get_full_name()}.multi_handle_upload", + } + + bio = io.BytesIO(b"contents of image one") + file1 = UploadFile(filename="image1.jpg", file=bio) + form_data = FormData([("files", file1)]) + original_close = form_data.close + form_close = AsyncMock(side_effect=original_close) + form_data.close = form_close + + async def form(): # noqa: RUF029 + return form_data + + async def receive(): + await asyncio.sleep(0) + return {"type": "http.disconnect"} + + async def send(_message): + await asyncio.sleep(0) + raise asyncio.CancelledError + + request_mock.form = form - current_state = await app.state_manager.get_state(_substate_key(token, state)) - state_dict = current_state.dict()[state.get_full_name()] - assert state_dict["img_list" + FIELD_MARKER] == [ - "image1.jpg", - "image2.jpg", - ] + upload_fn = upload(app) + streaming_response = await upload_fn(request_mock) + + assert isinstance(streaming_response, StreamingResponse) + assert form_close.await_count == 0 + assert not bio.closed + + with pytest.raises(asyncio.CancelledError): + await streaming_response( + {"type": "http", "asgi": {"spec_version": "2.4"}}, + receive, + send, + ) + + assert form_close.await_count == 1 + assert bio.closed await app.state_manager.close() @@ -1021,7 +1206,6 @@ async def test_upload_file_without_annotation(state, tmp_path, token): tmp_path: Temporary path. token: a Token. """ - state._tmp_path = tmp_path app = App(_state=State) request_mock = unittest.mock.Mock() @@ -1030,16 +1214,10 @@ async def test_upload_file_without_annotation(state, tmp_path, token): "reflex-event-handler": f"{state.get_full_name()}.handle_upload2", } - async def form(): # noqa: RUF029 - files_mock = unittest.mock.Mock() - - def getlist(key: str): - assert key == "files" - return [unittest.mock.Mock(filename="image1.jpg")] - - files_mock.getlist = getlist + file1 = UploadFile(filename="image1.jpg", file=io.BytesIO(b"data")) - return files_mock + async def form(): # noqa: RUF029 + return FormData([("files", file1)]) request_mock.form = form @@ -1067,7 +1245,6 @@ async def test_upload_file_background(state, tmp_path, token): tmp_path: Temporary path. token: a Token. """ - state._tmp_path = tmp_path app = App(_state=State) request_mock = unittest.mock.Mock() @@ -1076,16 +1253,10 @@ async def test_upload_file_background(state, tmp_path, token): "reflex-event-handler": f"{state.get_full_name()}.bg_upload", } - async def form(): # noqa: RUF029 - files_mock = unittest.mock.Mock() + file1 = UploadFile(filename="image1.jpg", file=io.BytesIO(b"data")) - def getlist(key: str): - assert key == "files" - return [unittest.mock.Mock(filename="image1.jpg")] - - files_mock.getlist = getlist - - return files_mock + async def form(): # noqa: RUF029 + return FormData([("files", file1)]) request_mock.form = form