Skip to content

Commit 54ba3f1

Browse files
committed
feat(files): stream uploads to disk instead of buffering
Previously /api/files/create read the entire UploadFile into memory before writing to disk, peaking at ~100 MB for a 100 MB upload. Now we drain the upload in 1 MiB chunks straight to a `.part` temp file and atomically rename on success, so peak memory stays bounded regardless of file size and a failed upload never leaves a half-written file at the final path. - `OSFileSystem.stream_create_file` does the chunked write with atomic rename and cleanup on failure - `parse_multipart_request` now hands callers un-read `UploadFile` handles (instead of pre-read bytes), so streaming is possible without giving up the small-payload `.read()` path - Directories and the default-template notebook still go through the in-memory `create_file_or_directory` path; only real file content streams
1 parent 16d224b commit 54ba3f1

8 files changed

Lines changed: 277 additions & 44 deletions

File tree

marimo/_server/api/endpoints/file_explorer.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,20 @@ async def create_file_or_directory(
120120
$ref: "#/components/schemas/FileCreateResponse"
121121
"""
122122
try:
123-
parsed = await parse_multipart_request(
123+
async with parse_multipart_request(
124124
request, FileCreateMultipartRequest
125-
)
126-
info = file_system.create_file_or_directory(
127-
parsed.body.path,
128-
parsed.body.type,
129-
parsed.body.name,
130-
parsed.files.get("file"),
131-
)
125+
) as parsed:
126+
upload = parsed.files.get("file")
127+
# Stream when there's actual file content; the in-memory create
128+
# path still handles directories and the default-template notebook.
129+
if upload is not None and parsed.body.type in ("file", "notebook"):
130+
info = await file_system.stream_create_file(
131+
parsed.body.path, parsed.body.name, upload
132+
)
133+
else:
134+
info = file_system.create_file_or_directory(
135+
parsed.body.path, parsed.body.type, parsed.body.name, None
136+
)
132137
return FileCreateResponse(success=True, info=info)
133138
except Exception as e:
134139
LOGGER.error(f"Error creating file or directory: {e}")

marimo/_server/api/utils.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import subprocess
66
import sys
77
import webbrowser
8+
from contextlib import asynccontextmanager
89
from dataclasses import dataclass
910
from pathlib import Path
1011
from shutil import which
@@ -25,6 +26,9 @@
2526
from marimo._utils.parse_dataclass import parse_raw
2627

2728
if TYPE_CHECKING:
29+
from collections.abc import AsyncIterator
30+
31+
from starlette.datastructures import UploadFile
2832
from starlette.requests import Request
2933

3034
from marimo._session.session import Session
@@ -44,20 +48,31 @@ async def parse_request(
4448

4549
@dataclass
4650
class MultipartRequest(Generic[S]):
47-
"""Result of parsing a multipart/form-data request body."""
51+
"""Result of parsing a multipart/form-data request body.
52+
53+
`files` carries the raw `UploadFile` objects (not yet read) so callers
54+
can stream chunks via `.read(size)`. For small payloads, callers can
55+
simply `await upload.read()` to materialize the whole body.
56+
"""
4857

4958
body: S
50-
files: dict[str, bytes]
59+
files: dict[str, UploadFile]
5160

5261

62+
@asynccontextmanager
5363
async def parse_multipart_request(
5464
request: Request, cls: type[S]
55-
) -> MultipartRequest[S]:
56-
"""Parse a multipart/form-data body into a msgspec.Struct + file bytes.
65+
) -> AsyncIterator[MultipartRequest[S]]:
66+
"""Parse a multipart/form-data body into a msgspec.Struct + uploads.
5767
5868
String form fields are validated against `cls`. File upload parts are
59-
read fully into memory and returned in `files`, keyed by form-field
60-
name (callers look them up explicitly rather than via the struct).
69+
yielded as `UploadFile` objects (un-read) in `files`, keyed by
70+
form-field name, so callers can stream them to disk instead of buffering.
71+
72+
Used as an async context manager so the underlying form (and its
73+
spooled temp files / fds) are closed automatically when the caller
74+
is done — `UploadFile` parts remain readable for the entire body of
75+
the `async with` block.
6176
6277
Raises msgspec.ValidationError if required string fields are missing
6378
or invalid.
@@ -66,18 +81,16 @@ async def parse_multipart_request(
6681
# without starlette (e.g. pyodide).
6782
from starlette.datastructures import UploadFile
6883

69-
# Use as an async context manager so any spooled temp files backing
70-
# UploadFile parts are closed after parsing.
7184
async with request.form() as form:
7285
string_payload: dict[str, Any] = {}
73-
files: dict[str, bytes] = {}
86+
files: dict[str, UploadFile] = {}
7487
for key, value in form.multi_items():
7588
if isinstance(value, UploadFile):
76-
files[key] = await value.read()
89+
files[key] = value
7790
elif isinstance(value, str):
7891
string_payload[key] = value
79-
body = msgspec.convert(string_payload, cls, strict=False)
80-
return MultipartRequest(body=body, files=files)
92+
body = msgspec.convert(string_payload, cls, strict=False)
93+
yield MultipartRequest(body=body, files=files)
8194

8295

8396
@runtime_checkable

marimo/_server/files/os_file_system.py

Lines changed: 92 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import re
99
import shutil
1010
import subprocess
11+
import tempfile
1112
from collections import deque
1213
from pathlib import Path
13-
from typing import Literal
14+
from typing import Literal, Protocol
1415

1516
from marimo import _loggers
1617
from marimo._server.files.file_system import FileSystem
@@ -36,6 +37,25 @@
3637
"..",
3738
]
3839

40+
# 1 MiB. Large enough to amortize syscall overhead, small enough to keep
41+
# peak memory bounded when streaming.
42+
_STREAM_CHUNK_SIZE = 1024 * 1024
43+
44+
# Hard cap on streamed uploads. Streaming removes the implicit OOM ceiling
45+
# that buffered uploads had, so without a cap an authenticated client could
46+
# exhaust disk. 1 GiB covers normal notebook-data use cases with margin.
47+
MAX_UPLOAD_BYTES = 1024 * 1024 * 1024
48+
49+
50+
class AsyncByteSource(Protocol):
51+
"""Anything that can be drained chunk-by-chunk into a file.
52+
53+
Starlette's `UploadFile` satisfies this; so does any object exposing
54+
an async `read(size)` returning bytes.
55+
"""
56+
57+
async def read(self, size: int = -1, /) -> bytes: ...
58+
3959

4060
class OSFileSystem(FileSystem):
4161
def get_root(self) -> str:
@@ -133,33 +153,32 @@ def open_file(self, path: str, encoding: str | None = None) -> str | bytes:
133153
except UnicodeDecodeError:
134154
return file_path.read_bytes()
135155

136-
def create_file_or_directory(
137-
self,
138-
path: str,
139-
file_type: Literal["file", "directory", "notebook"],
140-
name: str,
141-
contents: bytes | None,
142-
) -> FileInfo:
156+
@staticmethod
157+
def _validate_create_name(name: str) -> None:
158+
"""Reject names that are empty, reserved, or traverse out of the
159+
parent. Centralized so HTTP, WASM, and streaming paths all share it.
160+
"""
143161
if name in DISALLOWED_NAMES:
144162
raise ValueError(
145163
f"Cannot create file or directory with name {name}"
146164
)
147165
if name.strip() == "":
148166
raise ValueError("Cannot create file or directory with empty name")
149-
# Names that traverse out of `path` or escape via separators are
150-
# rejected. Validation belongs here (not in the endpoint) so every
151-
# caller of OSFileSystem — HTTP, WASM bridge, scripts — is covered.
152-
if (
153-
"/" in name
154-
or "\\" in name
155-
or "\x00" in name
156-
or name in (".", "..")
157-
):
167+
if "/" in name or "\\" in name or "\x00" in name:
158168
raise ValueError(
159169
f"Invalid name {name!r}: must not contain path separators "
160170
"or refer to a parent directory"
161171
)
162172

173+
def create_file_or_directory(
174+
self,
175+
path: str,
176+
file_type: Literal["file", "directory", "notebook"],
177+
name: str,
178+
contents: bytes | None,
179+
) -> FileInfo:
180+
self._validate_create_name(name)
181+
163182
full_path = Path(path) / name
164183
full_path = _generate_unique_path(full_path)
165184

@@ -192,6 +211,62 @@ def create_file_or_directory(
192211
),
193212
).file
194213

214+
async def stream_create_file(
215+
self,
216+
path: str,
217+
name: str,
218+
source: AsyncByteSource,
219+
) -> FileInfo:
220+
"""Stream-write an uploaded file to disk, chunk by chunk.
221+
222+
Avoids loading the full payload into memory (the HTTP multipart
223+
path can otherwise buffer 100 MB at once). Writes to a ``.part``
224+
temp file and atomically renames on success so a failed upload
225+
doesn't leave a half-written file at the final path.
226+
"""
227+
self._validate_create_name(name)
228+
229+
full_path = Path(path) / name
230+
full_path = _generate_unique_path(full_path)
231+
full_path.parent.mkdir(parents=True, exist_ok=True)
232+
233+
# `NamedTemporaryFile` gives us a guaranteed-unique sibling path so
234+
# concurrent uploads racing through `_generate_unique_path` can't
235+
# collide on the same `.part` file.
236+
tmp = tempfile.NamedTemporaryFile(
237+
dir=full_path.parent,
238+
prefix=full_path.name + ".",
239+
suffix=".part",
240+
delete=False,
241+
)
242+
tmp_path = tmp.name
243+
try:
244+
# Sync writes are bounded to ~1 MiB per chunk, with an `await`
245+
# in between; event loop blockage is brief and an async file
246+
# library would only add a dependency for marginal gain.
247+
written = 0
248+
with tmp:
249+
while chunk := await source.read(_STREAM_CHUNK_SIZE):
250+
written += len(chunk)
251+
if written > MAX_UPLOAD_BYTES:
252+
raise ValueError(
253+
f"Upload exceeds maximum size of "
254+
f"{MAX_UPLOAD_BYTES} bytes"
255+
)
256+
tmp.write(chunk)
257+
os.replace(tmp_path, full_path)
258+
except BaseException:
259+
try:
260+
os.unlink(tmp_path)
261+
except FileNotFoundError:
262+
pass
263+
raise
264+
265+
# Use the metadata-only helper: `get_details` would re-read the
266+
# file contents (and base64-encode binary), defeating the point of
267+
# streaming for large uploads.
268+
return self._get_file_info(str(full_path))
269+
195270
def delete_file_or_directory(self, path: str) -> bool:
196271
if os.path.isdir(path):
197272
safe_rmtree(path)

marimo/_server/models/files.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,13 @@ class FileCreateRequest(msgspec.Struct, rename="camel"):
5858

5959

6060
class FileCreateMultipartRequest(msgspec.Struct, rename="camel"):
61-
"""multipart/form-data body for POST /api/files/create."""
61+
"""multipart/form-data body for POST /api/files/create.
62+
63+
Schema-only: this struct exists to describe the multipart shape in
64+
OpenAPI. At runtime, the endpoint reads the string fields from
65+
``MultipartRequest.body`` and the uploaded bytes from
66+
``MultipartRequest.files["file"]`` — ``body.file`` is never populated.
67+
"""
6268

6369
# The path where to create the file or directory
6470
path: str
@@ -68,6 +74,9 @@ class FileCreateMultipartRequest(msgspec.Struct, rename="camel"):
6874
name: str
6975
# The raw file bytes (optional). When omitted, an empty file is created
7076
# (or, for 'notebook' type, a default notebook template).
77+
# NOTE: this field is OpenAPI-only — see class docstring. The
78+
# ``format: binary`` annotation makes the generated spec emit a proper
79+
# file-upload schema rather than a base64 string.
7180
file: Annotated[
7281
str | None, msgspec.Meta(extra_json_schema={"format": "binary"})
7382
] = None

packages/openapi/api.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,11 @@ components:
16651665
title: FileCopyResponse
16661666
type: object
16671667
FileCreateMultipartRequest:
1668-
description: multipart/form-data body for POST /api/files/create.
1668+
description: "multipart/form-data body for POST /api/files/create.\n\n Schema-only:\
1669+
\ this struct exists to describe the multipart shape in\n OpenAPI. At runtime,\
1670+
\ the endpoint reads the string fields from\n ``MultipartRequest.body``\
1671+
\ and the uploaded bytes from\n ``MultipartRequest.files[\"file\"]`` \u2014\
1672+
\ ``body.file`` is never populated."
16691673
properties:
16701674
file:
16711675
anyOf:

packages/openapi/src/api.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4359,6 +4359,11 @@ export interface components {
43594359
/**
43604360
* FileCreateMultipartRequest
43614361
* @description multipart/form-data body for POST /api/files/create.
4362+
*
4363+
* Schema-only: this struct exists to describe the multipart shape in
4364+
* OpenAPI. At runtime, the endpoint reads the string fields from
4365+
* ``MultipartRequest.body`` and the uploaded bytes from
4366+
* ``MultipartRequest.files["file"]`` — ``body.file`` is never populated.
43624367
*/
43634368
FileCreateMultipartRequest: {
43644369
/**

tests/_server/api/test_api_utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,19 @@ class _SampleForm(msgspec.Struct):
2323

2424
def _build_app(captured: dict[str, object]) -> TestClient:
2525
async def endpoint(request: Request) -> JSONResponse:
26-
parsed = await parse_multipart_request(request, _SampleForm)
27-
captured["body"] = parsed.body
28-
captured["files"] = parsed.files
26+
async with parse_multipart_request(request, _SampleForm) as parsed:
27+
captured["body"] = parsed.body
28+
captured["files"] = dict(parsed.files)
29+
upload = parsed.files.get("upload")
30+
if upload is not None:
31+
captured["upload_bytes"] = await upload.read()
2932
return JSONResponse({"ok": True})
3033

3134
app = Starlette(routes=[Route("/test", endpoint, methods=["POST"])])
3235
return TestClient(app)
3336

3437

35-
def test_parse_multipart_request_strings_and_file_bytes() -> None:
38+
def test_parse_multipart_request_strings_and_file_upload() -> None:
3639
captured: dict[str, object] = {}
3740
client = _build_app(captured)
3841
response = client.post(
@@ -45,7 +48,11 @@ def test_parse_multipart_request_strings_and_file_bytes() -> None:
4548
assert isinstance(body, _SampleForm)
4649
assert body.name == "marimo"
4750
assert body.count == 42
48-
assert captured["files"] == {"upload": b"\x00\x01\x02\xff"}
51+
files = captured["files"]
52+
assert isinstance(files, dict)
53+
assert set(files.keys()) == {"upload"}
54+
# File handle is returned un-read; bytes loaded inside the endpoint above.
55+
assert captured["upload_bytes"] == b"\x00\x01\x02\xff"
4956

5057

5158
def test_parse_multipart_request_omitted_file_yields_empty_dict() -> None:

0 commit comments

Comments
 (0)