Skip to content

Commit ac7dc96

Browse files
committed
Fix: Epath implements walk and glob, including tests. Fix sqlite tmp path handling. Fix S3 emulator timestamp handling
1 parent 598db74 commit ac7dc96

8 files changed

Lines changed: 263 additions & 44 deletions

File tree

src/megatron/energon/epathlib/epath.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -266,17 +266,15 @@ def url(self) -> str:
266266
return f"msc://{self.profile}{int_path_str}"
267267

268268
def is_local(self) -> bool:
269-
if self.profile == "dss":
270-
# For now, a DSS path is always considered local.
271-
# Note that this does not mean it exists on the local filesystem.
272-
return True
273-
else:
274-
return self.profile == DEFAULT_PROFILE_NAME
269+
# It will return a posix path if the fs is local, otherwise None
270+
return self.fs.get_posix_path(self._internal_str_path) is not None
275271

276272
def local_path(self) -> PathlibPath:
277-
if not self.is_local():
273+
# This resolves the path if it exists, probably ok.
274+
posix_path = self.fs.get_posix_path(self._internal_str_path)
275+
if posix_path is None:
278276
raise ValueError(f"Path {self} is not local")
279-
return PathlibPath(self._internal_str_path)
277+
return PathlibPath(posix_path)
280278

281279
def is_dir(self) -> bool:
282280
try:
@@ -290,25 +288,46 @@ def is_file(self) -> bool:
290288
def mkdir(self, exist_ok: bool = True, parents: bool = False):
291289
pass
292290

293-
def glob(self, pattern) -> Generator["EPath", None, None]:
291+
def walk(self) -> Generator["EPath", None, None]:
292+
"""Returns all files within this path (no folders)."""
293+
# Prefix to be removed from found paths to remap to relative paths
294+
root_prefix = self._internal_str_path.lstrip("/")
295+
296+
for obj in self.fs.list_recursive(self._internal_str_path):
297+
rel = obj.key
298+
if root_prefix:
299+
if rel.startswith(root_prefix + "/"):
300+
rel = rel[len(root_prefix) + 1 :]
301+
elif rel.startswith("/" + root_prefix + "/"):
302+
rel = rel[len(root_prefix) + 2 :]
303+
elif rel == root_prefix or rel == "/" + root_prefix:
304+
rel = "."
305+
306+
path = EPath(self)
307+
path.internal_path = self._resolve(self.internal_path / PurePosixPath(rel))
308+
yield path
309+
310+
def glob(self, pattern: str) -> Generator["EPath", None, None]:
311+
"""Returns all files matching the pattern within this path (no folders)."""
294312
search_path_pattern = (self / pattern)._internal_str_path
295-
# MSC S3 glob matches keys like ``bucket/key``; a leading ``/`` breaks wcmatch (pattern
313+
# MSC glob matches keys like ``bucket/key``; a leading ``/`` breaks wcmatch (pattern
296314
# ``/b/**`` never matches ``b/parts/x``). Returned keys may repeat the bucket prefix; strip
297315
# it before joining with ``internal_path`` so we do not get ``/b/b/parts/...``.
298-
if not self.is_local() and search_path_pattern.startswith("/"):
299-
search_path_pattern = search_path_pattern.lstrip("/")
316+
search_path_pattern = search_path_pattern.lstrip("/")
300317

301-
root_prefix = str(self.internal_path).lstrip("/")
318+
# Prefix to be removed from found paths to remap to relative paths
319+
root_prefix = self._internal_str_path.lstrip("/")
302320

303321
for path in self.fs.glob(search_path_pattern):
304322
assert isinstance(path, str)
305323

306324
rel = path
307-
if not self.is_local() and root_prefix:
308-
pfx = root_prefix + "/"
309-
if rel.startswith(pfx):
310-
rel = rel[len(pfx) :]
311-
elif rel == root_prefix:
325+
if root_prefix:
326+
if rel.startswith(root_prefix + "/"):
327+
rel = rel[len(root_prefix) + 1 :]
328+
elif rel.startswith("/" + root_prefix + "/"):
329+
rel = rel[len(root_prefix) + 2 :]
330+
elif rel == root_prefix or rel == "/" + root_prefix:
312331
rel = "."
313332

314333
new_path = EPath(self)

src/megatron/energon/flavors/webdataset/prepare.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def prepare_dataset(
521521
tar_index_only: Only create tar-index, then exit
522522
media_filter: Media filter configuration
523523
fix_duplicates: If True, fix duplicate keys in the dataset by renaming the files in the shards.
524-
index_sqlite_tmp_path: When ``parent_path`` is remote, directory used to build ``index.sqlite``
524+
index_sqlite_tmp_path: When ``parent_path`` is remote, temp file path used to build ``index.sqlite``
525525
locally before upload. If omitted, a new directory under ``/tmp`` is created and removed
526526
after a successful run.
527527
@@ -784,7 +784,17 @@ def add_media_metadata(
784784
progress_fn: Callable[[Iterator[Any], int], Iterator[T]] = (lambda x, y: x),
785785
index_sqlite_tmp_path: Optional[Path] = None,
786786
) -> int:
787-
"""Add or refresh media metadata in an existing WebDataset index."""
787+
"""Add or refresh media metadata in an existing WebDataset index.
788+
789+
Args:
790+
parent_path: WebDataset root path.
791+
media_filter: Media filtering configuration.
792+
workers: Number of parallel workers.
793+
progress_fn: Callback for progress updates.
794+
index_sqlite_tmp_path: When ``parent_path`` is remote, sqlite file path used to build
795+
``index.sqlite`` locally before upload. If omitted, a new directory under
796+
``/tmp`` is created and removed after a successful run.
797+
"""
788798

789799
parent_path = EPath(parent_path)
790800

src/megatron/energon/media/filesystem_prepare.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from __future__ import annotations
55

6-
import os
76
import shutil
87
import tempfile
98
from functools import partial
@@ -41,6 +40,9 @@ def prepare_filesystem_dataset(
4140
root_path: Dataset root directory.
4241
media_filter: Media filtering configuration.
4342
progress: Whether to display a tqdm progress bar.
43+
index_sqlite_tmp_path: When ``root_path`` is remote, temp file path used to build
44+
``index.sqlite`` locally before upload. If omitted, a new directory under
45+
``/tmp`` is created and removed after a successful run.
4446
4547
Returns:
4648
Number of metadata entries written to the database.
@@ -137,23 +139,23 @@ def _collect_media_files(
137139

138140
progress_bar = tqdm(total=None, unit="file", desc="Collecting media files")
139141

140-
if root.is_local():
141-
paths = (
142-
EPath(f"{path}/{file}")
143-
for path, _dirs, files in os.walk(root.local_path(), followlinks=False)
144-
for file in files
145-
)
146-
else:
147-
paths = root.glob("**/*")
142+
# if root.is_local() and not root.profile == "dss":
143+
# paths = (
144+
# EPath(path) / file
145+
# for path, _dirs, files in os.walk(root.local_path(), followlinks=False)
146+
# for file in files
147+
# )
148+
# else:
149+
# paths = root.glob("**/*")
148150

149-
for file in paths:
151+
for file in root.walk():
150152
if progress_bar is not None:
151153
progress_bar.update()
152154

153-
if not consider_all and not media_filter.should_consider_media(file.name):
155+
if ("/" + MAIN_FOLDER_NAME + "/") in file.url:
154156
continue
155157

156-
if ("/" + MAIN_FOLDER_NAME + "/") in file.url:
158+
if not consider_all and not media_filter.should_consider_media(file.name):
157159
continue
158160

159161
files.append(file)

src/megatron/energon/tools/prepare.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@ def progress_fn(els, length=None):
348348
def progress_fn(els, length=None):
349349
return els
350350

351+
if tmp_path is not None:
352+
index_sqlite_tmp_path = tmp_path / "index.sqlite"
353+
else:
354+
index_sqlite_tmp_path = None
355+
351356
found_types = BaseWebdatasetFactory.prepare_dataset(
352357
path,
353358
all_tars,
@@ -359,7 +364,7 @@ def progress_fn(els, length=None):
359364
workers=num_workers,
360365
media_filter=media_filter_config,
361366
fix_duplicates=fix_duplicates,
362-
index_sqlite_tmp_path=tmp_path,
367+
index_sqlite_tmp_path=index_sqlite_tmp_path,
363368
)
364369

365370
found_types = list(found_types)

src/megatron/energon/tools/prepare_media.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def command(
7171
media_metadata_by_glob, media_metadata_by_header, media_metadata_by_extension
7272
)
7373

74+
if tmp_path is not None:
75+
index_sqlite_tmp_path = tmp_path / "index.sqlite"
76+
else:
77+
index_sqlite_tmp_path = None
78+
7479
ds_type = get_dataset_type(path)
7580
if ds_type == EnergonDatasetType.WEBDATASET:
7681
click.echo("Preparing webdataset and computing media metadata...")
@@ -96,7 +101,7 @@ def progress_fn(els, length=None):
96101
media_filter=media_filter_config,
97102
workers=num_workers,
98103
progress_fn=progress_fn,
99-
index_sqlite_tmp_path=tmp_path,
104+
index_sqlite_tmp_path=index_sqlite_tmp_path,
100105
)
101106

102107
click.echo(f"Done. Stored metadata for {count} files.")
@@ -112,6 +117,7 @@ def progress_fn(els, length=None):
112117
media_filter_config,
113118
progress=progress,
114119
workers=num_workers,
120+
index_sqlite_tmp_path=index_sqlite_tmp_path,
115121
)
116122
click.echo(f"Done. Stored metadata for {stored} files.")
117123

tests/s3_emulator/handler.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33
import urllib.parse as _up
44
from datetime import datetime, timezone
5-
from email.utils import formatdate
5+
from email.utils import format_datetime
66
from hashlib import md5
77
from http import HTTPStatus
88
from http.server import BaseHTTPRequestHandler
@@ -130,13 +130,14 @@ def _handle_write(self):
130130
return
131131
try:
132132
data = self.server.state.copy_object(bucket, key, src_bucket, src_key)
133+
last_modified = self.server.state.get_object_last_modified(bucket, key)
133134
except FileNotFoundError:
134135
self._send_error(HTTPStatus.NOT_FOUND, "NoSuchKey")
135136
return
136137
xml = (
137138
'<?xml version="1.0" encoding="UTF-8"?>'
138139
"<CopyObjectResult>"
139-
f"<LastModified>{_escape_xml(formatdate(usegmt=True))}</LastModified>"
140+
f"<LastModified>{_escape_xml(_s3_datetime(last_modified))}</LastModified>"
140141
f"<ETag>&quot;{_escape_xml(_etag(data))}&quot;</ETag>"
141142
"</CopyObjectResult>"
142143
).encode()
@@ -207,6 +208,7 @@ def _handle_read(self, listing: bool, only_headers: bool = False):
207208

208209
try:
209210
data = self.server.state.get_object(bucket, key)
211+
last_modified = self.server.state.get_object_last_modified(bucket, key)
210212
except FileNotFoundError:
211213
self._send_error(HTTPStatus.NOT_FOUND, "Not found")
212214
return
@@ -234,10 +236,10 @@ def _handle_read(self, listing: bool, only_headers: bool = False):
234236
"Accept-Ranges": "bytes",
235237
"Content-Length": str(len(slice_data)),
236238
"ETag": _etag(data),
239+
"Last-Modified": _http_datetime(last_modified),
237240
}
238241
if only_headers:
239242
headers.setdefault("Content-Type", "application/octet-stream")
240-
headers.setdefault("Last-Modified", formatdate(usegmt=True))
241243
self._send_status(HTTPStatus.PARTIAL_CONTENT, extra_headers=headers)
242244
else:
243245
self._send_bytes(
@@ -254,15 +256,19 @@ def _handle_read(self, listing: bool, only_headers: bool = False):
254256
"Content-Length": str(len(data)),
255257
"Accept-Ranges": "bytes",
256258
"Content-Type": "application/octet-stream",
257-
"Last-Modified": formatdate(usegmt=True),
259+
"Last-Modified": _http_datetime(last_modified),
258260
"ETag": _etag(data),
259261
},
260262
)
261263
else:
262264
self._send_bytes(
263265
data,
264266
content_type="application/octet-stream",
265-
extra_headers={"Accept-Ranges": "bytes"},
267+
extra_headers={
268+
"Accept-Ranges": "bytes",
269+
"Last-Modified": _http_datetime(last_modified),
270+
"ETag": _etag(data),
271+
},
266272
)
267273

268274
def _handle_delete(self):
@@ -414,7 +420,6 @@ def _render_list_bucket_result(
414420
start_after = (qs.get("start-after") or [""])[0]
415421
exclusive_after = continuation or start_after
416422

417-
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
418423
state = self.server.state
419424

420425
items: list[tuple[Literal["cp", "key"], str]] = []
@@ -471,15 +476,17 @@ def _render_list_bucket_result(
471476
else:
472477
try:
473478
data = state.get_object(bucket, path)
479+
last_modified = state.get_object_last_modified(bucket, path)
474480
size = len(data)
475481
etag = _etag(data)
476482
except Exception: # noqa: BLE001
477483
size = 0
478484
etag = '""'
485+
last_modified = datetime.fromtimestamp(0, tz=timezone.utc)
479486
fragments.append(
480487
"<Contents>"
481488
f"<Key>{_escape_xml(path)}</Key>"
482-
f"<LastModified>{now}</LastModified>"
489+
f"<LastModified>{_s3_datetime(last_modified)}</LastModified>"
483490
f"<ETag>{etag}</ETag>"
484491
f"<Size>{size}</Size>"
485492
"</Contents>"
@@ -541,6 +548,18 @@ def _escape_xml(text: str) -> str: # noqa: D401
541548
)
542549

543550

551+
def _http_datetime(value: datetime) -> str:
552+
"""Format an aware datetime for HTTP Last-Modified headers."""
553+
554+
return format_datetime(value.astimezone(timezone.utc), usegmt=True)
555+
556+
557+
def _s3_datetime(value: datetime) -> str:
558+
"""Format an aware datetime for S3 XML LastModified fields."""
559+
560+
return value.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
561+
562+
544563
def _etag(data: bytes) -> str: # noqa: D401
545564
"""Generate an ETag for binary data.
546565

0 commit comments

Comments
 (0)