diff --git a/docs/source/advanced/data_prep_api.md b/docs/source/advanced/data_prep_api.md index 8768f7ea..24a578cd 100644 --- a/docs/source/advanced/data_prep_api.md +++ b/docs/source/advanced/data_prep_api.md @@ -68,3 +68,7 @@ Then, run the script: if __name__ == "__main__": prepare_one_dataset(Path("/path/to/dataset"), 16, Path("/path/to/template_dir")) ``` + +## Skipping the SQLite samples tables for very large datasets + +Pass `enable_sample_tables=False` to `prepare_dataset` (default `True`) to skip populating the `samples` and `sample_parts` tables. Useful when preparation is bottlenecked by SQLite indexing on very large datasets. A dataset prepared this way cannot be used as an [auxiliary dataset](aux-data). diff --git a/docs/source/basic/data_prep.md b/docs/source/basic/data_prep.md index 73629696..1edb62b5 100644 --- a/docs/source/basic/data_prep.md +++ b/docs/source/basic/data_prep.md @@ -686,6 +686,16 @@ The `media_metadata` table is used to store the media metadata for the selected | 00002.mp4 | ... | ... | +#### Skipping the samples tables for very large datasets + +For very large datasets (100M+ samples), populating the `samples` and `sample_parts` tables can dominate `energon prepare` runtime. If the dataset will not be used as an auxiliary dataset or mounted, you can skip these tables: + +```sh +> energon prepare --no-sample-tables /path/to/dataset +``` + +A dataset prepared this way cannot be used as an [auxiliary dataset](aux-data) or with `energon mount`. + (data-on-disk-jsonl)= ## Dataset Format on Disk for JSONL Datasets diff --git a/src/megatron/energon/flavors/webdataset/indexing.py b/src/megatron/energon/flavors/webdataset/indexing.py index 4e5cebcd..c698a5f8 100644 --- a/src/megatron/energon/flavors/webdataset/indexing.py +++ b/src/megatron/energon/flavors/webdataset/indexing.py @@ -27,6 +27,25 @@ def __init__(self, sample_key: str) -> None: self.sample_key = sample_key +class MissingSamplesTableError(RuntimeError): + """Raised when a sample-key operation is attempted on a dataset prepared without sample tables. + + Datasets prepared with ``energon prepare --no-sample-tables`` (or programmatically with + ``prepare_dataset(..., enable_sample_tables=False)``) do not have the ``samples`` / + ``sample_parts`` tables in their SQLite index. Such datasets cannot be used as auxiliary + datasets and do not support sample-key lookups. + """ + + def __init__(self, sqlite_path: EPath) -> None: + super().__init__( + f"Dataset at {sqlite_path.parent.parent} was prepared without the SQLite samples tables " + f"(`energon prepare --no-sample-tables` / `enable_sample_tables=False`). " + f"Re-prepare the dataset without that option to use it as an auxiliary dataset " + f"or for any sample-key lookup." + ) + self.sqlite_path = sqlite_path + + class SqliteIndexWriter: sqlite_path: EPath db: Optional[sqlite3.Connection] @@ -79,6 +98,10 @@ def __init__( if self.enable_sample_tables: assert self.reset_tables, "Reset tables is required when enabling sample tables" + if self.reset_tables: + # Always drop on reset — stale tables from a previous prepare run with + # enable_sample_tables=True must not survive a re-prepare with + # enable_sample_tables=False (and vice versa). self.db.execute("DROP INDEX IF EXISTS idx_samples_sample_key") self.db.execute("DROP INDEX IF EXISTS idx_samples_by_tar_and_idx") self.db.execute("DROP TABLE IF EXISTS samples") @@ -87,6 +110,7 @@ def __init__( self.db.execute("DROP INDEX IF EXISTS idx_sample_parts_full") self.db.execute("DROP TABLE IF EXISTS sample_parts") + if self.enable_sample_tables: self.db.execute( """ CREATE TABLE IF NOT EXISTS samples ( @@ -139,10 +163,17 @@ def append_samples( self, rows: Sequence["IndexSample"], ) -> None: - """Insert multiple sample rows efficiently.""" + """Insert multiple sample rows efficiently. + + No-op when ``enable_sample_tables`` was set to False at construction time — the + ``samples`` table does not exist in that case. + """ assert self.db is not None, "Database is closed" + if not self.enable_sample_tables: + return + if len(rows) == 0: return @@ -177,10 +208,17 @@ def append_parts( self, rows: Sequence["IndexSamplePart"], ) -> None: - """Insert multiple sample part rows efficiently.""" + """Insert multiple sample part rows efficiently. + + No-op when ``enable_sample_tables`` was set to False at construction time — the + ``sample_parts`` table does not exist in that case. + """ assert self.db is not None, "Database is closed" + if not self.enable_sample_tables: + return + if len(rows) == 0: return @@ -375,6 +413,20 @@ def __init__(self, sqlite_path: EPath): self.db = ThreadLocalSqlite(path, is_uri=True) + def db_has_samples(self) -> bool: + """Check if the database has a samples table. + + Returns: + True if samples table exists, False otherwise. + """ + assert self.db is not None, "Database is closed" + + db_exists = self.db.select_one( + "SELECT name FROM sqlite_master WHERE type='table' AND name='samples'" + ) + self.db.thread_close() + return db_exists is not None + def db_has_sample_parts(self) -> bool: """Check if the database has a sample_parts table. diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index b6c837aa..3b8ae3c1 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -484,7 +484,7 @@ class SqliteITarEntryReader(ITarReader[str]): """ sqlite_reader: SqliteIndexReader - db_has_sample_parts: int + db_has_sample_parts: bool def __init__( self, @@ -496,7 +496,10 @@ def __init__( disable_cache: bool = False, ): from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME - from megatron.energon.flavors.webdataset.indexing import SqliteIndexReader + from megatron.energon.flavors.webdataset.indexing import ( + MissingSamplesTableError, + SqliteIndexReader, + ) # shard_name_to_info_idx = {name: i for i, name in enumerate(wds_meta.info_shard_files)} tar_filenames = get_info_shard_files(base_path) @@ -507,6 +510,8 @@ def __init__( self.sqlite_reader = SqliteIndexReader(sqlite_path) self.db_has_sample_parts = self.sqlite_reader.db_has_sample_parts() + if not self.sqlite_reader.db_has_samples(): + raise MissingSamplesTableError(sqlite_path) self.key_is_full_entryname = key_is_full_entryname diff --git a/src/megatron/energon/flavors/webdataset/prepare.py b/src/megatron/energon/flavors/webdataset/prepare.py index d3a16223..eb311417 100644 --- a/src/megatron/energon/flavors/webdataset/prepare.py +++ b/src/megatron/energon/flavors/webdataset/prepare.py @@ -489,6 +489,7 @@ def prepare_dataset( tar_index_only: bool = False, media_filter: Optional[MediaFilterConfig] = None, fix_duplicates: bool = False, + enable_sample_tables: bool = True, ) -> Tuple[Set[str], List[Tuple[str, int]]]: """ Preprocess the shards and write the split config. Preprocessing is done in parallel. @@ -507,6 +508,13 @@ def prepare_dataset( tar_index_only: Only create tar-index, then exit media_filter: Media filter configuration fix_duplicates: If True, fix duplicate keys in the dataset by renaming the files in the shards. + enable_sample_tables: If True (default), populate the ``samples`` and ``sample_parts`` + tables in the SQLite index. Set to False to skip these tables and their post-insert + btree builds — only the per-tar ``.tar.idx`` files, ``.info.json`` and split config + are produced. Use this for datasets consumed purely by the integer-indexed loader + (``ShardInfosITarReader``); sample-key lookups, polylithic joins and media-metadata + filtering will not work. Substantially reduces preparation time on very large + datasets (100M+ samples) where the SQLite inserts and index builds dominate runtime. Returns: The set of all parts found in the shards. But at most 50. @@ -566,6 +574,7 @@ def prepare_dataset( parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME, total_tasks=len(paths), progress_fn=progress_fn, + enable_sample_tables=enable_sample_tables, enable_media_metadata=media_filter is not None, media_filter=media_filter, ) diff --git a/src/megatron/energon/tools/prepare.py b/src/megatron/energon/tools/prepare.py index 5e6643cf..dd24e73b 100644 --- a/src/megatron/energon/tools/prepare.py +++ b/src/megatron/energon/tools/prepare.py @@ -123,6 +123,17 @@ def printify_json(data: Any) -> Any: help="Only (re)generate the tar-index", is_flag=True, ) +@click.option( + "--no-sample-tables", + help=( + "Skip populating the SQLite samples and sample_parts tables. Only the per-tar " + ".tar.idx files, .info.json and split config are produced. Use for datasets " + "consumed purely by the integer-indexed loader; sample-key lookups, polylithic " + "joins and media-metadata filtering will not work. Substantially reduces " + "preparation time on very large datasets." + ), + is_flag=True, +) @click.option( "--shuffle-tars", help="If set, the tar files will be shuffled before splitting.", @@ -191,6 +202,7 @@ def command( exclude: str, num_workers: int, tar_index_only: bool, + no_sample_tables: bool, shuffle_tars: bool, media_metadata_by_glob: str | None, media_metadata_by_header: bool, @@ -220,6 +232,12 @@ def command( if do_media_metadata and tar_index_only: raise click.UsageError("--media-metadata-by-... cannot be combined with --tar-index-only") + if no_sample_tables and tar_index_only: + raise click.UsageError( + "--no-sample-tables cannot be combined with --tar-index-only " + "(--tar-index-only operates on an already-prepared dataset)" + ) + media_filter_config = ( MediaFilterConfig.parse( media_metadata_by_glob, media_metadata_by_header, media_metadata_by_extension @@ -348,6 +366,7 @@ def progress_fn(els, length=None): workers=num_workers, media_filter=media_filter_config, fix_duplicates=fix_duplicates, + enable_sample_tables=not no_sample_tables, ) found_types = list(found_types) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 01aa74ee..3cc9efb6 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -10,6 +10,7 @@ import logging import math import random +import sqlite3 import sys import tempfile import unittest @@ -47,8 +48,11 @@ ) from megatron.energon.dataset_config import get_dataset_from_config from megatron.energon.edataclass import edataclass +from megatron.energon.epathlib import EPath from megatron.energon.flavors import BaseWebdatasetFactory from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.indexing import MissingSamplesTableError, SqliteIndexReader +from megatron.energon.flavors.webdataset.itar_reader import SqliteITarEntryReader from megatron.energon.task_encoder.base import stateless from megatron.energon.tools.analyze_debug import command as analyze_debug_command from megatron.energon.tools.info import command as info_command @@ -1803,6 +1807,132 @@ def test_prepare_dataset_noninteractive_crude(self): content = f.read() assert "CrudeWebdataset" in content + def test_prepare_dataset_no_sample_tables(self): + """`--no-sample-tables` skips the SQLite samples/sample_parts tables. + + Verifies that: + - Prepare still succeeds and emits .info.json + the per-tar .tar.idx files. + - The SQLite database exists but does not contain the `samples` or `sample_parts` + tables (the bulk of the SQLite cost on large datasets). + - The flag rejects being combined with `--tar-index-only`. + """ + + runner = CliRunner() + result = runner.invoke( + prepare_command, + [ + str(self.dataset_path), + "--non-interactive", + "--force-overwrite", + "--split-ratio=1,0,0", + "--sample-type=CrudeWebdataset", + "--no-sample-tables", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0, f"Prepare failed: {result.stdout}" + + # .info.json and per-tar .tar.idx files must still exist. + assert (self.dataset_path / MAIN_FOLDER_NAME / ".info.json").is_file() + tar_idx_files = list(self.dataset_path.glob("**/*.tar.idx")) + assert len(tar_idx_files) > 0, "Expected per-tar .tar.idx files to be produced" + + # SQLite file exists, but the samples / sample_parts tables must NOT be created. + sqlite_path = self.dataset_path / MAIN_FOLDER_NAME / "index.sqlite" + assert sqlite_path.is_file() + with sqlite3.connect(str(sqlite_path)) as conn: + tables = { + row[0] for row in conn.execute("SELECT name FROM sqlite_master WHERE type='table'") + } + assert "samples" not in tables, f"unexpected samples table: {tables}" + assert "sample_parts" not in tables, f"unexpected sample_parts table: {tables}" + + # --no-sample-tables + --tar-index-only must be rejected. + result = runner.invoke( + prepare_command, + [ + str(self.dataset_path), + "--non-interactive", + "--no-sample-tables", + "--tar-index-only", + ], + catch_exceptions=True, + ) + assert result.exit_code != 0 + assert "--no-sample-tables cannot be combined with --tar-index-only" in result.output + + # SqliteIndexReader exposes db_has_samples() as part of its public surface. + index_reader = SqliteIndexReader(EPath(str(sqlite_path))) + assert index_reader.db_has_samples() is False + + with self.assertRaises(MissingSamplesTableError) as ctx: + SqliteITarEntryReader(EPath(str(self.dataset_path))) + assert "no-sample-tables" in str(ctx.exception) + + def test_prepare_dataset_no_sample_tables_save_restore(self): + """Resume after a mid-iteration checkpoint must reach the same samples in the + same order on a dataset prepared with ``--no-sample-tables``. + + This is the load-bearing claim of the flag: the integer-indexed loader + (``ShardInfosITarReader``) and the savable state (``SliceState``) do not + touch the SQLite samples tables, so save/restore must work without them. + We re-prepare the fixture with ``--no-sample-tables`` plus a captioning + field-map (so ``get_train_dataset`` yields decodable samples), then compare + a save/restore round-trip against a reference run. + """ + + runner = CliRunner() + result = runner.invoke( + prepare_command, + [ + str(self.dataset_path), + "--non-interactive", + "--force-overwrite", + "--split-ratio=1,0,0", + "--sample-type=CaptioningSample", + '--field-map={"image": "png", "caption": "txt"}', + "--no-sample-tables", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0, f"Prepare failed: {result.stdout}" + + def loader_factory(): + return get_savable_loader( + get_train_dataset( + self.dataset_path, + batch_size=2, + worker_config=no_worker_config, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) + + def keys_from(loader, n): + return [tuple(batch.__key__) for _, batch in zip(range(n), loader)] + + # Reference: a single uninterrupted run, used as the ground truth. + reference = keys_from(loader_factory(), 20) + + # Capture state mid-stream. + loader = loader_factory() + first_half = keys_from(loader, 10) + state = loader.save_state_rank() + post_save = keys_from(loader, 10) + + # Restore into a fresh loader and continue. The resumed sequence must + # match what the original loader produced after `save_state_rank()`. + resumed = loader_factory() + resumed.restore_state_rank(state) + post_restore = keys_from(resumed, 10) + + assert first_half + post_save == reference, ( + f"Uninterrupted iteration diverges from reference: {first_half + post_save} != {reference}" + ) + assert post_restore == post_save, ( + f"Resume diverged from continued iteration: {post_restore} != {post_save}" + ) + def test_preview_captioning_dataset(self): runner = CliRunner() result = runner.invoke(