Skip to content

Commit 4c7ab4a

Browse files
Address second review round: clearer docs, eager check in SqliteITarEntryReader
- data_prep.md: replace "consumed sequentially" with the precise constraint ("not used as auxiliary or mounted"). - SqliteIndexReader: expose has_sample_tables as a constructor-time attribute (mirrors db_has_sample_parts); drop the per-method _check_samples_table guard. - SqliteITarEntryReader: raise MissingSamplesTableError at __init__ when the samples table is missing — fail fast at the boundary that actually requires it. - Test updated to assert at SqliteITarEntryReader construction. Signed-off-by: Pei Li <pei.li@kaiko.ai>
1 parent 88245f2 commit 4c7ab4a

4 files changed

Lines changed: 21 additions & 42 deletions

File tree

docs/source/basic/data_prep.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,13 +688,13 @@ The `media_metadata` table is used to store the media metadata for the selected
688688

689689
#### Skipping the samples tables for very large datasets
690690

691-
For very large datasets (100M+ samples), populating the `samples` and `sample_parts` tables can dominate `energon prepare` runtime. If the dataset will only be consumed sequentially (not as an auxiliary dataset), you can skip these tables:
691+
For very large datasets (100M+ samples), populating the `samples` and `sample_parts` tables can dominate `energon prepare` runtime. If the dataset will not be used as an auxiliary dataset or mounted, you can skip these tables:
692692

693693
```sh
694694
> energon prepare --no-sample-tables /path/to/dataset
695695
```
696696

697-
A dataset prepared this way cannot be used as an [auxiliary dataset](aux-data).
697+
A dataset prepared this way cannot be used as an [auxiliary dataset](aux-data) or with `energon mount`.
698698

699699
(data-on-disk-jsonl)=
700700
## Dataset Format on Disk for JSONL Datasets

src/megatron/energon/flavors/webdataset/indexing.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ class SqliteIndexReader:
397397

398398
sqlite_path: EPath
399399
db: ThreadLocalSqlite
400-
_samples_table_checked: bool
400+
has_sample_tables: bool
401401

402402
def __init__(self, sqlite_path: EPath):
403403
"""Initialize the SQLite database reader.
@@ -413,25 +413,10 @@ def __init__(self, sqlite_path: EPath):
413413
path = f"file:{path}?mode=ro&immutable=1"
414414

415415
self.db = ThreadLocalSqlite(path, is_uri=True)
416-
self._samples_table_checked = False
417-
418-
def _check_samples_table(self) -> None:
419-
"""Verify the ``samples`` table is present, raising a clear error if not.
420-
421-
Called by every method that queries the ``samples`` / ``sample_parts`` tables, so callers
422-
accessing a dataset prepared with ``--no-sample-tables`` get a descriptive error instead
423-
of a raw ``sqlite3.OperationalError: no such table: samples``. The check runs once per
424-
reader instance.
425-
"""
426-
if self._samples_table_checked:
427-
return
428-
assert self.db is not None, "Database is closed"
429416
row = self.db.select_one(
430417
"SELECT name FROM sqlite_master WHERE type='table' AND name='samples'"
431418
)
432-
if row is None:
433-
raise MissingSamplesTableError(self.sqlite_path)
434-
self._samples_table_checked = True
419+
self.has_sample_tables = row is not None
435420

436421
def db_has_sample_parts(self) -> bool:
437422
"""Check if the database has a sample_parts table.
@@ -466,7 +451,6 @@ def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
466451
"""
467452

468453
assert self.db is not None, "Database is closed"
469-
self._check_samples_table()
470454

471455
for row in self.db.select_all("SELECT sample_key, byte_size, tar_file_id FROM samples"):
472456
yield row[0], row[1], row[2]
@@ -479,7 +463,6 @@ def list_all_sample_parts(self) -> Generator[Tuple[str, int, int], None, None]:
479463
"""
480464

481465
assert self.db is not None, "Database is closed"
482-
self._check_samples_table()
483466

484467
# Select all parts (sorted by tar_file_id, sample_index) but joined with the sample_key names
485468
for row in self.db.select_all(
@@ -505,7 +488,6 @@ def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int],
505488
"""
506489

507490
assert self.db is not None, "Database is closed"
508-
self._check_samples_table()
509491

510492
# Select all parts (sorted by tar_file_id, sample_index) but joined with the sample_key names
511493
for row in self.db.select_all(
@@ -525,15 +507,13 @@ def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int],
525507
def get_total_size(self) -> int:
526508
"""Get the total size of all samples in the database."""
527509
assert self.db is not None, "Database is closed"
528-
self._check_samples_table()
529510

530511
count = self.db.select_one("SELECT SUM(byte_size) FROM samples")
531512
return count[0] if count else 0
532513

533514
def get_sample_count(self) -> int:
534515
"""Get the total number of samples in the database."""
535516
assert self.db is not None, "Database is closed"
536-
self._check_samples_table()
537517

538518
count = self.db.select_one("SELECT COUNT(*) FROM samples")
539519
return count[0] if count else 0
@@ -549,7 +529,6 @@ def get_sample_part(self, key: str, part_name: str) -> ITarRawSamplePartPointer:
549529
Pointer to the sample part raw data.
550530
"""
551531
assert self.db is not None, "Database is closed"
552-
self._check_samples_table()
553532

554533
row = self.db.select_one(
555534
"SELECT sp.tar_file_id, sp.content_byte_offset, sp.content_byte_size "
@@ -579,7 +558,6 @@ def get_sample_pointer_by_key(self, key: str) -> ITarSamplePointer:
579558
Tuple of (tar_file_id, sample_key, sample_index, byte_offset, byte_size)
580559
"""
581560
assert self.db is not None, "Database is closed"
582-
self._check_samples_table()
583561

584562
sample = self.db.select_one(
585563
"SELECT tar_file_id, sample_key, sample_index, byte_offset, byte_size "

src/megatron/energon/flavors/webdataset/itar_reader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ class SqliteITarEntryReader(ITarReader[str]):
485485

486486
sqlite_reader: SqliteIndexReader
487487
db_has_sample_parts: int
488+
has_sample_tables: bool
488489

489490
def __init__(
490491
self,
@@ -496,7 +497,10 @@ def __init__(
496497
disable_cache: bool = False,
497498
):
498499
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
499-
from megatron.energon.flavors.webdataset.indexing import SqliteIndexReader
500+
from megatron.energon.flavors.webdataset.indexing import (
501+
MissingSamplesTableError,
502+
SqliteIndexReader,
503+
)
500504

501505
# shard_name_to_info_idx = {name: i for i, name in enumerate(wds_meta.info_shard_files)}
502506
tar_filenames = get_info_shard_files(base_path)
@@ -507,6 +511,10 @@ def __init__(
507511
self.sqlite_reader = SqliteIndexReader(sqlite_path)
508512

509513
self.db_has_sample_parts = self.sqlite_reader.db_has_sample_parts()
514+
self.has_sample_tables = self.sqlite_reader.has_sample_tables
515+
516+
if not self.has_sample_tables:
517+
raise MissingSamplesTableError(sqlite_path)
510518

511519
self.key_is_full_entryname = key_is_full_entryname
512520

tests/test_dataset.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from megatron.energon.edataclass import edataclass
5151
from megatron.energon.flavors import BaseWebdatasetFactory
5252
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
53+
from megatron.energon.flavors.webdataset.itar_reader import SqliteITarEntryReader
5354
from megatron.energon.task_encoder.base import stateless
5455
from megatron.energon.tools.analyze_debug import command as analyze_debug_command
5556
from megatron.energon.tools.info import command as info_command
@@ -1862,25 +1863,19 @@ def test_prepare_dataset_no_sample_tables(self):
18621863
result.stdout + (result.stderr or "")
18631864
) or isinstance(result.exception, click.UsageError)
18641865

1865-
# Sample-key operations against the SQLite reader must raise a descriptive error,
1866-
# not the raw `sqlite3.OperationalError: no such table: samples`.
1866+
# SqliteIndexReader exposes the `has_sample_tables` flag as part of its public surface.
18671867
from megatron.energon.epathlib import EPath
18681868
from megatron.energon.flavors.webdataset.indexing import (
18691869
MissingSamplesTableError,
18701870
SqliteIndexReader,
18711871
)
18721872

1873-
reader = SqliteIndexReader(EPath(str(sqlite_path)))
1874-
try:
1875-
with self.assertRaises(MissingSamplesTableError) as ctx:
1876-
reader.get_sample_pointer_by_key("any-key")
1877-
assert "no-sample-tables" in str(ctx.exception)
1878-
with self.assertRaises(MissingSamplesTableError):
1879-
reader.get_sample_count()
1880-
with self.assertRaises(MissingSamplesTableError):
1881-
reader.get_sample_part("any-key", "txt")
1882-
finally:
1883-
reader.close()
1873+
index_reader = SqliteIndexReader(EPath(str(sqlite_path)))
1874+
assert index_reader.has_sample_tables is False
1875+
1876+
with self.assertRaises(MissingSamplesTableError) as ctx:
1877+
SqliteITarEntryReader(EPath(str(self.dataset_path)))
1878+
assert "no-sample-tables" in str(ctx.exception)
18841879

18851880
def test_prepare_dataset_no_sample_tables_save_restore(self):
18861881
"""Resume after a mid-iteration checkpoint must reach the same samples in the
@@ -1894,8 +1889,6 @@ def test_prepare_dataset_no_sample_tables_save_restore(self):
18941889
a save/restore round-trip against a reference run.
18951890
"""
18961891

1897-
from megatron.energon import get_savable_loader, get_train_dataset
1898-
18991892
runner = CliRunner()
19001893
result = runner.invoke(
19011894
prepare_command,

0 commit comments

Comments
 (0)