Skip to content

Commit 2612cf9

Browse files
committed
feat(files): stream uploads to disk instead of buffering
Previously /api/files/create read the entire UploadFile into memory before writing to disk, peaking at ~100 MB for a 100 MB upload. Now we drain the upload in 1 MiB chunks straight to a `.part` temp file and atomically rename on success, so peak memory stays bounded regardless of file size and a failed upload never leaves a half-written file at the final path. - `OSFileSystem.stream_create_file` does the chunked write with atomic rename and cleanup on failure - `parse_multipart_request` now hands callers un-read `UploadFile` handles (instead of pre-read bytes), so streaming is possible without giving up the small-payload `.read()` path - Directories and the default-template notebook still go through the in-memory `create_file_or_directory` path; only real file content streams
1 parent 16d224b commit 2612cf9

8 files changed

Lines changed: 230 additions & 42 deletions

File tree

marimo/_server/api/endpoints/file_explorer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,17 @@ async def create_file_or_directory(
123123
parsed = await parse_multipart_request(
124124
request, FileCreateMultipartRequest
125125
)
126-
info = file_system.create_file_or_directory(
127-
parsed.body.path,
128-
parsed.body.type,
129-
parsed.body.name,
130-
parsed.files.get("file"),
131-
)
126+
upload = parsed.files.get("file")
127+
# Stream when there's actual file content; the in-memory create
128+
# path still handles directories and the default-template notebook.
129+
if upload is not None and parsed.body.type in ("file", "notebook"):
130+
info = await file_system.stream_create_file(
131+
parsed.body.path, parsed.body.name, upload
132+
)
133+
else:
134+
info = file_system.create_file_or_directory(
135+
parsed.body.path, parsed.body.type, parsed.body.name, None
136+
)
132137
return FileCreateResponse(success=True, info=info)
133138
except Exception as e:
134139
LOGGER.error(f"Error creating file or directory: {e}")

marimo/_server/api/utils.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from marimo._utils.parse_dataclass import parse_raw
2626

2727
if TYPE_CHECKING:
28+
from starlette.datastructures import UploadFile
2829
from starlette.requests import Request
2930

3031
from marimo._session.session import Session
@@ -44,20 +45,25 @@ async def parse_request(
4445

4546
@dataclass
4647
class MultipartRequest(Generic[S]):
47-
"""Result of parsing a multipart/form-data request body."""
48+
"""Result of parsing a multipart/form-data request body.
49+
50+
`files` carries the raw `UploadFile` objects (not yet read) so callers
51+
can stream chunks via `.read(size)`. For small payloads, callers can
52+
simply `await upload.read()` to materialize the whole body.
53+
"""
4854

4955
body: S
50-
files: dict[str, bytes]
56+
files: dict[str, UploadFile]
5157

5258

5359
async def parse_multipart_request(
5460
request: Request, cls: type[S]
5561
) -> MultipartRequest[S]:
56-
"""Parse a multipart/form-data body into a msgspec.Struct + file bytes.
62+
"""Parse a multipart/form-data body into a msgspec.Struct + uploads.
5763
5864
String form fields are validated against `cls`. File upload parts are
59-
read fully into memory and returned in `files`, keyed by form-field
60-
name (callers look them up explicitly rather than via the struct).
65+
returned as `UploadFile` objects (un-read) in `files`, keyed by
66+
form-field name, so callers can stream them to disk instead of buffering.
6167
6268
Raises msgspec.ValidationError if required string fields are missing
6369
or invalid.
@@ -66,16 +72,17 @@ async def parse_multipart_request(
6672
# without starlette (e.g. pyodide).
6773
from starlette.datastructures import UploadFile
6874

69-
# Use as an async context manager so any spooled temp files backing
70-
# UploadFile parts are closed after parsing.
71-
async with request.form() as form:
72-
string_payload: dict[str, Any] = {}
73-
files: dict[str, bytes] = {}
74-
for key, value in form.multi_items():
75-
if isinstance(value, UploadFile):
76-
files[key] = await value.read()
77-
elif isinstance(value, str):
78-
string_payload[key] = value
75+
# Do NOT use `async with request.form()` here: callers stream the
76+
# UploadFile parts after this function returns, and the context
77+
# manager would close their spooled temp files on exit.
78+
form = await request.form()
79+
string_payload: dict[str, Any] = {}
80+
files: dict[str, UploadFile] = {}
81+
for key, value in form.multi_items():
82+
if isinstance(value, UploadFile):
83+
files[key] = value
84+
elif isinstance(value, str):
85+
string_payload[key] = value
7986
body = msgspec.convert(string_payload, cls, strict=False)
8087
return MultipartRequest(body=body, files=files)
8188

marimo/_server/files/os_file_system.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import subprocess
1111
from collections import deque
1212
from pathlib import Path
13-
from typing import Literal
13+
from typing import Literal, Protocol, runtime_checkable
1414

1515
from marimo import _loggers
1616
from marimo._server.files.file_system import FileSystem
@@ -36,6 +36,26 @@
3636
"..",
3737
]
3838

39+
# 1 MiB. Large enough to amortize syscall overhead, small enough to keep
40+
# peak memory bounded when streaming.
41+
_STREAM_CHUNK_SIZE = 1024 * 1024
42+
43+
# Hard cap on streamed uploads. Streaming removes the implicit OOM ceiling
44+
# that buffered uploads had, so without a cap an authenticated client could
45+
# exhaust disk. 1 GiB covers normal notebook-data use cases with margin.
46+
MAX_UPLOAD_BYTES = 1024 * 1024 * 1024
47+
48+
49+
@runtime_checkable
50+
class AsyncByteSource(Protocol):
51+
"""Anything that can be drained chunk-by-chunk into a file.
52+
53+
Starlette's `UploadFile` satisfies this; so does any object exposing
54+
an async `read(size)` returning bytes.
55+
"""
56+
57+
async def read(self, size: int = -1, /) -> bytes: ...
58+
3959

4060
class OSFileSystem(FileSystem):
4161
def get_root(self) -> str:
@@ -133,33 +153,32 @@ def open_file(self, path: str, encoding: str | None = None) -> str | bytes:
133153
except UnicodeDecodeError:
134154
return file_path.read_bytes()
135155

136-
def create_file_or_directory(
137-
self,
138-
path: str,
139-
file_type: Literal["file", "directory", "notebook"],
140-
name: str,
141-
contents: bytes | None,
142-
) -> FileInfo:
156+
@staticmethod
157+
def _validate_create_name(name: str) -> None:
158+
"""Reject names that are empty, reserved, or traverse out of the
159+
parent. Centralized so HTTP, WASM, and streaming paths all share it.
160+
"""
143161
if name in DISALLOWED_NAMES:
144162
raise ValueError(
145163
f"Cannot create file or directory with name {name}"
146164
)
147165
if name.strip() == "":
148166
raise ValueError("Cannot create file or directory with empty name")
149-
# Names that traverse out of `path` or escape via separators are
150-
# rejected. Validation belongs here (not in the endpoint) so every
151-
# caller of OSFileSystem — HTTP, WASM bridge, scripts — is covered.
152-
if (
153-
"/" in name
154-
or "\\" in name
155-
or "\x00" in name
156-
or name in (".", "..")
157-
):
167+
if "/" in name or "\\" in name or "\x00" in name:
158168
raise ValueError(
159169
f"Invalid name {name!r}: must not contain path separators "
160170
"or refer to a parent directory"
161171
)
162172

173+
def create_file_or_directory(
174+
self,
175+
path: str,
176+
file_type: Literal["file", "directory", "notebook"],
177+
name: str,
178+
contents: bytes | None,
179+
) -> FileInfo:
180+
self._validate_create_name(name)
181+
163182
full_path = Path(path) / name
164183
full_path = _generate_unique_path(full_path)
165184

@@ -192,6 +211,49 @@ def create_file_or_directory(
192211
),
193212
).file
194213

214+
async def stream_create_file(
215+
self,
216+
path: str,
217+
name: str,
218+
source: AsyncByteSource,
219+
) -> FileInfo:
220+
"""Stream-write an uploaded file to disk, chunk by chunk.
221+
222+
Avoids loading the full payload into memory (the HTTP multipart
223+
path can otherwise buffer 100 MB at once). Writes to a ``.part``
224+
temp file and atomically renames on success so a failed upload
225+
doesn't leave a half-written file at the final path.
226+
"""
227+
self._validate_create_name(name)
228+
229+
full_path = Path(path) / name
230+
full_path = _generate_unique_path(full_path)
231+
full_path.parent.mkdir(parents=True, exist_ok=True)
232+
233+
tmp_path = full_path.with_name(full_path.name + ".part")
234+
try:
235+
# Sync writes are bounded to ~1 MiB per chunk, with an `await`
236+
# in between; event loop blockage is brief and an async file
237+
# library would only add a dependency for marginal gain.
238+
written = 0
239+
with open(tmp_path, "wb") as out: # noqa: ASYNC230
240+
while chunk := await source.read(_STREAM_CHUNK_SIZE):
241+
written += len(chunk)
242+
if written > MAX_UPLOAD_BYTES:
243+
raise ValueError(
244+
f"Upload exceeds maximum size of "
245+
f"{MAX_UPLOAD_BYTES} bytes"
246+
)
247+
out.write(chunk)
248+
tmp_path.replace(full_path)
249+
except BaseException:
250+
tmp_path.unlink(missing_ok=True)
251+
raise
252+
253+
# Read details fresh from disk; we deliberately don't pass `contents`
254+
# since the file may be too large to round-trip through memory.
255+
return self.get_details(str(full_path)).file
256+
195257
def delete_file_or_directory(self, path: str) -> bool:
196258
if os.path.isdir(path):
197259
safe_rmtree(path)

marimo/_server/models/files.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,13 @@ class FileCreateRequest(msgspec.Struct, rename="camel"):
5858

5959

6060
class FileCreateMultipartRequest(msgspec.Struct, rename="camel"):
61-
"""multipart/form-data body for POST /api/files/create."""
61+
"""multipart/form-data body for POST /api/files/create.
62+
63+
Schema-only: this struct exists to describe the multipart shape in
64+
OpenAPI. At runtime, the endpoint reads the string fields from
65+
``MultipartRequest.body`` and the uploaded bytes from
66+
``MultipartRequest.files["file"]`` — ``body.file`` is never populated.
67+
"""
6268

6369
# The path where to create the file or directory
6470
path: str
@@ -68,6 +74,9 @@ class FileCreateMultipartRequest(msgspec.Struct, rename="camel"):
6874
name: str
6975
# The raw file bytes (optional). When omitted, an empty file is created
7076
# (or, for 'notebook' type, a default notebook template).
77+
# NOTE: this field is OpenAPI-only — see class docstring. The
78+
# ``format: binary`` annotation makes the generated spec emit a proper
79+
# file-upload schema rather than a base64 string.
7180
file: Annotated[
7281
str | None, msgspec.Meta(extra_json_schema={"format": "binary"})
7382
] = None

packages/openapi/api.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,11 @@ components:
16651665
title: FileCopyResponse
16661666
type: object
16671667
FileCreateMultipartRequest:
1668-
description: multipart/form-data body for POST /api/files/create.
1668+
description: "multipart/form-data body for POST /api/files/create.\n\n Schema-only:\
1669+
\ this struct exists to describe the multipart shape in\n OpenAPI. At runtime,\
1670+
\ the endpoint reads the string fields from\n ``MultipartRequest.body``\
1671+
\ and the uploaded bytes from\n ``MultipartRequest.files[\"file\"]`` \u2014\
1672+
\ ``body.file`` is never populated."
16691673
properties:
16701674
file:
16711675
anyOf:

packages/openapi/src/api.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4359,6 +4359,11 @@ export interface components {
43594359
/**
43604360
* FileCreateMultipartRequest
43614361
* @description multipart/form-data body for POST /api/files/create.
4362+
*
4363+
* Schema-only: this struct exists to describe the multipart shape in
4364+
* OpenAPI. At runtime, the endpoint reads the string fields from
4365+
* ``MultipartRequest.body`` and the uploaded bytes from
4366+
* ``MultipartRequest.files["file"]`` — ``body.file`` is never populated.
43624367
*/
43634368
FileCreateMultipartRequest: {
43644369
/**

tests/_server/api/test_api_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,16 @@ async def endpoint(request: Request) -> JSONResponse:
2626
parsed = await parse_multipart_request(request, _SampleForm)
2727
captured["body"] = parsed.body
2828
captured["files"] = parsed.files
29+
upload = parsed.files.get("upload")
30+
if upload is not None:
31+
captured["upload_bytes"] = await upload.read()
2932
return JSONResponse({"ok": True})
3033

3134
app = Starlette(routes=[Route("/test", endpoint, methods=["POST"])])
3235
return TestClient(app)
3336

3437

35-
def test_parse_multipart_request_strings_and_file_bytes() -> None:
38+
def test_parse_multipart_request_strings_and_file_upload() -> None:
3639
captured: dict[str, object] = {}
3740
client = _build_app(captured)
3841
response = client.post(
@@ -45,7 +48,11 @@ def test_parse_multipart_request_strings_and_file_bytes() -> None:
4548
assert isinstance(body, _SampleForm)
4649
assert body.name == "marimo"
4750
assert body.count == 42
48-
assert captured["files"] == {"upload": b"\x00\x01\x02\xff"}
51+
files = captured["files"]
52+
assert isinstance(files, dict)
53+
assert set(files.keys()) == {"upload"}
54+
# File handle is returned un-read; bytes loaded inside the endpoint above.
55+
assert captured["upload_bytes"] == b"\x00\x01\x02\xff"
4956

5057

5158
def test_parse_multipart_request_omitted_file_yields_empty_dict() -> None:

0 commit comments

Comments
 (0)