Skip to content

Commit 9cca0d0

Browse files
committed
Fix outdated local copy. Implement prepare media for filesystem paths on remote.
1 parent 7c220f0 commit 9cca0d0

10 files changed

Lines changed: 252 additions & 94 deletions

File tree

src/megatron/energon/cache/file_store.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,19 +161,18 @@ def get_path(self) -> str:
161161
def get_media_metadata(self, key: str) -> MediaMetadataBase:
162162
if self._media_metadata_available is None:
163163
try:
164-
has_metadata = self.sqlite_reader.db_has_media_metadata()
164+
self._media_metadata_available = self.sqlite_reader.db_has_media_metadata()
165165
except sqlite3.Error as exc: # pragma: no cover - defensive
166+
self._media_metadata_available = False
166167
raise RuntimeError(
167168
"Failed to inspect media metadata table. Re-run `energon prepare --media-metadata-by-...`."
168169
) from exc
169170

170-
if not has_metadata:
171-
raise RuntimeError(
172-
"Media metadata is not available for this dataset. "
173-
"Run `energon prepare --media-metadata-by-...` to generate it."
174-
)
175-
176-
self._media_metadata_available = True
171+
if not self._media_metadata_available:
172+
raise RuntimeError(
173+
"Media metadata is not available for this dataset. "
174+
"Run `energon prepare --media-metadata-by-...` to generate it."
175+
)
177176

178177
try:
179178
row = self.sqlite_reader.get_media_metadata(key)

src/megatron/energon/epathlib/epath.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,9 @@ def relative_to(self, other: "EPath") -> str:
336336

337337
return str(self.internal_path.relative_to(other.internal_path))
338338

339+
def stat(self) -> msc.types.ObjectMetadata:
340+
return self.fs.info(self._internal_str_path)
341+
339342
@property
340343
def display_name(self) -> str:
341344
if self.profile == "dss":

src/megatron/energon/flavors/webdataset/indexing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ def __init__(
4949
part_name TEXT,
5050
content_byte_offset INTEGER,
5151
content_byte_size INTEGER)
52+
if enable_media_metadata is True, it also creates the media_metadata table:
53+
- media_metadata(entry_key TEXT PRIMARY KEY,
54+
metadata_type TEXT NOT NULL,
55+
metadata_json TEXT NOT NULL)
56+
if enable_media_metadata is True, it also creates the media_filters table:
57+
- media_filters(filter_id INTEGER PRIMARY KEY AUTOINCREMENT,
58+
strategy TEXT NOT NULL,
59+
patterns TEXT,
60+
created_at_utc TEXT DEFAULT CURRENT_TIMESTAMP,
61+
UNIQUE(strategy, patterns))
5262
Also creates indexes:
5363
- samples(sample_key)
5464
- samples(tar_file_id, sample_index)
@@ -309,6 +319,9 @@ class SqliteIndexReader:
309319
part_name TEXT,
310320
content_byte_offset INTEGER,
311321
content_byte_size INTEGER)
322+
- media_metadata(entry_key TEXT PRIMARY KEY,
323+
metadata_type TEXT NOT NULL,
324+
metadata_json TEXT NOT NULL)
312325
"""
313326

314327
sqlite_path: EPath

src/megatron/energon/flavors/webdataset/prepare.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def on_start(self, aggregator_pool: AggregatorPool) -> None:
158158
local_sqlite = self.sqlite_path
159159
if self.sqlite_local_build_path is not None:
160160
local_sqlite = EPath(self.sqlite_local_build_path)
161+
if self.sqlite_path.is_file():
162+
self.sqlite_path.copy(local_sqlite)
161163
self.writer = SqliteIndexWriter(
162164
local_sqlite,
163165
enable_sample_tables=self.enable_sample_tables,
@@ -565,7 +567,9 @@ def prepare_dataset(
565567
remote_sqlite_tmp_dir: Optional[Path] = None
566568
if not parent_path.is_local():
567569
if index_sqlite_tmp_path is None:
568-
remote_sqlite_tmp_dir = Path(tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-"))
570+
remote_sqlite_tmp_dir = Path(
571+
tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-")
572+
)
569573
index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME
570574
owns_remote_sqlite_tmp = True
571575
else:
@@ -618,7 +622,9 @@ def prepare_dataset(
618622
# Fix permissions if needed
619623
if fix_local_permissions:
620624
try:
621-
Path(str(parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME)).chmod(file_perms)
625+
Path(str(parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME)).chmod(
626+
file_perms
627+
)
622628
except OSError:
623629
pass
624630

@@ -703,7 +709,9 @@ def prepare_dataset(
703709
for split_part, split_ratio in split_parts_ratio:
704710
split_total += split_ratio
705711
split_end = int(len(shards) * split_total)
706-
split_shards[split_part] = [shard.name for shard in shards[split_offset:split_end]]
712+
split_shards[split_part] = [
713+
shard.name for shard in shards[split_offset:split_end]
714+
]
707715
split_offset = split_end
708716
else:
709717
assert split_parts_patterns is not None, (
@@ -749,7 +757,6 @@ def prepare_dataset(
749757
if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None:
750758
shutil.rmtree(remote_sqlite_tmp_dir, ignore_errors=True)
751759

752-
753760
@classmethod
754761
def add_media_metadata(
755762
cls,
@@ -781,12 +788,16 @@ def add_media_metadata(
781788
remote_sqlite_tmp_dir: Optional[Path] = None
782789
if not parent_path.is_local():
783790
if index_sqlite_tmp_path is None:
784-
remote_sqlite_tmp_dir = Path(tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-media-"))
791+
remote_sqlite_tmp_dir = Path(
792+
tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-media-")
793+
)
785794
index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME
786795
owns_remote_sqlite_tmp = True
787796
else:
788797
index_sqlite_tmp_path = None
789798

799+
sqlite_path = parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME
800+
790801
try:
791802
aggregator = SqliteIndexWriterAggregator(
792803
parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME,
@@ -818,6 +829,18 @@ def add_media_metadata(
818829

819830
pool.process()
820831

832+
if sqlite_path.is_local():
833+
try:
834+
meta_dir = (parent_path / MAIN_FOLDER_NAME).local_path()
835+
# Copy group permissions from the parent dir
836+
meta_dir.chmod((parent_path.local_path().stat().st_mode | 0o700))
837+
# Just read/write, no execute
838+
sqlite_path.local_path().chmod(
839+
(parent_path.local_path().stat().st_mode | 0o600) & 0o666
840+
)
841+
except OSError:
842+
pass
843+
821844
return aggregator.media_metadata_written
822845
finally:
823846
if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None:

src/megatron/energon/local_copy.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@
33

44

55
import hashlib
6+
import os
67
from pathlib import Path
78

89
from filelock import FileLock
910

1011
from megatron.energon.epathlib import EPath
1112

12-
LOCAL_COPY_TMP_DIR = Path("/tmp/energon_local_copy")
13+
LOCAL_COPY_TMP_DIR = Path(os.environ.get("ENERGON_LOCAL_COPY_TMP_DIR", "/tmp/energon_local_copy"))
1314

1415

1516
def ensure_local_copy(path: EPath) -> EPath:
1617
"""If the path is not local, copy it to a temporary directory and return the
17-
path to the temporary directory.
18+
path to the temporary directory. Assuming that the local file is never modified
19+
after it is copied. Will re-sync if the remote file is newer.
1820
1921
Args:
2022
path: The path to the file to copy.
@@ -33,10 +35,13 @@ def ensure_local_copy(path: EPath) -> EPath:
3335
final_path = LOCAL_COPY_TMP_DIR / f"{digest}.bin"
3436
lock_path = final_path.with_suffix(".lock")
3537
tmp_path = final_path.with_suffix(".part")
38+
mod_time = path.stat().last_modified.timestamp()
3639

3740
# Block until lock is free
3841
with FileLock(lock_path, timeout=60 * 5):
39-
if final_path.exists(): # someone else already produced it
42+
# someone else already produced it
43+
if final_path.exists() and final_path.stat().st_mtime >= mod_time:
44+
# The local file is already newer than the remote file
4045
return EPath(final_path)
4146

4247
# We are the downloader

src/megatron/energon/media/extractor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from PIL import Image, UnidentifiedImageError
1616

1717
from megatron.energon.av import AVDecoder
18+
from megatron.energon.epathlib import EPath
1819
from megatron.energon.media.metadata import ImageMetadata, MediaMetadataBase, MediaMetadataType
1920

2021
logger = logging.getLogger(__name__)
2122

2223

23-
SourceData = Union[bytes, Path, BinaryIO]
24+
SourceData = Union[bytes, EPath, BinaryIO]
2425

2526

2627
class MediaFilterStrategy(str, Enum):
@@ -210,9 +211,13 @@ def _build_metadata(
210211

211212

212213
def _build_image_metadata(source: SourceData) -> ImageMetadata | None:
214+
should_close = False
213215
try:
214216
if isinstance(source, (bytes, bytearray)):
215217
source = io.BytesIO(source)
218+
elif isinstance(source, EPath):
219+
source = source.open("rb")
220+
should_close = True
216221

217222
with Image.open(source) as image:
218223
image.load()
@@ -225,6 +230,9 @@ def _build_image_metadata(source: SourceData) -> ImageMetadata | None:
225230
except UnidentifiedImageError:
226231
logger.debug("Failed to parse image metadata", exc_info=True)
227232
return None
233+
finally:
234+
if should_close:
235+
source.close()
228236

229237

230238
def _build_av_metadata(source: SourceData) -> MediaMetadataBase | None:

0 commit comments

Comments
 (0)