Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 105 additions & 72 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import dataclasses
import functools
import inspect
import io
import json
import operator
import sys
Expand All @@ -19,6 +18,7 @@
from collections.abc import (
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Mapping,
Expand Down Expand Up @@ -1886,6 +1886,27 @@ async def health(_request: Request) -> JSONResponse:
return JSONResponse(content=health_status, status_code=status_code)


class _UploadStreamingResponse(StreamingResponse):
"""Streaming response that always releases upload form resources."""

_on_finish: Callable[[], Awaitable[None]]

def __init__(
self,
*args: Any,
on_finish: Callable[[], Awaitable[None]],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._on_finish = on_finish

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
try:
await super().__call__(scope, receive, send)
finally:
await self._on_finish()


def upload(app: App):
"""Upload a file.

Expand Down Expand Up @@ -1915,87 +1936,98 @@ async def upload_file(request: Request):

# Get the files from the request.
try:
files = await request.form()
form_data = await request.form()
except ClientDisconnect:
return Response() # user cancelled
files = files.getlist("files")
if not files:
msg = "No files were uploaded."
raise UploadValueError(msg)

token = request.headers.get("reflex-client-token")
handler = request.headers.get("reflex-event-handler")

if not token or not handler:
raise HTTPException(
status_code=400,
detail="Missing reflex-client-token or reflex-event-handler header.",
)

# Get the state for the session.
substate_token = _substate_key(token, handler.rpartition(".")[0])
state = await app.state_manager.get_state(substate_token)
form_data_closed = False

handler_upload_param = ()
async def _close_form_data() -> None:
"""Close the parsed form data exactly once."""
nonlocal form_data_closed
if form_data_closed:
return
form_data_closed = True
await form_data.close()

_current_state, event_handler = state._get_event_handler(handler)
async def _create_upload_event() -> Event:
"""Create an upload event using the live Starlette temp files.

if event_handler.is_background:
msg = f"@rx.event(background=True) is not supported for upload handler `{handler}`."
raise UploadTypeError(msg)
func = event_handler.fn
if isinstance(func, functools.partial):
func = func.func
for k, v in get_type_hints(func).items():
if types.is_generic_alias(v) and types._issubclass(
get_args(v)[0],
UploadFile,
):
handler_upload_param = (k, v)
break
Returns:
The upload event backed by the original temp files.
"""
files = form_data.getlist("files")
if not files:
msg = "No files were uploaded."
raise UploadValueError(msg)

token = request.headers.get("reflex-client-token")
handler = request.headers.get("reflex-event-handler")

if not token or not handler:
raise HTTPException(
status_code=400,
detail="Missing reflex-client-token or reflex-event-handler header.",
)

if not handler_upload_param:
msg = (
f"`{handler}` handler should have a parameter annotated as "
"list[rx.UploadFile]"
)
raise UploadValueError(msg)

# Make a copy of the files as they are closed after the request.
# This behaviour changed from fastapi 0.103.0 to 0.103.1 as the
# AsyncExitStack was removed from the request scope and is now
# part of the routing function which closes this before the
# event is handled.
file_copies = []
for file in files:
if not isinstance(file, StarletteUploadFile):
raise UploadValueError(
"Uploaded file is not an UploadFile." + str(file)
# Get the state for the session.
substate_token = _substate_key(token, handler.rpartition(".")[0])
state = await app.state_manager.get_state(substate_token)

handler_upload_param = ()

_current_state, event_handler = state._get_event_handler(handler)

if event_handler.is_background:
msg = f"@rx.event(background=True) is not supported for upload handler `{handler}`."
raise UploadTypeError(msg)
func = event_handler.fn
if isinstance(func, functools.partial):
func = func.func
for k, v in get_type_hints(func).items():
if types.is_generic_alias(v) and types._issubclass(
get_args(v)[0],
UploadFile,
):
handler_upload_param = (k, v)
break

if not handler_upload_param:
msg = (
f"`{handler}` handler should have a parameter annotated as "
"list[rx.UploadFile]"
)
content_copy = io.BytesIO()
content_copy.write(await file.read())
content_copy.seek(0)
file_copies.append(
UploadFile(
file=content_copy,
path=Path(file.filename.lstrip("/")) if file.filename else None,
size=file.size,
headers=file.headers,
raise UploadValueError(msg)

# Keep the parsed form data alive until the upload event finishes so
# the underlying Starlette temp files remain available to the handler.
file_uploads = []
for file in files:
if not isinstance(file, StarletteUploadFile):
raise UploadValueError(
"Uploaded file is not an UploadFile." + str(file)
)
file_uploads.append(
UploadFile(
file=file.file,
path=Path(file.filename.lstrip("/")) if file.filename else None,
size=file.size,
headers=file.headers,
)
)
)

for file in files:
if not isinstance(file, StarletteUploadFile):
raise UploadValueError(
"Uploaded file is not an UploadFile." + str(file)
)
await file.close()
return Event(
token=token,
name=handler,
payload={handler_upload_param[0]: file_uploads},
)

event = Event(
token=token,
name=handler,
payload={handler_upload_param[0]: file_copies},
)
event: Event | None = None
try:
event = await _create_upload_event()
finally:
if event is None:
await _close_form_data()

async def _ndjson_updates():
"""Process the upload event, generating ndjson updates.
Expand All @@ -2013,9 +2045,10 @@ async def _ndjson_updates():
yield update.json() + "\n"

# Stream updates to client
return StreamingResponse(
return _UploadStreamingResponse(
_ndjson_updates(),
media_type="application/x-ndjson",
on_finish=_close_form_data,
)

return upload_file
Expand Down
86 changes: 10 additions & 76 deletions tests/units/states/upload.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test states for upload-related tests."""

from pathlib import Path
from typing import ClassVar

import reflex as rx
from reflex.state import BaseState, State
Expand Down Expand Up @@ -35,11 +34,11 @@ async def handle_upload(self, files: list[rx.UploadFile]):
"""


class FileUploadState(State):
"""The base state for uploading a file."""
class _FileUploadMixin(BaseState, mixin=True):
"""Common fields and handlers for upload state tests."""

img_list: list[str]
_tmp_path: ClassVar[Path]
_tmp_path: Path = Path()

async def handle_upload2(self, files):
"""Handle the upload of a file.
Expand All @@ -64,6 +63,7 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]):

# Update the img var.
self.img_list.append(file.name)
yield

@rx.event(background=True)
async def bg_upload(self, files: list[rx.UploadFile]):
Expand All @@ -74,87 +74,21 @@ async def bg_upload(self, files: list[rx.UploadFile]):
"""


class FileUploadState(_FileUploadMixin, State):
"""The base state for uploading a file."""


class FileStateBase1(State):
"""The base state for a child FileUploadState."""


class ChildFileUploadState(FileStateBase1):
class ChildFileUploadState(_FileUploadMixin, FileStateBase1):
"""The child state for uploading a file."""

img_list: list[str]
_tmp_path: ClassVar[Path]

async def handle_upload2(self, files):
"""Handle the upload of a file.

Args:
files: The uploaded files.
"""

async def multi_handle_upload(self, files: list[rx.UploadFile]):
"""Handle the upload of a file.

Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
assert file.name is not None
outfile = self._tmp_path / file.name

# Save the file.
outfile.write_bytes(upload_data)

# Update the img var.
self.img_list.append(file.name)

@rx.event(background=True)
async def bg_upload(self, files: list[rx.UploadFile]):
"""Background task cannot be upload handler.

Args:
files: The uploaded files.
"""


class FileStateBase2(FileStateBase1):
"""The parent state for a grandchild FileUploadState."""


class GrandChildFileUploadState(FileStateBase2):
class GrandChildFileUploadState(_FileUploadMixin, FileStateBase2):
"""The child state for uploading a file."""

img_list: list[str]
_tmp_path: ClassVar[Path]

async def handle_upload2(self, files):
"""Handle the upload of a file.

Args:
files: The uploaded files.
"""

async def multi_handle_upload(self, files: list[rx.UploadFile]):
"""Handle the upload of a file.

Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
assert file.name is not None
outfile = self._tmp_path / file.name

# Save the file.
outfile.write_bytes(upload_data)

# Update the img var.
self.img_list.append(file.name)

@rx.event(background=True)
async def bg_upload(self, files: list[rx.UploadFile]):
"""Background task cannot be upload handler.

Args:
files: The uploaded files.
"""
Loading
Loading