|
10 | 10 | import subprocess |
11 | 11 | from collections import deque |
12 | 12 | from pathlib import Path |
13 | | -from typing import Literal |
| 13 | +from typing import Literal, Protocol, runtime_checkable |
14 | 14 |
|
15 | 15 | from marimo import _loggers |
16 | 16 | from marimo._server.files.file_system import FileSystem |
|
36 | 36 | "..", |
37 | 37 | ] |
38 | 38 |
|
| 39 | +# 1 MiB. Large enough to amortize syscall overhead, small enough to keep |
| 40 | +# peak memory bounded when streaming. |
| 41 | +_STREAM_CHUNK_SIZE = 1024 * 1024 |
| 42 | + |
| 43 | +# Hard cap on streamed uploads. Streaming removes the implicit OOM ceiling |
| 44 | +# that buffered uploads had, so without a cap an authenticated client could |
| 45 | +# exhaust disk. 1 GiB covers normal notebook-data use cases with margin. |
| 46 | +MAX_UPLOAD_BYTES = 1024 * 1024 * 1024 |
| 47 | + |
| 48 | + |
| 49 | +@runtime_checkable |
| 50 | +class AsyncByteSource(Protocol): |
| 51 | + """Anything that can be drained chunk-by-chunk into a file. |
| 52 | +
|
| 53 | + Starlette's `UploadFile` satisfies this; so does any object exposing |
| 54 | + an async `read(size)` returning bytes. |
| 55 | + """ |
| 56 | + |
| 57 | + async def read(self, size: int = -1, /) -> bytes: ... |
| 58 | + |
39 | 59 |
|
40 | 60 | class OSFileSystem(FileSystem): |
41 | 61 | def get_root(self) -> str: |
@@ -133,33 +153,32 @@ def open_file(self, path: str, encoding: str | None = None) -> str | bytes: |
133 | 153 | except UnicodeDecodeError: |
134 | 154 | return file_path.read_bytes() |
135 | 155 |
|
136 | | - def create_file_or_directory( |
137 | | - self, |
138 | | - path: str, |
139 | | - file_type: Literal["file", "directory", "notebook"], |
140 | | - name: str, |
141 | | - contents: bytes | None, |
142 | | - ) -> FileInfo: |
| 156 | + @staticmethod |
| 157 | + def _validate_create_name(name: str) -> None: |
| 158 | + """Reject names that are empty, reserved, or traverse out of the |
| 159 | + parent. Centralized so HTTP, WASM, and streaming paths all share it. |
| 160 | + """ |
143 | 161 | if name in DISALLOWED_NAMES: |
144 | 162 | raise ValueError( |
145 | 163 | f"Cannot create file or directory with name {name}" |
146 | 164 | ) |
147 | 165 | if name.strip() == "": |
148 | 166 | raise ValueError("Cannot create file or directory with empty name") |
149 | | - # Names that traverse out of `path` or escape via separators are |
150 | | - # rejected. Validation belongs here (not in the endpoint) so every |
151 | | - # caller of OSFileSystem — HTTP, WASM bridge, scripts — is covered. |
152 | | - if ( |
153 | | - "/" in name |
154 | | - or "\\" in name |
155 | | - or "\x00" in name |
156 | | - or name in (".", "..") |
157 | | - ): |
| 167 | + if "/" in name or "\\" in name or "\x00" in name: |
158 | 168 | raise ValueError( |
159 | 169 | f"Invalid name {name!r}: must not contain path separators " |
160 | 170 | "or refer to a parent directory" |
161 | 171 | ) |
162 | 172 |
|
| 173 | + def create_file_or_directory( |
| 174 | + self, |
| 175 | + path: str, |
| 176 | + file_type: Literal["file", "directory", "notebook"], |
| 177 | + name: str, |
| 178 | + contents: bytes | None, |
| 179 | + ) -> FileInfo: |
| 180 | + self._validate_create_name(name) |
| 181 | + |
163 | 182 | full_path = Path(path) / name |
164 | 183 | full_path = _generate_unique_path(full_path) |
165 | 184 |
|
@@ -192,6 +211,49 @@ def create_file_or_directory( |
192 | 211 | ), |
193 | 212 | ).file |
194 | 213 |
|
| 214 | + async def stream_create_file( |
| 215 | + self, |
| 216 | + path: str, |
| 217 | + name: str, |
| 218 | + source: AsyncByteSource, |
| 219 | + ) -> FileInfo: |
| 220 | + """Stream-write an uploaded file to disk, chunk by chunk. |
| 221 | +
|
| 222 | + Avoids loading the full payload into memory (the HTTP multipart |
| 223 | + path can otherwise buffer 100 MB at once). Writes to a ``.part`` |
| 224 | + temp file and atomically renames on success so a failed upload |
| 225 | + doesn't leave a half-written file at the final path. |
| 226 | + """ |
| 227 | + self._validate_create_name(name) |
| 228 | + |
| 229 | + full_path = Path(path) / name |
| 230 | + full_path = _generate_unique_path(full_path) |
| 231 | + full_path.parent.mkdir(parents=True, exist_ok=True) |
| 232 | + |
| 233 | + tmp_path = full_path.with_name(full_path.name + ".part") |
| 234 | + try: |
| 235 | + # Sync writes are bounded to ~1 MiB per chunk, with an `await` |
| 236 | + # in between; event loop blockage is brief and an async file |
| 237 | + # library would only add a dependency for marginal gain. |
| 238 | + written = 0 |
| 239 | + with open(tmp_path, "wb") as out: # noqa: ASYNC230 |
| 240 | + while chunk := await source.read(_STREAM_CHUNK_SIZE): |
| 241 | + written += len(chunk) |
| 242 | + if written > MAX_UPLOAD_BYTES: |
| 243 | + raise ValueError( |
| 244 | + f"Upload exceeds maximum size of " |
| 245 | + f"{MAX_UPLOAD_BYTES} bytes" |
| 246 | + ) |
| 247 | + out.write(chunk) |
| 248 | + tmp_path.replace(full_path) |
| 249 | + except BaseException: |
| 250 | + tmp_path.unlink(missing_ok=True) |
| 251 | + raise |
| 252 | + |
| 253 | + # Read details fresh from disk; we deliberately don't pass `contents` |
| 254 | + # since the file may be too large to round-trip through memory. |
| 255 | + return self.get_details(str(full_path)).file |
| 256 | + |
195 | 257 | def delete_file_or_directory(self, path: str) -> bool: |
196 | 258 | if os.path.isdir(path): |
197 | 259 | safe_rmtree(path) |
|
0 commit comments