Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 116 additions & 80 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import dataclasses
import functools
import inspect
import io
import json
import operator
import sys
Expand All @@ -19,6 +18,7 @@
from collections.abc import (
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Mapping,
Expand Down Expand Up @@ -1886,6 +1886,27 @@ async def health(_request: Request) -> JSONResponse:
return JSONResponse(content=health_status, status_code=status_code)


class _UploadStreamingResponse(StreamingResponse):
"""Streaming response that always releases upload form resources."""

_on_finish: Callable[[], Awaitable[None]]

def __init__(
self,
*args: Any,
on_finish: Callable[[], Awaitable[None]],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._on_finish = on_finish

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
try:
await super().__call__(scope, receive, send)
finally:
await self._on_finish()


def upload(app: App):
"""Upload a file.

Expand Down Expand Up @@ -1915,107 +1936,122 @@ async def upload_file(request: Request):

# Get the files from the request.
try:
files = await request.form()
form_data = await request.form()
except ClientDisconnect:
return Response() # user cancelled
files = files.getlist("files")
if not files:
msg = "No files were uploaded."
raise UploadValueError(msg)

token = request.headers.get("reflex-client-token")
handler = request.headers.get("reflex-event-handler")

if not token or not handler:
raise HTTPException(
status_code=400,
detail="Missing reflex-client-token or reflex-event-handler header.",
)

# Get the state for the session.
substate_token = _substate_key(token, handler.rpartition(".")[0])
state = await app.state_manager.get_state(substate_token)
form_data_closed = False

async def _close_form_data() -> None:
"""Close the parsed form data exactly once."""
nonlocal form_data_closed
if form_data_closed:
return
form_data_closed = True
await form_data.close()

handler_upload_param = ()
async def _create_upload_event() -> Event:
"""Create an upload event using the live Starlette temp files.

_current_state, event_handler = state._get_event_handler(handler)
Returns:
The upload event backed by the original temp files.
"""
files = form_data.getlist("files")
if not files:
msg = "No files were uploaded."
raise UploadValueError(msg)

token = request.headers.get("reflex-client-token")
handler = request.headers.get("reflex-event-handler")

if not token or not handler:
raise HTTPException(
status_code=400,
detail="Missing reflex-client-token or reflex-event-handler header.",
)

if event_handler.is_background:
msg = f"@rx.event(background=True) is not supported for upload handler `{handler}`."
raise UploadTypeError(msg)
func = event_handler.fn
if isinstance(func, functools.partial):
func = func.func
for k, v in get_type_hints(func).items():
if types.is_generic_alias(v) and types._issubclass(
get_args(v)[0],
UploadFile,
):
handler_upload_param = (k, v)
break
# Get the state for the session.
substate_token = _substate_key(token, handler.rpartition(".")[0])
state = await app.state_manager.get_state(substate_token)

if not handler_upload_param:
msg = (
f"`{handler}` handler should have a parameter annotated as "
"list[rx.UploadFile]"
)
raise UploadValueError(msg)

# Make a copy of the files as they are closed after the request.
# This behaviour changed from fastapi 0.103.0 to 0.103.1 as the
# AsyncExitStack was removed from the request scope and is now
# part of the routing function which closes this before the
# event is handled.
file_copies = []
for file in files:
if not isinstance(file, StarletteUploadFile):
raise UploadValueError(
"Uploaded file is not an UploadFile." + str(file)
handler_upload_param = ()

_current_state, event_handler = state._get_event_handler(handler)

if event_handler.is_background:
msg = f"@rx.event(background=True) is not supported for upload handler `{handler}`."
raise UploadTypeError(msg)
func = event_handler.fn
if isinstance(func, functools.partial):
func = func.func
for k, v in get_type_hints(func).items():
if types.is_generic_alias(v) and types._issubclass(
get_args(v)[0],
UploadFile,
):
handler_upload_param = (k, v)
break

if not handler_upload_param:
msg = (
f"`{handler}` handler should have a parameter annotated as "
"list[rx.UploadFile]"
)
content_copy = io.BytesIO()
content_copy.write(await file.read())
content_copy.seek(0)
file_copies.append(
UploadFile(
file=content_copy,
path=Path(file.filename.lstrip("/")) if file.filename else None,
size=file.size,
headers=file.headers,
raise UploadValueError(msg)

# Keep the parsed form data alive until the upload event finishes so
# the underlying Starlette temp files remain available to the handler.
file_uploads = []
for file in files:
if not isinstance(file, StarletteUploadFile):
raise UploadValueError(
"Uploaded file is not an UploadFile." + str(file)
)
file_uploads.append(
UploadFile(
file=file.file,
path=Path(file.filename.lstrip("/")) if file.filename else None,
size=file.size,
headers=file.headers,
)
)
)

for file in files:
if not isinstance(file, StarletteUploadFile):
raise UploadValueError(
"Uploaded file is not an UploadFile." + str(file)
)
await file.close()
return Event(
token=token,
name=handler,
payload={handler_upload_param[0]: file_uploads},
)

event = Event(
token=token,
name=handler,
payload={handler_upload_param[0]: file_copies},
)
event: Event | None = None
try:
event = await _create_upload_event()
finally:
if event is None:
await _close_form_data()

async def _ndjson_updates():
"""Process the upload event, generating ndjson updates.

Yields:
Each state update as JSON followed by a new line.
"""
# Process the event.
async with app.state_manager.modify_state_with_links(
event.substate_token
) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)
yield update.json() + "\n"
try:
# Process the event.
async with app.state_manager.modify_state_with_links(
event.substate_token
) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)
yield update.json() + "\n"
finally:
await _close_form_data()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i think cleaning up this path since we handle it with _UploadStreamingResponse makes the code more maintainable

Suggested change
try:
# Process the event.
async with app.state_manager.modify_state_with_links(
event.substate_token
) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)
yield update.json() + "\n"
finally:
await _close_form_data()
# Process the event.
async with app.state_manager.modify_state_with_links(
event.substate_token
) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)
yield update.json() + "\n"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


# Stream updates to client
return StreamingResponse(
return _UploadStreamingResponse(
_ndjson_updates(),
media_type="application/x-ndjson",
on_finish=_close_form_data,
)

return upload_file
Expand Down
86 changes: 10 additions & 76 deletions tests/units/states/upload.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test states for upload-related tests."""

from pathlib import Path
from typing import ClassVar

import reflex as rx
from reflex.state import BaseState, State
Expand Down Expand Up @@ -35,11 +34,11 @@ async def handle_upload(self, files: list[rx.UploadFile]):
"""


class FileUploadState(State):
"""The base state for uploading a file."""
class _FileUploadMixin(BaseState, mixin=True):
"""Common fields and handlers for upload state tests."""

img_list: list[str]
_tmp_path: ClassVar[Path]
_tmp_path: Path = Path()

async def handle_upload2(self, files):
"""Handle the upload of a file.
Expand All @@ -64,6 +63,7 @@ async def multi_handle_upload(self, files: list[rx.UploadFile]):

# Update the img var.
self.img_list.append(file.name)
yield

@rx.event(background=True)
async def bg_upload(self, files: list[rx.UploadFile]):
Expand All @@ -74,87 +74,21 @@ async def bg_upload(self, files: list[rx.UploadFile]):
"""


class FileUploadState(_FileUploadMixin, State):
"""The base state for uploading a file."""


class FileStateBase1(State):
"""The base state for a child FileUploadState."""


class ChildFileUploadState(FileStateBase1):
class ChildFileUploadState(_FileUploadMixin, FileStateBase1):
"""The child state for uploading a file."""

img_list: list[str]
_tmp_path: ClassVar[Path]

async def handle_upload2(self, files):
"""Handle the upload of a file.

Args:
files: The uploaded files.
"""

async def multi_handle_upload(self, files: list[rx.UploadFile]):
"""Handle the upload of a file.

Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
assert file.name is not None
outfile = self._tmp_path / file.name

# Save the file.
outfile.write_bytes(upload_data)

# Update the img var.
self.img_list.append(file.name)

@rx.event(background=True)
async def bg_upload(self, files: list[rx.UploadFile]):
"""Background task cannot be upload handler.

Args:
files: The uploaded files.
"""


class FileStateBase2(FileStateBase1):
"""The parent state for a grandchild FileUploadState."""


class GrandChildFileUploadState(FileStateBase2):
class GrandChildFileUploadState(_FileUploadMixin, FileStateBase2):
"""The child state for uploading a file."""

img_list: list[str]
_tmp_path: ClassVar[Path]

async def handle_upload2(self, files):
"""Handle the upload of a file.

Args:
files: The uploaded files.
"""

async def multi_handle_upload(self, files: list[rx.UploadFile]):
"""Handle the upload of a file.

Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
assert file.name is not None
outfile = self._tmp_path / file.name

# Save the file.
outfile.write_bytes(upload_data)

# Update the img var.
self.img_list.append(file.name)

@rx.event(background=True)
async def bg_upload(self, files: list[rx.UploadFile]):
"""Background task cannot be upload handler.

Args:
files: The uploaded files.
"""
Loading
Loading