|
8 | 8 | import re |
9 | 9 | import shutil |
10 | 10 | import subprocess |
| 11 | +import tempfile |
11 | 12 | from collections import deque |
12 | 13 | from pathlib import Path |
13 | | -from typing import Literal |
| 14 | +from typing import Literal, Protocol |
14 | 15 |
|
15 | 16 | from marimo import _loggers |
16 | 17 | from marimo._server.files.file_system import FileSystem |
|
36 | 37 | "..", |
37 | 38 | ] |
38 | 39 |
|
| 40 | +# 1 MiB. Large enough to amortize syscall overhead, small enough to keep |
| 41 | +# peak memory bounded when streaming. |
| 42 | +_STREAM_CHUNK_SIZE = 1024 * 1024 |
| 43 | + |
| 44 | +# Hard cap on streamed uploads. Streaming removes the implicit OOM ceiling |
| 45 | +# that buffered uploads had, so without a cap an authenticated client could |
| 46 | +# exhaust disk. 1 GiB covers normal notebook-data use cases with margin. |
| 47 | +MAX_UPLOAD_BYTES = 1024 * 1024 * 1024 |
| 48 | + |
| 49 | + |
| 50 | +class AsyncByteSource(Protocol): |
| 51 | + """Anything that can be drained chunk-by-chunk into a file. |
| 52 | +
|
| 53 | + Starlette's `UploadFile` satisfies this; so does any object exposing |
| 54 | + an async `read(size)` returning bytes. |
| 55 | + """ |
| 56 | + |
| 57 | + async def read(self, size: int = -1, /) -> bytes: ... |
| 58 | + |
39 | 59 |
|
40 | 60 | class OSFileSystem(FileSystem): |
41 | 61 | def get_root(self) -> str: |
@@ -133,33 +153,32 @@ def open_file(self, path: str, encoding: str | None = None) -> str | bytes: |
133 | 153 | except UnicodeDecodeError: |
134 | 154 | return file_path.read_bytes() |
135 | 155 |
|
136 | | - def create_file_or_directory( |
137 | | - self, |
138 | | - path: str, |
139 | | - file_type: Literal["file", "directory", "notebook"], |
140 | | - name: str, |
141 | | - contents: bytes | None, |
142 | | - ) -> FileInfo: |
| 156 | + @staticmethod |
| 157 | + def _validate_create_name(name: str) -> None: |
| 158 | + """Reject names that are empty, reserved, or traverse out of the |
| 159 | + parent. Centralized so HTTP, WASM, and streaming paths all share it. |
| 160 | + """ |
143 | 161 | if name in DISALLOWED_NAMES: |
144 | 162 | raise ValueError( |
145 | 163 | f"Cannot create file or directory with name {name}" |
146 | 164 | ) |
147 | 165 | if name.strip() == "": |
148 | 166 | raise ValueError("Cannot create file or directory with empty name") |
149 | | - # Names that traverse out of `path` or escape via separators are |
150 | | - # rejected. Validation belongs here (not in the endpoint) so every |
151 | | - # caller of OSFileSystem — HTTP, WASM bridge, scripts — is covered. |
152 | | - if ( |
153 | | - "/" in name |
154 | | - or "\\" in name |
155 | | - or "\x00" in name |
156 | | - or name in (".", "..") |
157 | | - ): |
| 167 | + if "/" in name or "\\" in name or "\x00" in name: |
158 | 168 | raise ValueError( |
159 | 169 | f"Invalid name {name!r}: must not contain path separators " |
160 | 170 | "or refer to a parent directory" |
161 | 171 | ) |
162 | 172 |
|
| 173 | + def create_file_or_directory( |
| 174 | + self, |
| 175 | + path: str, |
| 176 | + file_type: Literal["file", "directory", "notebook"], |
| 177 | + name: str, |
| 178 | + contents: bytes | None, |
| 179 | + ) -> FileInfo: |
| 180 | + self._validate_create_name(name) |
| 181 | + |
163 | 182 | full_path = Path(path) / name |
164 | 183 | full_path = _generate_unique_path(full_path) |
165 | 184 |
|
@@ -192,6 +211,62 @@ def create_file_or_directory( |
192 | 211 | ), |
193 | 212 | ).file |
194 | 213 |
|
| 214 | + async def stream_create_file( |
| 215 | + self, |
| 216 | + path: str, |
| 217 | + name: str, |
| 218 | + source: AsyncByteSource, |
| 219 | + ) -> FileInfo: |
| 220 | + """Stream-write an uploaded file to disk, chunk by chunk. |
| 221 | +
|
| 222 | + Avoids loading the full payload into memory (the HTTP multipart |
| 223 | + path can otherwise buffer 100 MB at once). Writes to a ``.part`` |
| 224 | + temp file and atomically renames on success so a failed upload |
| 225 | + doesn't leave a half-written file at the final path. |
| 226 | + """ |
| 227 | + self._validate_create_name(name) |
| 228 | + |
| 229 | + full_path = Path(path) / name |
| 230 | + full_path = _generate_unique_path(full_path) |
| 231 | + full_path.parent.mkdir(parents=True, exist_ok=True) |
| 232 | + |
| 233 | + # `NamedTemporaryFile` gives us a guaranteed-unique sibling path so |
| 234 | + # concurrent uploads racing through `_generate_unique_path` can't |
| 235 | + # collide on the same `.part` file. |
| 236 | + tmp = tempfile.NamedTemporaryFile( |
| 237 | + dir=full_path.parent, |
| 238 | + prefix=full_path.name + ".", |
| 239 | + suffix=".part", |
| 240 | + delete=False, |
| 241 | + ) |
| 242 | + tmp_path = tmp.name |
| 243 | + try: |
| 244 | + # Sync writes are bounded to ~1 MiB per chunk, with an `await` |
| 245 | + # in between; event loop blockage is brief and an async file |
| 246 | + # library would only add a dependency for marginal gain. |
| 247 | + written = 0 |
| 248 | + with tmp: |
| 249 | + while chunk := await source.read(_STREAM_CHUNK_SIZE): |
| 250 | + written += len(chunk) |
| 251 | + if written > MAX_UPLOAD_BYTES: |
| 252 | + raise ValueError( |
| 253 | + f"Upload exceeds maximum size of " |
| 254 | + f"{MAX_UPLOAD_BYTES} bytes" |
| 255 | + ) |
| 256 | + tmp.write(chunk) |
| 257 | + os.replace(tmp_path, full_path) |
| 258 | + except BaseException: |
| 259 | + try: |
| 260 | + os.unlink(tmp_path) |
| 261 | + except FileNotFoundError: |
| 262 | + pass |
| 263 | + raise |
| 264 | + |
| 265 | + # Use the metadata-only helper: `get_details` would re-read the |
| 266 | + # file contents (and base64-encode binary), defeating the point of |
| 267 | + # streaming for large uploads. |
| 268 | + return self._get_file_info(str(full_path)) |
| 269 | + |
195 | 270 | def delete_file_or_directory(self, path: str) -> bool: |
196 | 271 | if os.path.isdir(path): |
197 | 272 | safe_rmtree(path) |
|
0 commit comments