From 78b656c6aa52ff92d52cc13157f486077a7acbde Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Wed, 18 Mar 2026 21:57:36 +0500 Subject: [PATCH] feat: add chunked file upload support to stream large files without buffering in memory Introduce `rx.upload_files_chunk` and the `/_upload_chunk` endpoint to allow uploading files in 8 MB chunks streamed directly to a background event handler via `rx.UploadChunkIterator`. This avoids buffering entire files in server memory and enables incremental processing as data arrives. - Add `UploadChunk`, `UploadChunkIterator`, and `UploadFilesChunk` to the public API - Add `uploadFilesChunk` JS client function that splits files into chunks with session tracking, progress reporting, and cancellation - Add `upload_chunk()` backend endpoint with multi-request session management and streaming multipart parsing - Refactor shared upload helpers (`_require_upload_headers`, `_get_upload_runtime_handler`, `_seed_upload_router_data`) - Add `sync_web_runtime_templates()` to keep `.web` runtime files in sync during `setup_frontend` - Add unit and integration tests for chunked uploads, cancellation, validation, and custom parameter names --- pyi_hashes.json | 4 +- reflex/.templates/web/utils/helpers/upload.js | 516 +++++++++++-- reflex/.templates/web/utils/state.js | 33 +- reflex/__init__.py | 3 + reflex/app.py | 709 +++++++++++++++++- reflex/components/core/upload.py | 23 +- reflex/constants/event.py | 1 + reflex/event.py | 325 +++++++- reflex/utils/build.py | 11 +- reflex/utils/frontend_skeleton.py | 45 ++ tests/integration/test_upload.py | 214 +++++- tests/units/components/core/test_upload.py | 121 ++- tests/units/states/upload.py | 51 ++ tests/units/test_app.py | 235 +++++- tests/units/test_prerequisites.py | 72 ++ 15 files changed, 2231 insertions(+), 132 deletions(-) diff --git a/pyi_hashes.json b/pyi_hashes.json index ec9bbff8850..8c52bc2428a 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -1,5 +1,5 @@ { - "reflex/__init__.pyi": "0a3ae880e256b9fd3b960e12a2cb51a7", + "reflex/__init__.pyi": "276759cf35be6503c710e2203405adb6", "reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb", "reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a", "reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1", @@ -19,7 +19,7 @@ "reflex/components/core/helmet.pyi": "43f8497c8fafe51e29dca1dd535d143a", "reflex/components/core/html.pyi": "86eb9d4c1bb4807547b2950d9a32e9fd", "reflex/components/core/sticky.pyi": "cb763b986a9b0654d1a3f33440dfcf60", - "reflex/components/core/upload.pyi": "6dc28804a6dddf903e31162e87c1b023", + "reflex/components/core/upload.pyi": "3ac15718ed593895c0ce77552499505b", "reflex/components/core/window_events.pyi": "af33ccec866b9540ee7fbec6dbfbd151", "reflex/components/datadisplay/__init__.pyi": "52755871369acbfd3a96b46b9a11d32e", "reflex/components/datadisplay/code.pyi": "b86769987ef4d1cbdddb461be88539fd", diff --git a/reflex/.templates/web/utils/helpers/upload.js b/reflex/.templates/web/utils/helpers/upload.js index 6bbfc746ed6..6e40268699c 100644 --- a/reflex/.templates/web/utils/helpers/upload.js +++ b/reflex/.templates/web/utils/helpers/upload.js @@ -1,47 +1,16 @@ import JSON5 from "json5"; import env from "$/env.json"; -/** - * Upload files to the server. - * - * @param state The state to apply the delta to. - * @param handler The handler to use. - * @param upload_id The upload id to use. - * @param on_upload_progress The function to call on upload progress. - * @param socket the websocket connection - * @param extra_headers Extra headers to send with the request. - * @param refs The refs object to store the abort controller in. - * @param getBackendURL Function to get the backend URL. - * @param getToken Function to get the Reflex token. - * - * @returns The response from posting to the UPLOADURL endpoint. - */ -export const uploadFiles = async ( - handler, - files, - upload_id, - on_upload_progress, - extra_headers, - socket, - refs, - getBackendURL, - getToken, -) => { - // return if there's no file to upload - if (files === undefined || files.length === 0) { - return false; - } - - const upload_ref_name = `__upload_controllers_${upload_id}`; - - if (refs[upload_ref_name]) { - console.log("Upload already in progress for ", upload_id); - return false; - } +const UPLOAD_CHUNK_SIZE = 8 * 1024 * 1024; +const logUpload = (message, details = {}) => { + console.log(`[reflex upload] ${message}`, details); +}; +const trackUploadResponse = (socket) => { // Track how many partial updates have been processed for this upload. let resp_idx = 0; - const eventHandler = (progressEvent) => { + + return (progressEvent) => { const event_callbacks = socket._callbacks.$event; // Whenever called, responseText will contain the entire response so far. const chunks = progressEvent.event.target.responseText.trim().split("\n"); @@ -73,24 +42,48 @@ export const uploadFiles = async ( } }); }; +}; - const controller = new AbortController(); - const formdata = new FormData(); +const sendUploadRequest = async ({ + handler, + upload_id, + on_upload_progress, + extra_headers, + refs, + getToken, + formdata, + url, + responseHandler, +}) => { + const upload_ref_name = `__upload_controllers_${upload_id}`; - // Add the token and handler to the file name. - files.forEach((file) => { - formdata.append("files", file, file.path || file.name); - }); + if (refs[upload_ref_name]) { + logUpload("upload already in progress", { upload_id, handler }); + return false; + } + + const controller = new AbortController(); - // Send the file to the server. refs[upload_ref_name] = controller; return new Promise((resolve, reject) => { const xhr = new XMLHttpRequest(); - // Set up event handlers + logUpload("sending request", { + handler, + upload_id, + url: String(url), + request_kind: responseHandler ? "classic" : "chunked", + }); + xhr.onload = function () { if (xhr.status >= 200 && xhr.status < 300) { + logUpload("request completed", { + handler, + upload_id, + status: xhr.status, + response_length: xhr.responseText.length, + }); resolve({ data: xhr.responseText, status: xhr.status, @@ -105,49 +98,45 @@ export const uploadFiles = async ( }; xhr.onerror = function () { + logUpload("request failed", { handler, upload_id }); reject(new Error("Network error")); }; xhr.onabort = function () { + logUpload("request aborted", { handler, upload_id }); reject(new Error("Upload aborted")); }; - // Handle upload progress if (on_upload_progress) { xhr.upload.onprogress = function (event) { if (event.lengthComputable) { - const progressEvent = { + on_upload_progress({ loaded: event.loaded, total: event.total, progress: event.loaded / event.total, - }; - on_upload_progress(progressEvent); + }); } }; } - // Handle download progress with streaming response parsing - xhr.onprogress = function (event) { - if (eventHandler) { - const progressEvent = { + if (responseHandler) { + xhr.onprogress = function (event) { + responseHandler({ event: { target: { responseText: xhr.responseText, }, }, progress: event.lengthComputable ? event.loaded / event.total : 0, - }; - eventHandler(progressEvent); - } - }; + }); + }; + } - // Handle abort controller controller.signal.addEventListener("abort", () => { xhr.abort(); }); - // Configure and send request - xhr.open("POST", getBackendURL(env.UPLOAD)); + xhr.open("POST", url); xhr.setRequestHeader("Reflex-Client-Token", getToken()); xhr.setRequestHeader("Reflex-Event-Handler", handler); for (const [key, value] of Object.entries(extra_headers || {})) { @@ -168,3 +157,408 @@ export const uploadFiles = async ( delete refs[upload_ref_name]; }); }; + +const createChunkUploadController = () => ({ + cancelled: false, + currentXhr: null, + abort() { + this.cancelled = true; + this.currentXhr?.abort(); + }, +}); + +const createChunkUploadSessionId = () => { + if ( + typeof crypto !== "undefined" && + typeof crypto.randomUUID === "function" + ) { + return crypto.randomUUID(); + } + return `upload-${Date.now()}-${Math.random().toString(16).slice(2)}`; +}; + +const buildChunkUploadURL = ({ + getBackendURL, + sessionId, + upload_id, + filename, + offset, + complete = false, + cancel = false, +}) => { + const url = new URL(getBackendURL(env.UPLOAD_CHUNK)); + const searchParams = new URLSearchParams({ + session_id: sessionId, + }); + + if (upload_id) { + searchParams.set("upload_id", upload_id); + } + if (filename !== undefined) { + searchParams.set("filename", filename); + } + if (offset !== undefined) { + searchParams.set("offset", String(offset)); + } + if (complete) { + searchParams.set("complete", "1"); + } + if (cancel) { + searchParams.set("cancel", "1"); + } + + url.search = searchParams.toString(); + return url; +}; + +const sendChunkUploadRequest = async ({ + handler, + upload_id, + extra_headers, + getToken, + url, + body, + contentType, + controller, + onRequestProgress, + details, +}) => + new Promise((resolve, reject) => { + if (controller.cancelled) { + reject(new Error("Upload aborted")); + return; + } + + const xhr = new XMLHttpRequest(); + const cleanup = () => { + if (controller.currentXhr === xhr) { + controller.currentXhr = null; + } + }; + controller.currentXhr = xhr; + + logUpload("sending request", { + handler, + upload_id, + url: String(url), + request_kind: "chunked", + ...details, + }); + + xhr.onload = function () { + cleanup(); + if (xhr.status >= 200 && xhr.status < 300) { + logUpload("request completed", { + handler, + upload_id, + status: xhr.status, + response_length: xhr.responseText.length, + ...details, + }); + resolve(xhr.status); + } else { + reject(new Error(`HTTP error! status: ${xhr.status}`)); + } + }; + + xhr.onerror = function () { + cleanup(); + logUpload("request failed", { handler, upload_id, ...details }); + reject(new Error("Network error")); + }; + + xhr.onabort = function () { + cleanup(); + logUpload("request aborted", { handler, upload_id, ...details }); + reject(new Error("Upload aborted")); + }; + + if (onRequestProgress) { + xhr.upload.onprogress = function (event) { + onRequestProgress(event); + }; + } + + xhr.open("POST", url); + xhr.setRequestHeader("Reflex-Client-Token", getToken()); + xhr.setRequestHeader("Reflex-Event-Handler", handler); + if (contentType) { + xhr.setRequestHeader("Content-Type", contentType); + } + for (const [key, value] of Object.entries(extra_headers || {})) { + xhr.setRequestHeader(key, value); + } + + try { + xhr.send(body); + } catch (error) { + cleanup(); + reject(error); + } + }); + +const notifyChunkUploadCancelled = async ({ + handler, + upload_id, + sessionId, + extra_headers, + getBackendURL, + getToken, +}) => { + const url = buildChunkUploadURL({ + getBackendURL, + sessionId, + upload_id, + cancel: true, + }); + + logUpload("sending cancel request", { + handler, + upload_id, + session_id: sessionId, + url: String(url), + }); + + try { + const response = await fetch(url, { + method: "POST", + headers: { + "Reflex-Client-Token": getToken(), + "Reflex-Event-Handler": handler, + ...extra_headers, + }, + keepalive: true, + }); + logUpload("cancel request completed", { + handler, + upload_id, + session_id: sessionId, + status: response.status, + }); + } catch (error) { + logUpload("cancel request failed", { + handler, + upload_id, + session_id: sessionId, + error: error.message, + }); + } +}; + +/** + * Upload files to the server. + * + * @param handler The handler to use. + * @param upload_id The upload id to use. + * @param on_upload_progress The function to call on upload progress. + * @param extra_headers Extra headers to send with the request. + * @param socket The websocket connection. + * @param refs The refs object to store the abort controller in. + * @param getBackendURL Function to get the backend URL. + * @param getToken Function to get the Reflex token. + * + * @returns The response from posting to the upload endpoint. + */ +export const uploadFiles = async ( + handler, + files, + upload_id, + on_upload_progress, + extra_headers, + socket, + refs, + getBackendURL, + getToken, +) => { + if (files === undefined || files.length === 0) { + logUpload("classic upload skipped because there are no files", { + handler, + upload_id, + }); + return false; + } + + const formdata = new FormData(); + + files.forEach((file) => { + formdata.append("files", file, file.path || file.name); + }); + + return sendUploadRequest({ + handler, + upload_id, + on_upload_progress, + extra_headers, + refs, + getToken, + formdata, + url: getBackendURL(env.UPLOAD), + responseHandler: trackUploadResponse(socket), + }); +}; + +/** + * Upload files to the streaming chunk endpoint. + * + * @param handler The handler to use. + * @param files The files to upload. + * @param upload_id The upload id to use. + * @param on_upload_progress The function to call on upload progress. + * @param extra_headers Extra headers to send with the request. + * @param _socket The websocket connection. + * @param refs The refs object to store the abort controller in. + * @param getBackendURL Function to get the backend URL. + * @param getToken Function to get the Reflex token. + * + * @returns The response from posting to the chunk upload endpoint. + */ +export const uploadFilesChunk = async ( + handler, + files, + upload_id, + on_upload_progress, + extra_headers, + _socket, + refs, + getBackendURL, + getToken, +) => { + if (files === undefined || files.length === 0) { + logUpload("chunked upload skipped because there are no files", { + handler, + upload_id, + }); + return false; + } + + const maxSize = Math.max(...files.map((file) => file.size), 0); + const totalBytes = files.reduce((sum, file) => sum + file.size, 0); + const totalRequestCount = files.reduce( + (sum, file) => sum + Math.max(1, Math.ceil(file.size / UPLOAD_CHUNK_SIZE)), + 0, + ); + const upload_ref_name = `__upload_controllers_${upload_id}`; + + if (refs[upload_ref_name]) { + logUpload("upload already in progress", { upload_id, handler }); + return false; + } + + const controller = createChunkUploadController(); + const sessionId = createChunkUploadSessionId(); + refs[upload_ref_name] = controller; + + logUpload("prepared chunked upload plan", { + handler, + upload_id, + session_id: sessionId, + file_count: files.length, + max_file_size: maxSize, + total_size: totalBytes, + chunk_size: UPLOAD_CHUNK_SIZE, + request_count: totalRequestCount, + files: files.map((file) => ({ + name: file.path || file.name, + size: file.size, + type: file.type, + })), + }); + + let uploadedBytes = 0; + let requestIndex = 0; + let completed = false; + const maxIterations = Math.max(maxSize, 1); + + try { + for (let offset = 0; offset < maxIterations; offset += UPLOAD_CHUNK_SIZE) { + for (const file of files) { + if (controller.cancelled) { + throw new Error("Upload aborted"); + } + + const filename = file.path || file.name; + let chunkBlob; + if (file.size === 0) { + if (offset !== 0) { + continue; + } + chunkBlob = file.slice(0, 0, file.type); + } else { + if (offset >= file.size) { + continue; + } + chunkBlob = file.slice(offset, offset + UPLOAD_CHUNK_SIZE, file.type); + } + + const isFinalRequest = requestIndex === totalRequestCount - 1; + const url = buildChunkUploadURL({ + getBackendURL, + sessionId, + upload_id, + filename, + offset, + complete: isFinalRequest, + }); + + await sendChunkUploadRequest({ + handler, + upload_id, + extra_headers, + getToken, + url, + body: chunkBlob, + contentType: + chunkBlob.type || file.type || "application/octet-stream", + controller, + details: { + session_id: sessionId, + file_name: filename, + offset, + request_index: requestIndex, + complete: isFinalRequest, + }, + onRequestProgress: (event) => { + if (!on_upload_progress || !event.lengthComputable) { + return; + } + const loaded = uploadedBytes + event.loaded; + const total = totalBytes; + on_upload_progress({ + loaded, + total, + progress: total === 0 ? 1 : loaded / total, + }); + }, + }); + + uploadedBytes += chunkBlob.size; + requestIndex += 1; + } + } + + if (on_upload_progress) { + on_upload_progress({ + loaded: totalBytes, + total: totalBytes, + progress: 1, + }); + } + completed = true; + return true; + } catch (error) { + console.log("Upload error:", error.message); + return false; + } finally { + if (!completed && requestIndex > 0) { + await notifyChunkUploadCancelled({ + handler, + upload_id, + sessionId, + extra_headers, + getBackendURL, + getToken, + }); + } + delete refs[upload_ref_name]; + } +}; diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 9e937ed62cd..6ad5a883f42 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -20,7 +20,7 @@ import { } from "$/utils/context"; import debounce from "$/utils/helpers/debounce"; import throttle from "$/utils/helpers/throttle"; -import { uploadFiles } from "$/utils/helpers/upload"; +import { uploadFiles, uploadFilesChunk } from "$/utils/helpers/upload"; // Endpoint URLs. const EVENTURL = env.EVENT; @@ -418,11 +418,30 @@ export const applyEvent = async (event, socket, navigate, params) => { */ export const applyRestEvent = async (event, socket, navigate, params) => { let eventSent = false; - if (event.handler === "uploadFiles") { - if (event.payload.files === undefined || event.payload.files.length === 0) { + if (event.handler === "uploadFiles" || event.handler === "uploadFilesChunk") { + const filePayloadKey = event.payload.upload_param_name || "files"; + const uploadFilesPayload = + event.payload.files ?? event.payload[filePayloadKey]; + + console.log("[reflex upload] applyRestEvent", { + event_name: event.name, + handler: event.handler, + upload_id: event.payload.upload_id, + file_payload_key: filePayloadKey, + file_count: uploadFilesPayload?.length ?? 0, + payload_keys: Object.keys(event.payload || {}), + }); + + if (uploadFilesPayload === undefined || uploadFilesPayload.length === 0) { + console.warn("[reflex upload] no files available for REST upload", { + event_name: event.name, + handler: event.handler, + upload_id: event.payload.upload_id, + file_payload_key: filePayloadKey, + }); // Submit the event over the websocket to trigger the event handler. return await applyEvent( - ReflexEvent(event.name, { files: [] }), + ReflexEvent(event.name, { [filePayloadKey]: [] }), socket, navigate, params, @@ -430,9 +449,11 @@ export const applyRestEvent = async (event, socket, navigate, params) => { } // Start upload, but do not wait for it, which would block other events. - uploadFiles( + const uploadFn = + event.handler === "uploadFilesChunk" ? uploadFilesChunk : uploadFiles; + uploadFn( event.name, - event.payload.files, + uploadFilesPayload, event.payload.upload_id, event.payload.on_upload_progress, event.payload.extra_headers, diff --git a/reflex/__init__.py b/reflex/__init__.py index 066df110f02..e3c832ced19 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -301,6 +301,8 @@ "event", "EventChain", "EventHandler", + "UploadChunk", + "UploadChunkIterator", "call_script", "call_function", "run_script", @@ -320,6 +322,7 @@ "set_value", "stop_propagation", "upload_files", + "upload_files_chunk", "window_alert", ], "istate.storage": [ diff --git a/reflex/app.py b/reflex/app.py index 54682543a7d..087536d1118 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -29,8 +29,9 @@ from pathlib import Path from timeit import default_timer as timer from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, BinaryIO, ParamSpec, get_args, get_type_hints +from typing import TYPE_CHECKING, Any, BinaryIO, ParamSpec, cast +from python_multipart.multipart import MultipartParser, parse_options_header from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp as EngineIOApp from socketio import AsyncNamespace, AsyncServer @@ -38,6 +39,7 @@ from starlette.datastructures import Headers from starlette.datastructures import UploadFile as StarletteUploadFile from starlette.exceptions import HTTPException +from starlette.formparsers import MultiPartException, _user_safe_decode from starlette.middleware import cors from starlette.requests import ClientDisconnect, Request from starlette.responses import JSONResponse, Response, StreamingResponse @@ -77,11 +79,16 @@ from reflex.event import ( _EVENT_FIELDS, Event, + EventHandler, EventSpec, EventType, IndividualEventType, + UploadChunk, + UploadChunkIterator, get_hydrate_event, noop, + resolve_upload_chunk_handler_param, + resolve_upload_handler_param, ) from reflex.istate.proxy import StateProxy from reflex.page import DECORATED_PAGES @@ -110,7 +117,6 @@ js_runtimes, path_ops, prerequisites, - types, ) from reflex.utils.exec import ( get_compile_context, @@ -285,6 +291,422 @@ def name(self) -> str | None: return None +@dataclasses.dataclass +class _UploadChunkPart: + """Track the current multipart file part for chunk uploads.""" + + content_disposition: bytes | None = None + field_name: str = "" + filename: str | None = None + content_type: str = "" + item_headers: list[tuple[bytes, bytes]] = dataclasses.field(default_factory=list) + offset: int = 0 + bytes_emitted: int = 0 + is_upload_chunk: bool = False + + +class _UploadChunkMultipartParser: + """Streaming multipart parser for chunked uploads.""" + + def __init__( + self, + headers: Headers, + stream: AsyncGenerator[bytes, None], + chunk_iter: UploadChunkIterator, + ) -> None: + self.headers = headers + self.stream = stream + self.chunk_iter = chunk_iter + self._charset = "" + self._current_partial_header_name = b"" + self._current_partial_header_value = b"" + self._current_part = _UploadChunkPart() + self._chunks_to_emit: list[UploadChunk] = [] + self._seen_upload_chunk = False + self._part_count = 0 + self._emitted_chunk_count = 0 + self._emitted_bytes = 0 + self._stream_chunk_count = 0 + + def on_part_begin(self) -> None: + """Reset parser state for a new multipart part.""" + self._current_part = _UploadChunkPart() + + def on_part_data(self, data: bytes, start: int, end: int) -> None: + """Record streamed chunk data for the current part.""" + if ( + not self._current_part.is_upload_chunk + or self._current_part.filename is None + ): + return + + message_bytes = data[start:end] + self._chunks_to_emit.append( + UploadChunk( + filename=self._current_part.filename, + offset=self._current_part.offset + self._current_part.bytes_emitted, + content_type=self._current_part.content_type, + data=message_bytes, + ) + ) + self._current_part.bytes_emitted += len(message_bytes) + self._emitted_chunk_count += 1 + self._emitted_bytes += len(message_bytes) + + def on_part_end(self) -> None: + """Emit a zero-byte chunk for empty file parts.""" + if ( + self._current_part.is_upload_chunk + and self._current_part.filename is not None + and self._current_part.bytes_emitted == 0 + ): + self._chunks_to_emit.append( + UploadChunk( + filename=self._current_part.filename, + offset=self._current_part.offset, + content_type=self._current_part.content_type, + data=b"", + ) + ) + self._emitted_chunk_count += 1 + + def on_header_field(self, data: bytes, start: int, end: int) -> None: + """Accumulate multipart header field bytes.""" + self._current_partial_header_name += data[start:end] + + def on_header_value(self, data: bytes, start: int, end: int) -> None: + """Accumulate multipart header value bytes.""" + self._current_partial_header_value += data[start:end] + + def on_header_end(self) -> None: + """Store the completed multipart header.""" + field = self._current_partial_header_name.lower() + if field == b"content-disposition": + self._current_part.content_disposition = self._current_partial_header_value + self._current_part.item_headers.append(( + field, + self._current_partial_header_value, + )) + self._current_partial_header_name = b"" + self._current_partial_header_value = b"" + + def on_headers_finished(self) -> None: + """Parse upload chunk metadata from multipart headers.""" + disposition, options = parse_options_header( + self._current_part.content_disposition + ) + if disposition != b"form-data": + msg = "Invalid upload chunk disposition." + raise MultiPartException(msg) + + try: + field_name = _user_safe_decode(options[b"name"], self._charset) + except KeyError as err: + msg = 'The Content-Disposition header field "name" must be provided.' + raise MultiPartException(msg) from err + + parts = field_name.split(":") + if len(parts) != 3 or parts[0] != "chunk": + msg = f"Invalid upload chunk field name: {field_name}" + raise MultiPartException(msg) + + try: + int(parts[1]) + offset = int(parts[2]) + except ValueError as err: + msg = f"Invalid upload chunk field name: {field_name}" + raise MultiPartException(msg) from err + + if offset < 0: + msg = f"Invalid upload chunk field name: {field_name}" + raise MultiPartException(msg) + + try: + filename = _user_safe_decode(options[b"filename"], self._charset) + except KeyError as err: + msg = f"Missing filename for upload chunk field: {field_name}" + raise MultiPartException(msg) from err + + content_type = "" + for header_name, header_value in self._current_part.item_headers: + if header_name == b"content-type": + content_type = _user_safe_decode(header_value, self._charset) + break + + self._current_part.field_name = field_name + self._current_part.filename = filename + self._current_part.content_type = content_type + self._current_part.offset = offset + self._current_part.bytes_emitted = 0 + self._current_part.is_upload_chunk = True + self._seen_upload_chunk = True + self._part_count += 1 + + def on_end(self) -> None: + """Finalize parser callbacks.""" + + def stats(self) -> dict[str, int | bool]: + """Return parser statistics for logging.""" + return { + "parts": self._part_count, + "emitted_chunks": self._emitted_chunk_count, + "emitted_bytes": self._emitted_bytes, + "request_chunks": self._stream_chunk_count, + "saw_upload_chunk": self._seen_upload_chunk, + } + + async def parse(self) -> None: + """Parse the incoming request stream and push chunks to the iterator. + + Raises: + MultiPartException: If the request is not valid multipart upload data. + RuntimeError: If the upload handler exits before consuming all chunks. + """ + _, params = parse_options_header(self.headers["Content-Type"]) + charset = params.get(b"charset", "utf-8") + if isinstance(charset, bytes): + charset = charset.decode("latin-1") + self._charset = charset + + try: + boundary = params[b"boundary"] + except KeyError as err: + msg = "Missing boundary in multipart." + raise MultiPartException(msg) from err + + callbacks = { + "on_part_begin": self.on_part_begin, + "on_part_data": self.on_part_data, + "on_part_end": self.on_part_end, + "on_header_field": self.on_header_field, + "on_header_value": self.on_header_value, + "on_header_end": self.on_header_end, + "on_headers_finished": self.on_headers_finished, + "on_end": self.on_end, + } + parser = MultipartParser(boundary, cast(Any, callbacks)) + + async for chunk in self.stream: + self._stream_chunk_count += 1 + parser.write(chunk) + while self._chunks_to_emit: + await self.chunk_iter.push(self._chunks_to_emit.pop(0)) + + parser.finalize() + while self._chunks_to_emit: + await self.chunk_iter.push(self._chunks_to_emit.pop(0)) + + if not self._seen_upload_chunk: + msg = "No file chunks were uploaded." + raise MultiPartException(msg) + + +@dataclasses.dataclass(frozen=True) +class _ChunkUploadRequestMetadata: + """Metadata describing one chunk-upload HTTP request.""" + + session_id: str + upload_id: str + filename: str | None + offset: int + is_complete: bool + is_cancel: bool + + +@dataclasses.dataclass +class _ChunkUploadSession: + """Bookkeeping for a logical chunk upload spanning multiple requests.""" + + session_id: str + upload_id: str + iterator: UploadChunkIterator + task: asyncio.Task[Any] + request_count: int = 0 + emitted_chunk_count: int = 0 + emitted_bytes: int = 0 + + def record_request(self, *, chunk_count: int, emitted_bytes: int) -> None: + """Update session counters after processing a request.""" + self.request_count += 1 + self.emitted_chunk_count += chunk_count + self.emitted_bytes += emitted_bytes + + def stats(self) -> dict[str, int | str]: + """Return summary stats for logging.""" + return { + "session_id": self.session_id, + "upload_id": self.upload_id, + "requests": self.request_count, + "emitted_chunks": self.emitted_chunk_count, + "emitted_bytes": self.emitted_bytes, + } + + +def _chunk_upload_session_key(token: str, handler_name: str, session_id: str) -> str: + """Build the internal key for a chunk upload session. + + Returns: + The app-local session key. + """ + return f"{token}:{handler_name}:{session_id}" + + +def _parse_chunk_upload_flag(value: str | None) -> bool: + """Parse a truthy chunk-upload flag from query params. + + Returns: + Whether the flag should be treated as enabled. + """ + return value is not None and value.lower() in {"1", "true", "yes", "on"} + + +def _extract_chunk_upload_request_metadata( + request: Request, +) -> _ChunkUploadRequestMetadata: + """Validate and parse query params for a chunk-upload request. + + Args: + request: The incoming request. + + Returns: + Parsed chunk-upload metadata. + + Raises: + ValueError: If the request is missing required chunk metadata. + """ + params = request.query_params + session_id = params.get("session_id", "") + if not session_id: + msg = "Missing session_id for upload chunk request." + raise ValueError(msg) + + upload_id = params.get("upload_id", "") + is_complete = _parse_chunk_upload_flag(params.get("complete")) + is_cancel = _parse_chunk_upload_flag(params.get("cancel")) + + if is_cancel: + return _ChunkUploadRequestMetadata( + session_id=session_id, + upload_id=upload_id, + filename=params.get("filename"), + offset=0, + is_complete=False, + is_cancel=True, + ) + + filename = params.get("filename") + if not filename: + msg = "Missing filename for upload chunk request." + raise ValueError(msg) + + offset_raw = params.get("offset") + if offset_raw is None: + msg = "Missing offset for upload chunk request." + raise ValueError(msg) + + try: + offset = int(offset_raw) + except ValueError as err: + msg = f"Invalid offset for upload chunk request: {offset_raw}" + raise ValueError(msg) from err + + if offset < 0: + msg = f"Invalid offset for upload chunk request: {offset_raw}" + raise ValueError(msg) + + return _ChunkUploadRequestMetadata( + session_id=session_id, + upload_id=upload_id, + filename=filename, + offset=offset, + is_complete=is_complete, + is_cancel=False, + ) + + +async def _cleanup_chunk_upload_session(app: App, session_key: str) -> None: + """Remove a chunk upload session from the app if it is still present.""" + async with app._upload_chunk_sessions_lock: + app._upload_chunk_sessions.pop(session_key, None) + + +async def _lookup_chunk_upload_session( + app: App, + token: str, + handler_name: str, + session_id: str, +) -> tuple[str, _ChunkUploadSession | None]: + """Look up an active chunk upload session. + + Returns: + The app-local session key and any existing upload session. + """ + session_key = _chunk_upload_session_key(token, handler_name, session_id) + async with app._upload_chunk_sessions_lock: + return session_key, app._upload_chunk_sessions.get(session_key) + + +async def _get_or_create_chunk_upload_session( + app: App, + token: str, + handler_name: str, + metadata: _ChunkUploadRequestMetadata, +) -> tuple[str, _ChunkUploadSession, bool]: + """Get or initialize the logical upload session for chunked uploads. + + Returns: + The internal session key, the upload session, and whether it was created. + """ + session_key = _chunk_upload_session_key(token, handler_name, metadata.session_id) + async with app._upload_chunk_sessions_lock: + existing = app._upload_chunk_sessions.get(session_key) + if existing is not None: + if metadata.upload_id and existing.upload_id != metadata.upload_id: + msg = "Chunk upload session does not match the requested upload_id." + raise ValueError(msg) + return session_key, existing, False + + _state, event_handler = await _get_upload_runtime_handler( + app, token, handler_name + ) + handler_upload_param = resolve_upload_chunk_handler_param(event_handler) + + chunk_iter = UploadChunkIterator(maxsize=8) + event = Event( + token=token, + name=handler_name, + payload={handler_upload_param[0]: chunk_iter}, + ) + + async with app.state_manager.modify_state_with_links( + event.substate_token, + event=event, + ) as state: + _seed_upload_router_data(state, token) + task = app._process_background(state, event) + + if task is None: + msg = f"@rx.event(background=True) is required for upload_files_chunk handler `{handler_name}`." + raise exceptions.UploadTypeError(msg) + + chunk_iter.set_consumer_task(task) + session = _ChunkUploadSession( + session_id=metadata.session_id, + upload_id=metadata.upload_id, + iterator=chunk_iter, + task=task, + ) + app._upload_chunk_sessions[session_key] = session + task.add_done_callback( + lambda finished_task, *, _app=app, _session_key=session_key: ( + finished_task.get_loop().create_task( + _cleanup_chunk_upload_session(_app, _session_key) + ) + ) + ) + return session_key, session, True + + @dataclasses.dataclass( frozen=True, ) @@ -435,6 +857,16 @@ class App(MiddlewareMixin, LifespanMixin): # Background tasks that are currently running. _background_tasks: set[asyncio.Task] = dataclasses.field(default_factory=set) + # Active logical sessions for chunked uploads that span multiple requests. + _upload_chunk_sessions: dict[str, _ChunkUploadSession] = dataclasses.field( + default_factory=dict + ) + + # Synchronize creation and cleanup of chunked upload sessions. + _upload_chunk_sessions_lock: asyncio.Lock = dataclasses.field( + default_factory=asyncio.Lock + ) + # Frontend Error Handler Function frontend_exception_handler: Callable[[Exception], None] = ( default_frontend_exception_handler @@ -706,6 +1138,11 @@ def _add_optional_endpoints(self): upload(self), methods=["POST"], ) + self._api.add_route( + str(constants.Endpoint.UPLOAD_CHUNK), + upload_chunk(self), + methods=["POST"], + ) # To access uploaded files. self._api.mount( @@ -1912,6 +2349,71 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self._on_finish() +def _require_upload_headers(request: Request) -> tuple[str, str]: + """Extract the required upload headers from a request. + + Args: + request: The incoming request. + + Returns: + The client token and event handler name. + + Raises: + HTTPException: If the upload headers are missing. + """ + token = request.headers.get("reflex-client-token") + handler = request.headers.get("reflex-event-handler") + + if not token or not handler: + raise HTTPException( + status_code=400, + detail="Missing reflex-client-token or reflex-event-handler header.", + ) + + return token, handler + + +async def _get_upload_runtime_handler( + app: App, + token: str, + handler_name: str, +) -> tuple[BaseState, EventHandler]: + """Resolve the runtime state and event handler for an upload request. + + Args: + app: The Reflex app. + token: The client token. + handler_name: The fully qualified event handler name. + + Returns: + The root state instance and resolved event handler. + """ + substate_token = _substate_key(token, handler_name.rpartition(".")[0]) + state = await app.state_manager.get_state(substate_token) + _current_state, event_handler = state._get_event_handler(handler_name) + return state, event_handler + + +def _seed_upload_router_data(state: BaseState, token: str) -> None: + """Ensure upload-launched handlers have the client token in router state. + + Background upload handlers use ``StateProxy`` which derives its mutable-state + token from ``self.router.session.client_token``. Upload requests do not flow + through the normal websocket event pipeline, so we seed the token here. + + Args: + state: The root state instance. + token: The client token from the upload request. + """ + router_data = dict(state.router_data) + if router_data.get(constants.RouteVar.CLIENT_TOKEN) == token: + return + + router_data[constants.RouteVar.CLIENT_TOKEN] = token + state.router_data = router_data + state.router = RouterData.from_router_data(router_data) + + def upload(app: App): """Upload a file. @@ -1937,7 +2439,7 @@ async def upload_file(request: Request): UploadTypeError: if a background task is used as the handler. HTTPException: when the request does not include token / handler headers. """ - from reflex.utils.exceptions import UploadTypeError, UploadValueError + from reflex.utils.exceptions import UploadValueError # Get the files from the request. try: @@ -1966,43 +2468,12 @@ async def _create_upload_event() -> Event: msg = "No files were uploaded." raise UploadValueError(msg) - token = request.headers.get("reflex-client-token") - handler = request.headers.get("reflex-event-handler") - - if not token or not handler: - raise HTTPException( - status_code=400, - detail="Missing reflex-client-token or reflex-event-handler header.", - ) - - # Get the state for the session. - substate_token = _substate_key(token, handler.rpartition(".")[0]) - state = await app.state_manager.get_state(substate_token) - - handler_upload_param = () - - _current_state, event_handler = state._get_event_handler(handler) + token, handler = _require_upload_headers(request) - if event_handler.is_background: - msg = f"@rx.event(background=True) is not supported for upload handler `{handler}`." - raise UploadTypeError(msg) - func = event_handler.fn - if isinstance(func, functools.partial): - func = func.func - for k, v in get_type_hints(func).items(): - if types.is_generic_alias(v) and types._issubclass( - get_args(v)[0], - UploadFile, - ): - handler_upload_param = (k, v) - break - - if not handler_upload_param: - msg = ( - f"`{handler}` handler should have a parameter annotated as " - "list[rx.UploadFile]" - ) - raise UploadValueError(msg) + _state, event_handler = await _get_upload_runtime_handler( + app, token, handler + ) + handler_upload_param = resolve_upload_handler_param(event_handler) # Keep the parsed form data alive until the upload event finishes so # the underlying Starlette temp files remain available to the handler. @@ -2059,6 +2530,166 @@ async def _ndjson_updates(): return upload_file +def upload_chunk(app: App): + """Upload file chunks to a background event handler. + + Args: + app: The app to upload the file for. + + Returns: + The streaming upload function. + """ + + async def upload_file_chunk(request: Request): + """Upload file chunks without buffering the full file in memory. + + Args: + request: The Starlette request object. + + Returns: + A response indicating whether the upload stream was accepted. + + Raises: + UploadTypeError: If the handler is not a background event. + UploadValueError: If the handler signature is invalid. + HTTPException: If the request is missing required headers. + """ + token, handler_name = _require_upload_headers(request) + try: + metadata = _extract_chunk_upload_request_metadata(request) + except ValueError as err: + console.error( + f"[chunk-upload] request failed handler={handler_name} error={err}" + ) + return JSONResponse({"detail": str(err)}, status_code=400) + console.info( + "[chunk-upload] request received " + f"handler={handler_name} token={token[:8]}... " + f"session_id={metadata.session_id} upload_id={metadata.upload_id or ''} " + f"complete={metadata.is_complete} cancel={metadata.is_cancel}" + ) + + session_key, existing_session = await _lookup_chunk_upload_session( + app, + token, + handler_name, + metadata.session_id, + ) + if metadata.is_cancel: + if existing_session is not None: + await existing_session.iterator.fail( + RuntimeError("Upload cancelled by client.") + ) + await _cleanup_chunk_upload_session(app, session_key) + console.info( + "[chunk-upload] upload cancelled " + f"handler={handler_name} session_id={metadata.session_id} " + f"session_found={existing_session is not None}" + ) + return Response(status_code=202) + + try: + session_key, session, created = await _get_or_create_chunk_upload_session( + app, + token, + handler_name, + metadata, + ) + except (exceptions.UploadTypeError, RuntimeError, ValueError) as err: + console.error( + "[chunk-upload] request failed " + f"handler={handler_name} session_id={metadata.session_id} error={err}" + ) + return JSONResponse({"detail": str(err)}, status_code=400) + + if created: + console.info( + "[chunk-upload] background task scheduled " + f"handler={handler_name} session_id={metadata.session_id}" + ) + + content_type = request.headers.get("content-type", "") + emitted_bytes = 0 + emitted_chunk_count = 0 + + try: + async for chunk_bytes in request.stream(): + if not chunk_bytes: + continue + await session.iterator.push( + UploadChunk( + filename=metadata.filename or "", + offset=metadata.offset + emitted_bytes, + content_type=content_type, + data=chunk_bytes, + ) + ) + emitted_bytes += len(chunk_bytes) + emitted_chunk_count += 1 + except ClientDisconnect as err: + console.info( + "[chunk-upload] client disconnected " + f"handler={handler_name} session_id={metadata.session_id} " + f"bytes={emitted_bytes} chunks={emitted_chunk_count}" + ) + await session.iterator.fail(err) + await _cleanup_chunk_upload_session(app, session_key) + return Response() + except (RuntimeError, ValueError) as err: + console.error( + "[chunk-upload] request failed " + f"handler={handler_name} session_id={metadata.session_id} " + f"error={err} bytes={emitted_bytes} chunks={emitted_chunk_count}" + ) + await session.iterator.fail( + err if isinstance(err, Exception) else RuntimeError() + ) + await _cleanup_chunk_upload_session(app, session_key) + return JSONResponse({"detail": str(err)}, status_code=400) + + if emitted_bytes == 0: + try: + await session.iterator.push( + UploadChunk( + filename=metadata.filename or "", + offset=metadata.offset, + content_type=content_type, + data=b"", + ) + ) + except (RuntimeError, ValueError) as err: + console.error( + "[chunk-upload] request failed " + f"handler={handler_name} session_id={metadata.session_id} " + f"error={err} bytes={emitted_bytes} chunks={emitted_chunk_count}" + ) + await session.iterator.fail( + err if isinstance(err, Exception) else RuntimeError() + ) + await _cleanup_chunk_upload_session(app, session_key) + return JSONResponse({"detail": str(err)}, status_code=400) + emitted_chunk_count = 1 + + session.record_request( + chunk_count=emitted_chunk_count, + emitted_bytes=emitted_bytes, + ) + + if metadata.is_complete: + await session.iterator.finish() + await _cleanup_chunk_upload_session(app, session_key) + + console.info( + "[chunk-upload] request completed " + f"handler={handler_name} session_id={metadata.session_id} " + f"complete={metadata.is_complete} request_bytes={emitted_bytes} " + f"request_chunks={emitted_chunk_count} stats={session.stats()}" + ) + return Response(status_code=202) + + return upload_file_chunk + + class EventNamespace(AsyncNamespace): """The event namespace.""" diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index 670112fad33..cc423a4459e 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -27,6 +27,7 @@ EventChain, EventHandler, EventSpec, + UploadChunkIterator, call_event_fn, call_event_handler, parse_args_spec, @@ -172,6 +173,10 @@ def get_upload_url(file_path: str | Var[str]) -> Var[str]: _on_drop_spec = passthrough_event_spec(list[UploadFile]) +_on_drop_args_spec = ( + _on_drop_spec, + passthrough_event_spec(UploadChunkIterator), +) def _default_drop_rejected(rejected_files: ArrayVar[list[dict[str, Any]]]) -> EventSpec: @@ -212,10 +217,10 @@ class GhostUpload(Fragment): """A ghost upload component.""" # Fired when files are dropped. - on_drop: EventHandler[_on_drop_spec] + on_drop: EventHandler[_on_drop_args_spec] # Fired when dropped files do not meet the specified criteria. - on_drop_rejected: EventHandler[_on_drop_spec] + on_drop_rejected: EventHandler[_on_drop_args_spec] class Upload(MemoizationLeaf): @@ -258,10 +263,10 @@ class Upload(MemoizationLeaf): is_used: ClassVar[bool] = False # Fired when files are dropped. - on_drop: EventHandler[_on_drop_spec] + on_drop: EventHandler[_on_drop_args_spec] # Fired when dropped files do not meet the specified criteria. - on_drop_rejected: EventHandler[_on_drop_spec] + on_drop_rejected: EventHandler[_on_drop_args_spec] # Style rules to apply when actively dragging. drag_active_style: Style | None = field(default=None, is_javascript_property=False) @@ -310,11 +315,15 @@ def create(cls, *children, **props) -> Component: if isinstance(event, EventHandler): event = event(upload_files(upload_id)) if isinstance(event, EventSpec): - # Call the lambda to get the event chain. - event = call_event_handler(event, _on_drop_spec) + if event.client_handler_name not in { + "uploadFiles", + "uploadFilesChunk", + }: + # Call the lambda to get the event chain. + event = call_event_handler(event, _on_drop_args_spec) elif isinstance(event, Callable): # Call the lambda to get the event chain. - event = call_event_fn(event, _on_drop_spec) + event = call_event_fn(event, _on_drop_args_spec) if isinstance(event, EventSpec): # Update the provided args for direct use with on_drop. event = event.with_args( diff --git a/reflex/constants/event.py b/reflex/constants/event.py index 6a0f71ec161..9fb5305c20a 100644 --- a/reflex/constants/event.py +++ b/reflex/constants/event.py @@ -10,6 +10,7 @@ class Endpoint(Enum): PING = "ping" EVENT = "_event" UPLOAD = "_upload" + UPLOAD_CHUNK = "_upload_chunk" AUTH_CODESPACE = "auth-codespace" HEALTH = "_health" ALL_ROUTES = "_all_routes" diff --git a/reflex/event.py b/reflex/event.py index ff75e3bd3cb..66cc0d7ae94 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1,11 +1,13 @@ """Define event classes to connect the frontend and backend.""" +import asyncio import dataclasses import inspect import sys import types from base64 import b64encode -from collections.abc import Callable, Mapping, Sequence +from collections import deque +from collections.abc import AsyncIterator, Callable, Mapping, Sequence from functools import lru_cache, partial from typing import ( TYPE_CHECKING, @@ -92,6 +94,229 @@ def substate_token(self) -> str: EVENT_ACTIONS_MARKER = "_rx_event_actions" +@dataclasses.dataclass( + init=True, + frozen=True, +) +class UploadChunk: + """A chunk of uploaded file data.""" + + filename: str + offset: int + content_type: str + data: bytes + + +class UploadChunkIterator(AsyncIterator[UploadChunk]): + """An async iterator over uploaded file chunks.""" + + def __init__(self, *, maxsize: int = 8): + """Initialize the iterator. + + Args: + maxsize: Maximum number of chunks to buffer before blocking producers. + """ + self._maxsize = maxsize + self._chunks: deque[UploadChunk] = deque() + self._condition = asyncio.Condition() + self._closed = False + self._error: Exception | None = None + self._consumer_task: asyncio.Task[Any] | None = None + + def __aiter__(self) -> Self: + """Return the iterator itself.""" + return self + + async def __anext__(self) -> UploadChunk: + """Yield the next available upload chunk. + + Returns: + The next upload chunk. + + Raises: + StopAsyncIteration: When all chunks have been consumed. + """ + async with self._condition: + while not self._chunks and not self._closed: + await self._condition.wait() + + if self._chunks: + chunk = self._chunks.popleft() + self._condition.notify_all() + return chunk + + if self._error is not None: + raise self._error + raise StopAsyncIteration + + def set_consumer_task(self, task: asyncio.Task[Any]) -> None: + """Track the task consuming this iterator. + + Args: + task: The background task consuming upload chunks. + """ + self._consumer_task = task + task.add_done_callback(self._wake_waiters) + + async def push(self, chunk: UploadChunk) -> None: + """Push a new chunk into the iterator. + + Args: + chunk: The chunk to push. + + Raises: + RuntimeError: If the iterator is already closed or the consumer exited early. + """ + async with self._condition: + while len(self._chunks) >= self._maxsize and not self._closed: + self._raise_if_consumer_finished() + await self._condition.wait() + + if self._closed: + msg = "Upload chunk iterator is closed." + raise RuntimeError(msg) + + self._raise_if_consumer_finished() + self._chunks.append(chunk) + self._condition.notify_all() + + async def finish(self) -> None: + """Mark the iterator as complete.""" + async with self._condition: + if self._closed: + return + self._closed = True + self._condition.notify_all() + + async def fail(self, error: Exception) -> None: + """Mark the iterator as failed. + + Args: + error: The error to raise from the iterator. + """ + async with self._condition: + if self._closed: + return + self._closed = True + self._error = error + self._condition.notify_all() + + def _raise_if_consumer_finished(self) -> None: + """Raise if the consumer task exited before draining the iterator.""" + if self._consumer_task is None or not self._consumer_task.done(): + return + + try: + task_exc = self._consumer_task.exception() + except asyncio.CancelledError as err: + task_exc = err + + msg = "Upload handler returned before consuming all upload chunks." + if task_exc is not None: + raise RuntimeError(msg) from task_exc + raise RuntimeError(msg) + + def _wake_waiters(self, task: asyncio.Task[Any]) -> None: + """Wake any producers or consumers blocked on the iterator condition. + + Args: + task: The completed consumer task. + """ + task.get_loop().create_task(self._notify_waiters()) + + async def _notify_waiters(self) -> None: + """Notify tasks waiting on the iterator condition.""" + async with self._condition: + self._condition.notify_all() + + +def _handler_name(handler: "EventHandler") -> str: + """Get a stable fully qualified handler name for errors. + + Args: + handler: The handler to name. + + Returns: + The fully qualified handler name. + """ + if handler.state_full_name: + return f"{handler.state_full_name}.{handler.fn.__name__}" + return handler.fn.__qualname__ + + +def resolve_upload_handler_param(handler: "EventHandler") -> tuple[str, Any]: + """Validate and resolve the UploadFile list parameter for a handler. + + Args: + handler: The event handler to inspect. + + Returns: + The parameter name and annotation for the upload file argument. + + Raises: + UploadTypeError: If the handler is a background task. + UploadValueError: If the handler does not accept ``list[rx.UploadFile]``. + """ + from reflex.app import UploadFile + from reflex.utils.exceptions import UploadTypeError, UploadValueError + + handler_name = _handler_name(handler) + if handler.is_background: + msg = ( + f"@rx.event(background=True) is not supported for upload handler " + f"`{handler_name}`." + ) + raise UploadTypeError(msg) + + func = handler.fn.func if isinstance(handler.fn, partial) else handler.fn + for name, annotation in get_type_hints(func).items(): + if name == "return" or get_origin(annotation) is not list: + continue + args = get_args(annotation) + if len(args) == 1 and typehint_issubclass(args[0], UploadFile): + return name, annotation + + msg = ( + f"`{handler_name}` handler should have a parameter annotated as " + "list[rx.UploadFile]" + ) + raise UploadValueError(msg) + + +def resolve_upload_chunk_handler_param(handler: "EventHandler") -> tuple[str, type]: + """Validate and resolve the UploadChunkIterator parameter for a handler. + + Args: + handler: The event handler to inspect. + + Returns: + The parameter name and annotation for the iterator argument. + + Raises: + UploadTypeError: If the handler is not a background task. + UploadValueError: If the handler does not accept an UploadChunkIterator. + """ + from reflex.utils.exceptions import UploadTypeError, UploadValueError + + handler_name = _handler_name(handler) + if not handler.is_background: + msg = f"@rx.event(background=True) is required for upload_files_chunk handler `{handler_name}`." + raise UploadTypeError(msg) + + func = handler.fn.func if isinstance(handler.fn, partial) else handler.fn + for name, annotation in get_type_hints(func).items(): + if name == "return": + continue + if annotation is UploadChunkIterator: + return name, annotation + + msg = ( + f"`{handler_name}` handler should have a parameter annotated as " + "rx.UploadChunkIterator" + ) + raise UploadValueError(msg) + + @dataclasses.dataclass( init=True, frozen=True, @@ -282,7 +507,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "EventSpec": values = [] for arg in [*args, *kwargs.values()]: # Special case for file uploads. - if isinstance(arg, FileUpload): + if isinstance(arg, (FileUpload, UploadFilesChunk)): return arg.as_event_spec(handler=self) # Otherwise, convert to JSON. @@ -858,14 +1083,22 @@ def on_upload_progress_args_spec(_prog: Var[dict[str, int | float | bool]]): """ return [_prog] - def as_event_spec(self, handler: EventHandler) -> EventSpec: - """Get the EventSpec for the file upload. + def _as_event_spec( + self, + handler: EventHandler, + *, + client_handler_name: str, + upload_param_name: str, + ) -> EventSpec: + """Create an upload EventSpec. Args: handler: The event handler. + client_handler_name: The client handler name. + upload_param_name: The upload argument name in the event handler. Returns: - The event spec for the handler. + The upload EventSpec. Raises: ValueError: If the on_upload_progress is not a valid event handler. @@ -876,14 +1109,19 @@ def as_event_spec(self, handler: EventHandler) -> EventSpec: ) upload_id = self.upload_id if self.upload_id is not None else DEFAULT_UPLOAD_ID + upload_files_var = Var( + _js_expr="filesById", + _var_type=dict[str, Any], + _var_data=VarData.merge(upload_files_context_var_data), + ).to(ObjectVar)[LiteralVar.create(upload_id)] spec_args = [ ( Var(_js_expr="files"), - Var( - _js_expr="filesById", - _var_type=dict[str, Any], - _var_data=VarData.merge(upload_files_context_var_data), - ).to(ObjectVar)[LiteralVar.create(upload_id)], + upload_files_var, + ), + ( + Var(_js_expr="upload_param_name"), + LiteralVar.create(upload_param_name), ), ( Var(_js_expr="upload_id"), @@ -896,6 +1134,14 @@ def as_event_spec(self, handler: EventHandler) -> EventSpec: ), ), ] + if upload_param_name != "files": + spec_args.insert( + 1, + ( + Var(_js_expr=upload_param_name), + upload_files_var, + ), + ) if self.on_upload_progress is not None: on_upload_progress = self.on_upload_progress if isinstance(on_upload_progress, EventHandler): @@ -931,16 +1177,65 @@ def as_event_spec(self, handler: EventHandler) -> EventSpec: ) return EventSpec( handler=handler, - client_handler_name="uploadFiles", + client_handler_name=client_handler_name, args=tuple(spec_args), event_actions=handler.event_actions.copy(), ) + def as_event_spec(self, handler: EventHandler) -> EventSpec: + """Get the EventSpec for the file upload. + + Args: + handler: The event handler. + + Returns: + The event spec for the handler. + """ + from reflex.utils.exceptions import UploadValueError + + try: + upload_param_name, _annotation = resolve_upload_handler_param(handler) + except UploadValueError: + upload_param_name = "files" + return self._as_event_spec( + handler, + client_handler_name="uploadFiles", + upload_param_name=upload_param_name, + ) + # Alias for rx.upload_files upload_files = FileUpload +@dataclasses.dataclass( + init=True, + frozen=True, +) +class UploadFilesChunk(FileUpload): + """Class to represent a streaming file upload.""" + + def as_event_spec(self, handler: EventHandler) -> EventSpec: + """Get the EventSpec for the streaming file upload. + + Args: + handler: The event handler. + + Returns: + The event spec for the handler. + """ + upload_param_name, _annotation = resolve_upload_chunk_handler_param(handler) + return self._as_event_spec( + handler, + client_handler_name="uploadFilesChunk", + upload_param_name=upload_param_name, + ) + + +# Alias for rx.upload_files_chunk +upload_files_chunk = UploadFilesChunk + + # Special server-side events. def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec: """A server-side event. @@ -2303,6 +2598,9 @@ class EventNamespace: # File Upload FileUpload = FileUpload + UploadChunk = UploadChunk + UploadChunkIterator = UploadChunkIterator + UploadFilesChunk = UploadFilesChunk # Type Aliases EventType = EventType @@ -2316,10 +2614,15 @@ class EventNamespace: _EVENT_FIELDS = _EVENT_FIELDS FORM_DATA = FORM_DATA upload_files = upload_files + upload_files_chunk = upload_files_chunk stop_propagation = stop_propagation prevent_default = prevent_default # Private/Internal Functions + resolve_upload_handler_param = staticmethod(resolve_upload_handler_param) + resolve_upload_chunk_handler_param = staticmethod( + resolve_upload_chunk_handler_param + ) _values_returned_from_event = staticmethod(_values_returned_from_event) _check_event_args_subclass_of_callback = staticmethod( _check_event_args_subclass_of_callback diff --git a/reflex/utils/build.py b/reflex/utils/build.py index 7b7408f8b3b..be0ef1b7ac9 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -10,7 +10,14 @@ from reflex import constants from reflex.config import get_config -from reflex.utils import console, js_runtimes, path_ops, prerequisites, processes +from reflex.utils import ( + console, + frontend_skeleton, + js_runtimes, + path_ops, + prerequisites, + processes, +) from reflex.utils.exec import is_in_app_harness @@ -260,6 +267,8 @@ def setup_frontend( Args: root: The root path of the project. """ + frontend_skeleton.sync_web_runtime_templates() + # Set the environment variables in client (env.json). set_env_json() diff --git a/reflex/utils/frontend_skeleton.py b/reflex/utils/frontend_skeleton.py index 37d66fac7db..07fd5873038 100644 --- a/reflex/utils/frontend_skeleton.py +++ b/reflex/utils/frontend_skeleton.py @@ -13,6 +13,20 @@ from reflex.utils.prerequisites import get_project_hash, get_web_dir from reflex.utils.registry import get_npm_registry +_WEB_RUNTIME_TEMPLATE_DIRS = ( + Path("components"), + Path(constants.Dirs.UTILS), +) + +_WEB_RUNTIME_TEMPLATE_FILES = ( + Path("app") / "entry.client.js", + Path("app") / "routes.js", + Path("styles") / "__reflex_style_reset.css", + Path("jsconfig.json"), + Path("postcss.config.js"), + Path("vite-plugin-safari-cachebust.js"), +) + def initialize_gitignore( gitignore_file: Path = constants.GitIgnore.FILE, @@ -124,6 +138,37 @@ def initialize_web_directory(): init_reflex_json(project_hash=project_hash) +def sync_web_runtime_templates(): + """Refresh the static runtime files in the web directory. + + This keeps shared frontend helpers in `.web` in sync with the framework + templates without wiping generated route modules or installed dependencies. + """ + template_dir = Path(constants.Templates.Dirs.WEB_TEMPLATE) + web_dir = get_web_dir() + + for relative_dir in _WEB_RUNTIME_TEMPLATE_DIRS: + source_dir = template_dir / relative_dir + if not source_dir.exists(): + continue + for source_file in source_dir.rglob("*"): + if source_file.is_dir(): + continue + relative_file = source_file.relative_to(template_dir) + target_file = web_dir / relative_file + console.debug(f"Syncing {source_file} to {target_file}") + path_ops.mkdir(target_file.parent) + path_ops.cp(source_file, target_file) + + for relative_file in _WEB_RUNTIME_TEMPLATE_FILES: + source_file = template_dir / relative_file + target_file = web_dir / relative_file + if source_file.exists(): + console.debug(f"Syncing {source_file} to {target_file}") + path_ops.mkdir(target_file.parent) + path_ops.cp(source_file, target_file) + + def update_react_router_config(prerender_routes: bool = False): """Update react-router.config.js config from Reflex config. diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index ed4a5456cd1..15cbc3bb86a 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -6,11 +6,13 @@ import time from collections.abc import Generator from pathlib import Path +from typing import Any, cast from urllib.parse import urlsplit import pytest from selenium.webdriver.common.by import By +import reflex as rx from reflex.constants.event import Endpoint from reflex.testing import AppHarness, WebDriver @@ -19,17 +21,18 @@ def UploadFile(): """App for testing dynamic routes.""" - import reflex as rx - LARGE_DATA = "DUMMY" * 1024 * 512 class UploadState(rx.State): _file_data: dict[str, str] = {} event_order: rx.Field[list[str]] = rx.field([]) progress_dicts: rx.Field[list[dict]] = rx.field([]) + stream_progress_dicts: rx.Field[list[dict]] = rx.field([]) disabled: rx.Field[bool] = rx.field(False) large_data: rx.Field[str] = rx.field("") quaternary_names: rx.Field[list[str]] = rx.field([]) + stream_chunk_records: rx.Field[list[str]] = rx.field([]) + stream_completed_files: rx.Field[list[str]] = rx.field([]) @rx.event async def handle_upload(self, files: list[rx.UploadFile]): @@ -57,6 +60,11 @@ def chain_event(self): self.large_data = "" self.event_order.append("chain_event") + @rx.event + def stream_upload_progress(self, progress): + assert progress + self.stream_progress_dicts.append(progress) + @rx.event async def handle_upload_tertiary(self, files: list[rx.UploadFile]): for file in files: @@ -68,6 +76,36 @@ async def handle_upload_tertiary(self, files: list[rx.UploadFile]): async def handle_upload_quaternary(self, files: list[rx.UploadFile]): self.quaternary_names = [file.name for file in files if file.name] + @rx.event(background=True) + async def handle_upload_stream(self, chunk_iter: rx.UploadChunkIterator): + upload_dir = rx.get_upload_dir() / "streaming" + file_buffers: dict[str, bytearray] = {} + + async for chunk in chunk_iter: + buf = file_buffers.get(chunk.filename) + if buf is None: + buf = bytearray() + file_buffers[chunk.filename] = buf + + end = chunk.offset + len(chunk.data) + if end > len(buf): + buf.extend(b"\x00" * (end - len(buf))) + buf[chunk.offset : end] = chunk.data + + async with self: + self.stream_chunk_records.append( + f"{chunk.filename}:{chunk.offset}:{len(chunk.data)}" + ) + + for filename, buf in file_buffers.items(): + path = upload_dir / filename + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("wb") as fh: + fh.write(buf) + + async with self: + self.stream_completed_files = sorted(file_buffers) + @rx.event def do_download(self): return rx.download(rx.get_upload_url("test.txt")) @@ -188,6 +226,44 @@ def index(): UploadState.quaternary_names.to_string(), id="quaternary_files", ), + rx.heading("Streaming Upload"), + rx.upload.root( + rx.vstack( + rx.button("Select File"), + rx.text("Drag and drop files here or click to select files"), + ), + id="streaming", + ), + rx.button( + "Upload", + on_click=UploadState.handle_upload_stream( + rx.upload_files_chunk( # pyright: ignore [reportArgumentType] + upload_id="streaming", + on_upload_progress=UploadState.stream_upload_progress, + ) + ), + id="upload_button_streaming", + ), + rx.box( + rx.foreach( + rx.selected_files("streaming"), + lambda f: rx.text(f, as_="p"), + ), + id="selected_files_streaming", + ), + rx.button( + "Cancel", + on_click=rx.cancel_upload("streaming"), + id="cancel_button_streaming", + ), + rx.text( + UploadState.stream_chunk_records.to_string(), + id="stream_chunk_records", + ), + rx.text( + UploadState.stream_completed_files.to_string(), + id="stream_completed_files", + ), rx.text(UploadState.event_order.to_string(), id="event-order"), ) @@ -487,6 +563,140 @@ async def _progress_dicts(): target_file.unlink() +@pytest.mark.asyncio +async def test_upload_chunk_file(tmp_path, upload_file: AppHarness, driver: WebDriver): + """Submit a streaming upload and check that chunks are processed incrementally.""" + assert upload_file.app_instance is not None + token = poll_for_token(driver, upload_file) + state_name = upload_file.get_state_name("_upload_state") + state_full_name = upload_file.get_full_state_name(["_upload_state"]) + substate_token = f"{token}_{state_full_name}" + + upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[4] + upload_button = driver.find_element(By.ID, "upload_button_streaming") + selected_files = driver.find_element(By.ID, "selected_files_streaming") + chunk_records_display = driver.find_element(By.ID, "stream_chunk_records") + completed_files_display = driver.find_element(By.ID, "stream_completed_files") + + exp_files = { + "stream1.txt": "ABCD" * 262_144, + "stream2.txt": "WXYZ" * 262_144, + } + for exp_name, exp_contents in exp_files.items(): + target_file = tmp_path / exp_name + target_file.write_text(exp_contents) + upload_box.send_keys(str(target_file)) + + await asyncio.sleep(0.2) + + assert [Path(name).name for name in selected_files.text.split("\n")] == [ + Path(name).name for name in exp_files + ] + + upload_button.click() + + AppHarness.expect(lambda: "stream1.txt" in chunk_records_display.text) + + async def _stream_completed(): + state = await upload_file.get_state(substate_token) + return ( + len( + state.substates[state_name].stream_completed_files # pyright: ignore[reportAttributeAccessIssue] + ) + == 2 + ) + + await AppHarness._poll_for_async(_stream_completed) + + state = await upload_file.get_state(substate_token) + substate = cast(Any, state.substates[state_name]) + chunk_records = substate.stream_chunk_records + + assert len(chunk_records) > 2 + assert [Path(record.split(":")[0]).name for record in chunk_records[:2]] == [ + "stream1.txt", + "stream2.txt", + ] + assert substate.stream_completed_files == ["stream1.txt", "stream2.txt"] + + AppHarness.expect( + lambda: ( + "stream1.txt" in completed_files_display.text + and "stream2.txt" in completed_files_display.text + ) + ) + + for exp_name, exp_contents in exp_files.items(): + assert ( + rx.get_upload_dir() / "streaming" / exp_name + ).read_text() == exp_contents + + +@pytest.mark.asyncio +async def test_cancel_upload_chunk( + tmp_path, + upload_file: AppHarness, + driver: WebDriver, +): + """Submit a large streaming upload and cancel it.""" + assert upload_file.app_instance is not None + driver.execute_cdp_cmd("Network.enable", {}) + driver.execute_cdp_cmd( + "Network.emulateNetworkConditions", + { + "offline": False, + "downloadThroughput": 1024 * 1024 / 8, # 1 Mbps + "uploadThroughput": 1024 * 1024 / 8, # 1 Mbps + "latency": 200, # 200ms + }, + ) + token = poll_for_token(driver, upload_file) + state_name = upload_file.get_state_name("_upload_state") + state_full_name = upload_file.get_full_state_name(["_upload_state"]) + substate_token = f"{token}_{state_full_name}" + + upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[4] + upload_button = driver.find_element(By.ID, "upload_button_streaming") + cancel_button = driver.find_element(By.ID, "cancel_button_streaming") + + exp_name = "cancel_stream.txt" + target_file = tmp_path / exp_name + with target_file.open("wb") as f: + f.seek(2 * 1024 * 1024) + f.write(b"0") + + upload_box.send_keys(str(target_file)) + upload_button.click() + await asyncio.sleep(1) + cancel_button.click() + + await asyncio.sleep(12) + + async def _stream_progress_dicts(): + state = await upload_file.get_state(substate_token) + return ( + state.substates[state_name].stream_progress_dicts # pyright: ignore[reportAttributeAccessIssue] + ) + + assert await AppHarness._poll_for_async(_stream_progress_dicts) + + for progress in await _stream_progress_dicts(): + assert progress["progress"] != 1 + + state = await upload_file.get_state(substate_token) + substate = cast(Any, state.substates[state_name]) + assert substate.stream_completed_files == [] + assert substate.stream_chunk_records + + partial_path = rx.get_upload_dir() / "streaming" / exp_name + assert partial_path.exists() + assert partial_path.stat().st_size < target_file.stat().st_size + + target_file.unlink() + if partial_path.exists(): + partial_path.unlink() + + def test_upload_download_file( tmp_path, upload_file: AppHarness, diff --git a/tests/units/components/core/test_upload.py b/tests/units/components/core/test_upload.py index 3b03362d6e4..b9dfb4bb976 100644 --- a/tests/units/components/core/test_upload.py +++ b/tests/units/components/core/test_upload.py @@ -1,5 +1,8 @@ -from typing import Any +from typing import Any, cast +import pytest + +import reflex as rx from reflex import event from reflex.components.core.upload import ( StyledUpload, @@ -9,7 +12,7 @@ cancel_upload, get_upload_url, ) -from reflex.event import EventSpec +from reflex.event import EventChain, EventHandler, EventSpec from reflex.state import State from reflex.vars.base import LiteralVar, Var @@ -33,6 +36,31 @@ def not_drop_handler(self, not_files: Any): not_files: The files dropped. """ + @event + async def upload_alias_handler(self, uploads: list[rx.UploadFile]): + """Handle uploaded files with a non-default parameter name.""" + + +class StreamingUploadStateTest(State): + """Test state for streaming uploads.""" + + @event(background=True) + async def chunk_drop_handler(self, chunk_iter: rx.UploadChunkIterator): + """Handle streamed upload chunks.""" + + @event(background=True) + async def chunk_upload_alias_handler(self, stream: rx.UploadChunkIterator): + """Handle streamed upload chunks with a non-default parameter name.""" + + async def chunk_drop_handler_not_background( + self, chunk_iter: rx.UploadChunkIterator + ): + """Invalid handler used to validate background-task requirement.""" + + @event(background=True) + async def chunk_drop_handler_missing_annotation(self, chunk_iter): + """Invalid handler missing the UploadChunkIterator annotation.""" + def test_cancel_upload(): spec = cancel_upload("foo_id") @@ -44,10 +72,54 @@ def test_get_upload_url(): assert isinstance(url, Var) +def test_upload_files_chunk_export(): + chunk = rx.UploadChunk( + filename="foo.txt", + offset=0, + content_type="text/plain", + data=b"hello", + ) + + assert chunk.filename == "foo.txt" + assert isinstance(rx.UploadChunkIterator(), rx.UploadChunkIterator) + assert callable(rx.upload_files_chunk) + + def test__on_drop_spec(): assert isinstance(_on_drop_spec(LiteralVar.create([])), tuple) +def test_upload_files_chunk_requires_background(): + with pytest.raises(TypeError) as err: + event.resolve_upload_chunk_handler_param( + cast( + EventHandler, StreamingUploadStateTest.chunk_drop_handler_not_background + ) + ) + + assert ( + err.value.args[0] + == "@rx.event(background=True) is required for upload_files_chunk handler " + f"`{StreamingUploadStateTest.get_full_name()}.chunk_drop_handler_not_background`." + ) + + +def test_upload_files_chunk_requires_iterator_annotation(): + with pytest.raises(ValueError) as err: + event.resolve_upload_chunk_handler_param( + cast( + EventHandler, + StreamingUploadStateTest.chunk_drop_handler_missing_annotation, + ) + ) + + assert ( + err.value.args[0] + == f"`{StreamingUploadStateTest.get_full_name()}.chunk_drop_handler_missing_annotation` " + "handler should have a parameter annotated as rx.UploadChunkIterator" + ) + + def test_upload_create(): up_comp_1 = Upload.create() assert isinstance(up_comp_1, Upload) @@ -83,6 +155,51 @@ def test_upload_create(): assert isinstance(up_comp_4, Upload) assert up_comp_4.is_used + # reset is_used + Upload.is_used = False + + up_comp_5 = Upload.create( + id="foo_id", + on_drop=StreamingUploadStateTest.chunk_drop_handler( + cast(Any, rx.upload_files_chunk(upload_id="foo_id")) + ), + ) + assert isinstance(up_comp_5, Upload) + assert up_comp_5.is_used + + up_comp_6 = Upload.create( + id="foo_id", + on_drop=StreamingUploadStateTest.chunk_upload_alias_handler( + cast(Any, rx.upload_files_chunk(upload_id="foo_id")) + ), + ) + assert isinstance(up_comp_6, Upload) + assert up_comp_6.is_used + + +def test_upload_button_handlers_allow_custom_param_names(): + legacy_button = rx.button( + "Upload", + on_click=UploadStateTest.upload_alias_handler( + cast(Any, rx.upload_files(upload_id="foo_id")) + ), + ) + legacy_chain = cast(EventChain, legacy_button.event_triggers["on_click"]) + legacy_event = cast(EventSpec, legacy_chain.events[0]) + legacy_arg_names = [arg[0]._js_expr for arg in legacy_event.args] + assert legacy_arg_names[:3] == ["files", "uploads", "upload_param_name"] + + chunk_button = rx.button( + "Upload", + on_click=StreamingUploadStateTest.chunk_upload_alias_handler( + cast(Any, rx.upload_files_chunk(upload_id="foo_id")) + ), + ) + chunk_chain = cast(EventChain, chunk_button.event_triggers["on_click"]) + chunk_event = cast(EventSpec, chunk_chain.events[0]) + chunk_arg_names = [arg[0]._js_expr for arg in chunk_event.args] + assert chunk_arg_names[:3] == ["files", "stream", "upload_param_name"] + def test_styled_upload_create(): styled_up_comp_1 = StyledUpload.create() diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py index 6c732796a73..d53304a3be6 100644 --- a/tests/units/states/upload.py +++ b/tests/units/states/upload.py @@ -1,6 +1,7 @@ """Test states for upload-related tests.""" from pathlib import Path +from typing import BinaryIO import reflex as rx from reflex.state import BaseState, State @@ -78,6 +79,56 @@ class FileUploadState(_FileUploadMixin, State): """The base state for uploading a file.""" +class _ChunkUploadMixin(BaseState, mixin=True): + """Common fields and handlers for chunk upload tests.""" + + chunk_records: list[str] + completed_files: list[str] + _tmp_path: Path = Path() + + @rx.event(background=True) + async def chunk_handle_upload(self, chunk_iter: rx.UploadChunkIterator): + """Handle a chunked upload in the background.""" + file_handles: dict[str, BinaryIO] = {} + + try: + async for chunk in chunk_iter: + outfile = self._tmp_path / chunk.filename + outfile.parent.mkdir(parents=True, exist_ok=True) + + fh = file_handles.get(chunk.filename) + if fh is None: + fh = outfile.open("r+b") if outfile.exists() else outfile.open("wb") + file_handles[chunk.filename] = fh + + fh.seek(chunk.offset) + fh.write(chunk.data) + + async with self: + self.chunk_records.append( + f"{chunk.filename}:{chunk.offset}:{len(chunk.data)}:{chunk.content_type}" + ) + finally: + for fh in file_handles.values(): + fh.close() + + async with self: + self.completed_files = sorted(file_handles) + + async def chunk_handle_upload_not_background( + self, chunk_iter: rx.UploadChunkIterator + ): + """Invalid streaming upload handler used for compile-time validation tests.""" + + @rx.event(background=True) + async def chunk_handle_upload_missing_annotation(self, chunk_iter): + """Invalid streaming upload handler missing the iterator annotation.""" + + +class ChunkUploadState(_ChunkUploadMixin, State): + """The base state for streaming chunk uploads.""" + + class FileStateBase1(State): """The base state for a child FileUploadState.""" diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 25c71c0d17e..a803a81c96f 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -10,13 +10,14 @@ from contextlib import nullcontext as does_not_raise from importlib.util import find_spec from pathlib import Path +from types import SimpleNamespace from typing import TYPE_CHECKING, Any, ClassVar from unittest.mock import AsyncMock import pytest from pytest_mock import MockerFixture from starlette.applications import Starlette -from starlette.datastructures import FormData, UploadFile +from starlette.datastructures import FormData, Headers, UploadFile from starlette.responses import StreamingResponse import reflex as rx @@ -27,6 +28,7 @@ default_overlay_component, process, upload, + upload_chunk, ) from reflex.components import Component from reflex.components.base.bare import Bare @@ -57,6 +59,7 @@ from .states import GenState from .states.upload import ( ChildFileUploadState, + ChunkUploadState, FileStateBase1, FileUploadState, GrandChildFileUploadState, @@ -1271,6 +1274,236 @@ async def form(): # noqa: RUF029 await app.state_manager.close() +def _build_chunk_upload_body( + data: bytes, +) -> bytes: + """Build the raw request body for chunk upload tests. + + Returns: + The bytes that should be sent as the request body. + """ + return data + + +def _make_chunk_upload_request( + token: str, + handler_name: str, + session_id: str, + body: bytes, + *, + filename: str | None = None, + offset: int | None = None, + complete: bool = False, + cancel: bool = False, + upload_id: str = "streaming", + content_type: str = "application/octet-stream", + stream_chunk_size: int = 17, +): + """Create a mocked request for the chunk upload endpoint. + + Returns: + A mocked Starlette request object. + """ + request_mock = unittest.mock.Mock() + request_mock.headers = Headers({ + "content-type": content_type, + "reflex-client-token": token, + "reflex-event-handler": handler_name, + }) + request_mock.query_params = { + "session_id": session_id, + "upload_id": upload_id, + **({"filename": filename} if filename is not None else {}), + **({"offset": str(offset)} if offset is not None else {}), + **({"complete": "1"} if complete else {}), + **({"cancel": "1"} if cancel else {}), + } + + async def stream(): + for index in range(0, len(body), stream_chunk_size): + yield body[index : index + stream_chunk_size] + yield b"" + await asyncio.sleep(0) + + request_mock.stream = stream + return request_mock + + +async def _drain_background_tasks(app: App): + """Wait for all background tasks associated with an app. + + Returns: + The gathered background task results. + """ + tasks = tuple(app._background_tasks) + if tasks: + return await asyncio.gather(*tasks, return_exceptions=True) + return [] + + +@pytest.mark.asyncio +async def test_upload_chunk_streams_chunks(tmp_path, token: str, mocker: MockerFixture): + """Test streaming upload chunks through the background upload endpoint.""" + mocker.patch( + "reflex.state.State.class_subclasses", + {ChunkUploadState}, + ) + app = App() + mocker.patch( + "reflex.utils.prerequisites.get_and_validate_app", + return_value=SimpleNamespace(app=app), + ) + app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + + async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: + substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + substate._tmp_path = tmp_path + substate.chunk_records = [] + substate.completed_files = [] + + upload_fn = upload_chunk(app) + session_id = "stream-session" + for filename, offset, data in ( + ("alpha.txt", 0, b"ab"), + ("beta.txt", 0, b"12"), + ("alpha.txt", 2, b"cde"), + ("beta.txt", 2, b"345"), + ): + response = await upload_fn( + _make_chunk_upload_request( + token, + f"{ChunkUploadState.get_full_name()}.chunk_handle_upload", + session_id, + _build_chunk_upload_body(data), + filename=filename, + offset=offset, + complete=filename == "beta.txt" and offset == 2, + content_type="text/plain", + stream_chunk_size=1, + ) + ) + assert response.status_code == 202 + + task_results = await _drain_background_tasks(app) + assert task_results == [None] + + state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) + substate = ( + state + if isinstance(state, ChunkUploadState) + else state.get_substate(ChunkUploadState.get_full_name().split(".")) + ) + assert isinstance(substate, ChunkUploadState) + parsed_chunk_records = [ + (filename, int(offset), int(size), content_type) + for filename, offset, size, content_type in ( + record.rsplit(":", 3) for record in substate.chunk_records + ) + ] + assert len(parsed_chunk_records) >= 4 + assert {filename for filename, *_ in parsed_chunk_records} == { + "alpha.txt", + "beta.txt", + } + assert all( + content_type == "text/plain" for *_, content_type in parsed_chunk_records + ) + assert ( + sum( + size + for filename, _offset, size, _content_type in parsed_chunk_records + if filename == "alpha.txt" + ) + == 5 + ) + assert ( + sum( + size + for filename, _offset, size, _content_type in parsed_chunk_records + if filename == "beta.txt" + ) + == 5 + ) + beta_initial_chunk_index = next( + index + for index, (filename, offset, _size, _content_type) in enumerate( + parsed_chunk_records + ) + if filename == "beta.txt" and offset == 0 + ) + alpha_later_chunk_index = next( + index + for index, (filename, offset, _size, _content_type) in enumerate( + parsed_chunk_records + ) + if filename == "alpha.txt" and offset >= 2 + ) + assert beta_initial_chunk_index < alpha_later_chunk_index + assert substate.completed_files == ["alpha.txt", "beta.txt"] + assert (tmp_path / "alpha.txt").read_bytes() == b"abcde" + assert (tmp_path / "beta.txt").read_bytes() == b"12345" + assert app.event_namespace.emit_update.await_count >= 1 # pyright: ignore [reportOptionalMemberAccess] + assert not app._background_tasks + + await app.state_manager.close() + + +@pytest.mark.asyncio +async def test_upload_chunk_invalid_offset_returns_400( + token: str, + mocker: MockerFixture, +): + """Test that malformed chunk metadata fails the upload request.""" + mocker.patch( + "reflex.state.State.class_subclasses", + {ChunkUploadState}, + ) + app = App() + mocker.patch( + "reflex.utils.prerequisites.get_and_validate_app", + return_value=SimpleNamespace(app=app), + ) + app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + + async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: + substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + substate.chunk_records = [] + substate.completed_files = [] + + upload_fn = upload_chunk(app) + response = await upload_fn( + _make_chunk_upload_request( + token, + f"{ChunkUploadState.get_full_name()}.chunk_handle_upload", + "bad-session", + _build_chunk_upload_body(b"abc"), + filename="alpha.txt", + offset=-1, + content_type="text/plain", + ) + ) + + assert response.status_code == 400 + assert json.loads(bytes(response.body).decode()) == { + "detail": "Invalid offset for upload chunk request: -1" + } + + await _drain_background_tasks(app) + + state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) + substate = ( + state + if isinstance(state, ChunkUploadState) + else state.get_substate(ChunkUploadState.get_full_name().split(".")) + ) + assert isinstance(substate, ChunkUploadState) + assert substate.chunk_records == [] + assert substate.completed_files == [] + assert not app._background_tasks + + await app.state_manager.close() + + class DynamicState(BaseState): """State class for testing dynamic route var. diff --git a/tests/units/test_prerequisites.py b/tests/units/test_prerequisites.py index cf22404bd69..1744dc48406 100644 --- a/tests/units/test_prerequisites.py +++ b/tests/units/test_prerequisites.py @@ -1,3 +1,4 @@ +import os import shutil import tempfile from pathlib import Path @@ -14,6 +15,7 @@ _compile_package_json, _compile_vite_config, _update_react_router_config, + sync_web_runtime_templates, ) from reflex.utils.rename import rename_imports_and_app_name from reflex.utils.telemetry import CpuInfo, get_cpu_info @@ -92,6 +94,76 @@ def test_initialise_vite_config(config, expected_output): assert expected_output in output +def test_sync_web_runtime_templates_preserves_generated_routes( + temp_directory, monkeypatch +): + template_dir = temp_directory / "template_web" + web_dir = temp_directory / ".web" + + template_files = { + "utils/state.js": "new state helper", + "utils/helpers/upload.js": "new upload helper", + "components/reflex/radix_themes_color_mode_provider.js": "new component", + "app/entry.client.js": "new entry client", + "app/routes.js": "new routes file", + "styles/__reflex_style_reset.css": "new reset", + "jsconfig.json": '{"compilerOptions": {}}', + "postcss.config.js": "module.exports = {};", + "vite-plugin-safari-cachebust.js": "export default function plugin() {}", + } + + for relative_path, content in template_files.items(): + file_path = template_dir / relative_path + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content) + + stale_files = { + "utils/state.js": "old state helper", + "utils/helpers/upload.js": "old upload helper", + "components/reflex/radix_themes_color_mode_provider.js": "old component", + "app/entry.client.js": "old entry client", + "app/routes.js": "old routes file", + } + for relative_path, content in stale_files.items(): + file_path = web_dir / relative_path + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content) + + generated_route = web_dir / "app" / "routes" / "compare._index.jsx" + generated_route.parent.mkdir(parents=True, exist_ok=True) + generated_route.write_text("generated route module") + generated_context = web_dir / "utils" / "context.js" + generated_context.parent.mkdir(parents=True, exist_ok=True) + generated_context.write_text("generated context module") + generated_component = web_dir / "components" / "utils.js" + generated_component.parent.mkdir(parents=True, exist_ok=True) + generated_component.write_text("generated component module") + + monkeypatch.setattr("reflex.utils.frontend_skeleton.get_web_dir", lambda: web_dir) + monkeypatch.setattr( + "reflex.utils.frontend_skeleton.constants.Templates.Dirs.WEB_TEMPLATE", + os.fspath(template_dir), + ) + + sync_web_runtime_templates() + + assert (web_dir / "utils" / "state.js").read_text() == "new state helper" + assert (web_dir / "utils" / "helpers" / "upload.js").read_text() == ( + "new upload helper" + ) + assert ( + web_dir / "components" / "reflex" / "radix_themes_color_mode_provider.js" + ).read_text() == "new component" + assert (web_dir / "app" / "entry.client.js").read_text() == "new entry client" + assert (web_dir / "app" / "routes.js").read_text() == "new routes file" + assert generated_route.exists() + assert generated_route.read_text() == "generated route module" + assert generated_context.exists() + assert generated_context.read_text() == "generated context module" + assert generated_component.exists() + assert generated_component.read_text() == "generated component module" + + @pytest.mark.parametrize( ("frontend_path", "expected_command"), [