Skip to content

Commit 8c093f4

Browse files
committed
Implement energon prepare with a remote dataset (using local temp)
1 parent dedcd0f commit 8c093f4

9 files changed

Lines changed: 546 additions & 209 deletions

File tree

src/megatron/energon/epathlib/epath.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,27 @@ def mkdir(self, exist_ok: bool = True, parents: bool = False):
272272

273273
def glob(self, pattern) -> Generator["EPath", None, None]:
274274
search_path_pattern = (self / pattern)._internal_str_path
275+
# MSC S3 glob matches keys like ``bucket/key``; a leading ``/`` breaks wcmatch (pattern
276+
# ``/b/**`` never matches ``b/parts/x``). Returned keys may repeat the bucket prefix; strip
277+
# it before joining with ``internal_path`` so we do not get ``/b/b/parts/...``.
278+
if not self.is_local() and search_path_pattern.startswith("/"):
279+
search_path_pattern = search_path_pattern.lstrip("/")
280+
281+
root_prefix = str(self.internal_path).lstrip("/")
275282

276283
for path in self.fs.glob(search_path_pattern):
277284
assert isinstance(path, str)
278285

286+
rel = path
287+
if not self.is_local() and root_prefix:
288+
pfx = root_prefix + "/"
289+
if rel.startswith(pfx):
290+
rel = rel[len(pfx) :]
291+
elif rel == root_prefix:
292+
rel = "."
293+
279294
new_path = EPath(self)
280-
new_path.internal_path = self._resolve(self.internal_path / PurePosixPath(path))
295+
new_path.internal_path = self._resolve(self.internal_path / PurePosixPath(rel))
281296

282297
yield new_path
283298

src/megatron/energon/flavors/webdataset/indexing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def __init__(
6464

6565
# Initialize SQLite connection
6666
# Only supporting local file system, because sqlite does not support remote file systems.
67-
# TODO: Implement remote file systems. Maybe create locally in tmp then upload?
6867
path = self.sqlite_path.local_path()
6968
path.parent.mkdir(parents=True, exist_ok=True)
7069
self.db = sqlite3.connect(path)

src/megatron/energon/flavors/webdataset/prepare.py

Lines changed: 212 additions & 165 deletions
Large diffs are not rendered by default.

src/megatron/energon/tools/prepare.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import re
88
import typing
9+
from pathlib import Path
910
from types import FunctionType
1011
from typing import Any, List, Optional, Tuple, Type
1112

@@ -184,6 +185,15 @@ def printify_json(data: Any) -> Any:
184185
help='Field mapping in JSON format (e.g., \'{"image": "jpg", "caption": "txt"}\'). If not set in non-interactive mode, a sample loader from template will be created. Use with --sample-type. Only applies if sample_type is not set to CrudeWebdataset.',
185186
default=None,
186187
)
188+
@click.option(
189+
"--tmp-path",
190+
type=click.Path(path_type=Path),
191+
default=None,
192+
help=(
193+
"When PATH is remote (e.g. S3), directory where index.sqlite is built locally before upload. "
194+
"Default: a new directory under /tmp (removed after success)."
195+
),
196+
)
187197
def command(
188198
path: EPath,
189199
progress: bool,
@@ -203,6 +213,7 @@ def command(
203213
field_map: Optional[str],
204214
skip_dataset_yaml: bool,
205215
dataset_yaml_name: str,
216+
tmp_path: Optional[Path],
206217
):
207218
"""Prepare WebDataset for use with energon.
208219
@@ -348,6 +359,7 @@ def progress_fn(els, length=None):
348359
workers=num_workers,
349360
media_filter=media_filter_config,
350361
fix_duplicates=fix_duplicates,
362+
index_sqlite_tmp_path=tmp_path,
351363
)
352364

353365
found_types = list(found_types)

src/megatron/energon/tools/prepare_media.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44

5+
from pathlib import Path
6+
from typing import Optional
7+
58
import click
69

710
from megatron.energon.epathlib import EPath
@@ -44,13 +47,23 @@
4447
is_flag=True,
4548
help="Media detection by standard file extensions.",
4649
)
50+
@click.option(
51+
"--tmp-path",
52+
type=click.Path(path_type=Path),
53+
default=None,
54+
help=(
55+
"When PATH is remote, directory where index.sqlite is built locally before upload. "
56+
"Default: a new directory under /tmp (removed after success)."
57+
),
58+
)
4759
def command(
4860
path: EPath,
4961
progress: bool,
5062
num_workers: int,
5163
media_metadata_by_glob: str | None,
5264
media_metadata_by_header: bool,
5365
media_metadata_by_extension: bool,
66+
tmp_path: Optional[Path],
5467
):
5568
"""Prepare a filesystem dataset by collecting media metadata."""
5669

@@ -83,6 +96,7 @@ def progress_fn(els, length=None):
8396
media_filter=media_filter_config,
8497
workers=num_workers,
8598
progress_fn=progress_fn,
99+
index_sqlite_tmp_path=tmp_path,
86100
)
87101

88102
click.echo(f"Done. Stored metadata for {count} files.")

tests/s3_emulator/handler.py

Lines changed: 154 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from hashlib import md5
77
from http import HTTPStatus
88
from http.server import BaseHTTPRequestHandler
9-
from typing import Protocol
9+
from typing import Literal, Protocol
1010

1111
from .auth import InvalidSignature, S3Auth
1212
from .state import S3State
@@ -112,6 +112,37 @@ def _handle_write(self):
112112

113113
qs = _up.parse_qs(parsed.query, keep_blank_values=True)
114114

115+
# S3 CopyObject: PUT to destination key with x-amz-copy-source and empty body.
116+
copy_src = (
117+
self.headers.get("x-amz-copy-source") or self.headers.get("X-Amz-Copy-Source") or ""
118+
).strip()
119+
if copy_src:
120+
if not bucket:
121+
self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified")
122+
return
123+
if key == "":
124+
self._send_error(HTTPStatus.BAD_REQUEST, "CopyObject requires an object key")
125+
return
126+
try:
127+
src_bucket, src_key = _parse_copy_source(copy_src)
128+
except ValueError as err:
129+
self._send_error(HTTPStatus.BAD_REQUEST, str(err))
130+
return
131+
try:
132+
data = self.server.state.copy_object(bucket, key, src_bucket, src_key)
133+
except FileNotFoundError:
134+
self._send_error(HTTPStatus.NOT_FOUND, "NoSuchKey")
135+
return
136+
xml = (
137+
'<?xml version="1.0" encoding="UTF-8"?>'
138+
"<CopyObjectResult>"
139+
f"<LastModified>{_escape_xml(formatdate(usegmt=True))}</LastModified>"
140+
f"<ETag>&quot;{_escape_xml(_etag(data))}&quot;</ETag>"
141+
"</CopyObjectResult>"
142+
).encode()
143+
self._send_bytes(xml, status=HTTPStatus.OK, content_type="application/xml")
144+
return
145+
115146
# Multipart: upload part
116147
if "uploadId" in qs and "partNumber" in qs:
117148
upload_id = qs["uploadId"][0]
@@ -160,15 +191,15 @@ def _handle_read(self, listing: bool, only_headers: bool = False):
160191
self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified")
161192
return
162193

163-
if key == "": # List bucket contents
194+
if key == "": # List bucket contents (ListObjects / ListObjectsV2)
164195
if not listing:
165-
# We treat listing with GET only
166196
try:
167197
objects = self.server.state.list_objects(bucket)
168198
except KeyError:
169199
self._send_error(HTTPStatus.NOT_FOUND, "Bucket not found")
170200
return
171-
xml_body = self._render_bucket_list(bucket, objects)
201+
qs = _up.parse_qs(parsed.query, keep_blank_values=True)
202+
xml_body = self._render_list_bucket_result(bucket, objects, qs)
172203
self._send_bytes(xml_body, content_type="application/xml")
173204
else:
174205
self._send_error(HTTPStatus.NOT_IMPLEMENTED, "Listing not implemented")
@@ -359,51 +390,135 @@ def _send_bytes(
359390
if self.command != "HEAD":
360391
self.wfile.write(data)
361392

362-
@staticmethod
363-
def _render_bucket_list(bucket: str, objects: list[str]) -> bytes:
364-
"""Generate an XML listing of objects in a bucket.
393+
def _render_list_bucket_result(
394+
self,
395+
bucket: str,
396+
all_keys: list[str],
397+
qs: dict[str, list[str]],
398+
) -> bytes:
399+
"""Build ListBucketResult XML (ListObjectsV2-compatible).
400+
401+
Clients (e.g. MSC) send ``delimiter=/`` and ``prefix=`` and expect
402+
``CommonPrefixes`` for nested keys such as ``parts/data-0.tar``, not
403+
only flat ``Contents``.
404+
"""
405+
prefix = (qs.get("prefix") or [""])[0]
406+
delimiter = (qs.get("delimiter") or [None])[0]
407+
max_keys_s = (qs.get("max-keys") or qs.get("maxkeys") or ["1000"])[0]
408+
try:
409+
max_keys = max(1, min(int(max_keys_s), 1000))
410+
except ValueError:
411+
max_keys = 1000
365412

366-
Args:
367-
bucket: The bucket name.
368-
objects: List of object keys in the bucket.
413+
continuation = (qs.get("continuation-token") or [""])[0]
414+
start_after = (qs.get("start-after") or [""])[0]
415+
exclusive_after = continuation or start_after
369416

370-
Returns:
371-
The XML document as bytes.
372-
"""
373-
entries = []
374417
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
375-
for key in objects:
376-
try:
377-
data = S3RequestHandler.server.state.get_object(bucket, key) # type: ignore[attr-defined]
378-
size = len(data)
379-
etag = _etag(data)
380-
except Exception: # noqa: BLE001
381-
size = 0
382-
etag = '""'
383-
entries.append(
384-
"<Contents>"
385-
f"<Key>{_escape_xml(key)}</Key>"
386-
f"<LastModified>{now}</LastModified>"
387-
f"<ETag>{etag}</ETag>"
388-
f"<Size>{size}</Size>"
389-
"</Contents>"
390-
)
391-
obj_elems = "".join(entries)
392-
xml = (
393-
'<?xml version="1.0" encoding="UTF-8"?>'
394-
"<ListBucketResult>"
395-
f"<Name>{_escape_xml(bucket)}</Name>"
396-
f"{obj_elems}"
397-
"</ListBucketResult>"
398-
)
399-
return xml.encode()
418+
state = self.server.state
419+
420+
items: list[tuple[Literal["cp", "key"], str]] = []
421+
if not delimiter:
422+
for k in sorted(all_keys):
423+
if k.startswith(prefix):
424+
items.append(("key", k))
425+
else:
426+
common: set[str] = set()
427+
contents: list[str] = []
428+
for k in sorted(all_keys):
429+
if not k.startswith(prefix):
430+
continue
431+
relative = k[len(prefix) :]
432+
if delimiter in relative:
433+
idx = relative.index(delimiter)
434+
common.add(prefix + relative[: idx + len(delimiter)])
435+
else:
436+
contents.append(k)
437+
for cp in sorted(common):
438+
items.append(("cp", cp))
439+
for ck in sorted(contents):
440+
items.append(("key", ck))
441+
items.sort(key=lambda x: x[1])
442+
443+
if exclusive_after:
444+
items = [it for it in items if it[1] > exclusive_after]
445+
446+
page = items[:max_keys]
447+
truncated = len(items) > max_keys
448+
next_token = page[-1][1] if truncated and page else ""
449+
450+
fragments: list[str] = [
451+
'<?xml version="1.0" encoding="UTF-8"?>',
452+
"<ListBucketResult>",
453+
f"<Name>{_escape_xml(bucket)}</Name>",
454+
f"<Prefix>{_escape_xml(prefix)}</Prefix>",
455+
f"<KeyCount>{len(page)}</KeyCount>",
456+
f"<MaxKeys>{max_keys}</MaxKeys>",
457+
f"<IsTruncated>{str(truncated).lower()}</IsTruncated>",
458+
]
459+
if delimiter:
460+
fragments.append(f"<Delimiter>{_escape_xml(delimiter)}</Delimiter>")
461+
if truncated and next_token:
462+
fragments.append(f"<NextContinuationToken>{_escape_xml(next_token)}</NextContinuationToken>")
463+
464+
for kind, path in page:
465+
if kind == "cp":
466+
fragments.append(f"<CommonPrefixes><Prefix>{_escape_xml(path)}</Prefix></CommonPrefixes>")
467+
else:
468+
try:
469+
data = state.get_object(bucket, path)
470+
size = len(data)
471+
etag = _etag(data)
472+
except Exception: # noqa: BLE001
473+
size = 0
474+
etag = '""'
475+
fragments.append(
476+
"<Contents>"
477+
f"<Key>{_escape_xml(path)}</Key>"
478+
f"<LastModified>{now}</LastModified>"
479+
f"<ETag>{etag}</ETag>"
480+
f"<Size>{size}</Size>"
481+
"</Contents>"
482+
)
483+
484+
fragments.append("</ListBucketResult>")
485+
return "".join(fragments).encode()
400486

401487

402488
class S3ServerProtocol(Protocol): # noqa: D101
403489
state: S3State
404490
auth: S3Auth
405491

406492

493+
def _parse_copy_source(raw: str) -> tuple[str, str]:
494+
"""Parse ``x-amz-copy-source`` into ``(bucket, key)``.
495+
496+
Accepts ``/bucket/key``, ``bucket/key``, URL-encoded keys, and strips ``?versionId=``.
497+
498+
Args:
499+
raw: Raw header value.
500+
501+
Returns:
502+
Source bucket and object key.
503+
504+
Raises:
505+
ValueError: If the value cannot be parsed.
506+
"""
507+
s = raw.strip()
508+
if not s:
509+
raise ValueError("Empty x-amz-copy-source")
510+
s = s.split("?", 1)[0]
511+
s = _up.unquote(s)
512+
if s.startswith("/"):
513+
s = s[1:]
514+
if "/" not in s:
515+
raise ValueError("x-amz-copy-source must be /bucket/key")
516+
src_bucket, src_key = s.split("/", 1)
517+
if not src_bucket or not src_key:
518+
raise ValueError("Invalid x-amz-copy-source")
519+
return src_bucket, src_key
520+
521+
407522
def _escape_xml(text: str) -> str: # noqa: D401
408523
"""Escape special characters for XML.
409524

tests/s3_emulator/state.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,35 @@ def get_object(self, bucket: str, key: str) -> bytes:
111111
except KeyError as exc:
112112
raise FileNotFoundError(f"{bucket}/{key}") from exc
113113

114+
def copy_object(self, dest_bucket: str, dest_key: str, src_bucket: str, src_key: str) -> bytes:
115+
"""Copy an object to another key (S3 CopyObject).
116+
117+
Args:
118+
dest_bucket: Destination bucket.
119+
dest_key: Destination object key.
120+
src_bucket: Source bucket.
121+
src_key: Source object key.
122+
123+
Returns:
124+
Copied object bytes (for ETag in the CopyObject XML response).
125+
126+
Raises:
127+
FileNotFoundError: If the source object does not exist.
128+
"""
129+
with self._lock:
130+
try:
131+
payload = bytes(self._fs[src_bucket][src_key])
132+
except KeyError as exc:
133+
raise FileNotFoundError(f"{src_bucket}/{src_key}") from exc
134+
if dest_bucket not in self._fs:
135+
self._fs[dest_bucket] = {}
136+
self._fs[dest_bucket][dest_key] = payload
137+
if self._root_dir is not None:
138+
obj_path = (self._root_dir / dest_bucket / dest_key).resolve()
139+
obj_path.parent.mkdir(parents=True, exist_ok=True)
140+
obj_path.write_bytes(payload)
141+
return payload
142+
114143
def delete_object(self, bucket: str, key: str) -> None:
115144
"""Delete an object from a bucket.
116145

0 commit comments

Comments
 (0)