Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/advanced/data_prep_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ Then, run the script:
if __name__ == "__main__":
prepare_one_dataset(Path("/path/to/dataset"), 16, Path("/path/to/template_dir"))
```

## Skipping the SQLite samples tables for very large datasets
Comment thread
pei-li-hedgehog marked this conversation as resolved.

Pass `enable_sample_tables=False` to `prepare_dataset` (default `True`) to skip populating the `samples` and `sample_parts` tables. Useful when preparation is bottlenecked by SQLite indexing on very large datasets. A dataset prepared this way cannot be used as an [auxiliary dataset](aux-data).
10 changes: 10 additions & 0 deletions docs/source/basic/data_prep.md
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,16 @@ The `media_metadata` table is used to store the media metadata for the selected
| 00002.mp4 | ... | ... |


#### Skipping the samples tables for very large datasets
Comment thread
pei-li-hedgehog marked this conversation as resolved.

For very large datasets (100M+ samples), populating the `samples` and `sample_parts` tables can dominate `energon prepare` runtime. If the dataset will not be used as an auxiliary dataset or mounted, you can skip these tables:

```sh
> energon prepare --no-sample-tables /path/to/dataset
```

A dataset prepared this way cannot be used as an [auxiliary dataset](aux-data) or with `energon mount`.

(data-on-disk-jsonl)=
## Dataset Format on Disk for JSONL Datasets

Expand Down
56 changes: 54 additions & 2 deletions src/megatron/energon/flavors/webdataset/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@ def __init__(self, sample_key: str) -> None:
self.sample_key = sample_key


class MissingSamplesTableError(RuntimeError):
"""Raised when a sample-key operation is attempted on a dataset prepared without sample tables.

Datasets prepared with ``energon prepare --no-sample-tables`` (or programmatically with
``prepare_dataset(..., enable_sample_tables=False)``) do not have the ``samples`` /
``sample_parts`` tables in their SQLite index. Such datasets cannot be used as auxiliary
datasets and do not support sample-key lookups.
"""

def __init__(self, sqlite_path: EPath) -> None:
super().__init__(
f"Dataset at {sqlite_path.parent.parent} was prepared without the SQLite samples tables "
f"(`energon prepare --no-sample-tables` / `enable_sample_tables=False`). "
f"Re-prepare the dataset without that option to use it as an auxiliary dataset "
f"or for any sample-key lookup."
)
self.sqlite_path = sqlite_path


class SqliteIndexWriter:
sqlite_path: EPath
db: Optional[sqlite3.Connection]
Expand Down Expand Up @@ -79,6 +98,10 @@ def __init__(
if self.enable_sample_tables:
assert self.reset_tables, "Reset tables is required when enabling sample tables"

if self.reset_tables:
# Always drop on reset — stale tables from a previous prepare run with
# enable_sample_tables=True must not survive a re-prepare with
# enable_sample_tables=False (and vice versa).
self.db.execute("DROP INDEX IF EXISTS idx_samples_sample_key")
self.db.execute("DROP INDEX IF EXISTS idx_samples_by_tar_and_idx")
self.db.execute("DROP TABLE IF EXISTS samples")
Expand All @@ -87,6 +110,7 @@ def __init__(
self.db.execute("DROP INDEX IF EXISTS idx_sample_parts_full")
self.db.execute("DROP TABLE IF EXISTS sample_parts")

if self.enable_sample_tables:
self.db.execute(
"""
CREATE TABLE IF NOT EXISTS samples (
Expand Down Expand Up @@ -139,10 +163,17 @@ def append_samples(
self,
rows: Sequence["IndexSample"],
) -> None:
"""Insert multiple sample rows efficiently."""
"""Insert multiple sample rows efficiently.

No-op when ``enable_sample_tables`` was set to False at construction time — the
``samples`` table does not exist in that case.
"""

assert self.db is not None, "Database is closed"

if not self.enable_sample_tables:
return

if len(rows) == 0:
return

Expand Down Expand Up @@ -177,10 +208,17 @@ def append_parts(
self,
rows: Sequence["IndexSamplePart"],
) -> None:
"""Insert multiple sample part rows efficiently."""
"""Insert multiple sample part rows efficiently.

No-op when ``enable_sample_tables`` was set to False at construction time — the
``sample_parts`` table does not exist in that case.
"""

assert self.db is not None, "Database is closed"

if not self.enable_sample_tables:
return

if len(rows) == 0:
return

Expand Down Expand Up @@ -375,6 +413,20 @@ def __init__(self, sqlite_path: EPath):

self.db = ThreadLocalSqlite(path, is_uri=True)

def db_has_samples(self) -> bool:
"""Check if the database has a samples table.

Returns:
True if samples table exists, False otherwise.
"""
assert self.db is not None, "Database is closed"

db_exists = self.db.select_one(
"SELECT name FROM sqlite_master WHERE type='table' AND name='samples'"
)
self.db.thread_close()
return db_exists is not None

def db_has_sample_parts(self) -> bool:
"""Check if the database has a sample_parts table.

Expand Down
9 changes: 7 additions & 2 deletions src/megatron/energon/flavors/webdataset/itar_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ class SqliteITarEntryReader(ITarReader[str]):
"""

sqlite_reader: SqliteIndexReader
db_has_sample_parts: int
db_has_sample_parts: bool

def __init__(
self,
Expand All @@ -496,7 +496,10 @@ def __init__(
disable_cache: bool = False,
):
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.indexing import SqliteIndexReader
from megatron.energon.flavors.webdataset.indexing import (
MissingSamplesTableError,
SqliteIndexReader,
)

# shard_name_to_info_idx = {name: i for i, name in enumerate(wds_meta.info_shard_files)}
tar_filenames = get_info_shard_files(base_path)
Expand All @@ -507,6 +510,8 @@ def __init__(
self.sqlite_reader = SqliteIndexReader(sqlite_path)

self.db_has_sample_parts = self.sqlite_reader.db_has_sample_parts()
if not self.sqlite_reader.db_has_samples():
raise MissingSamplesTableError(sqlite_path)

self.key_is_full_entryname = key_is_full_entryname

Expand Down
9 changes: 9 additions & 0 deletions src/megatron/energon/flavors/webdataset/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def prepare_dataset(
tar_index_only: bool = False,
media_filter: Optional[MediaFilterConfig] = None,
fix_duplicates: bool = False,
enable_sample_tables: bool = True,
) -> Tuple[Set[str], List[Tuple[str, int]]]:
"""
Preprocess the shards and write the split config. Preprocessing is done in parallel.
Expand All @@ -507,6 +508,13 @@ def prepare_dataset(
tar_index_only: Only create tar-index, then exit
media_filter: Media filter configuration
fix_duplicates: If True, fix duplicate keys in the dataset by renaming the files in the shards.
enable_sample_tables: If True (default), populate the ``samples`` and ``sample_parts``
tables in the SQLite index. Set to False to skip these tables and their post-insert
btree builds — only the per-tar ``.tar.idx`` files, ``.info.json`` and split config
are produced. Use this for datasets consumed purely by the integer-indexed loader
(``ShardInfosITarReader``); sample-key lookups, polylithic joins and media-metadata
filtering will not work. Substantially reduces preparation time on very large
datasets (100M+ samples) where the SQLite inserts and index builds dominate runtime.

Returns:
The set of all parts found in the shards. But at most 50.
Expand Down Expand Up @@ -566,6 +574,7 @@ def prepare_dataset(
parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME,
total_tasks=len(paths),
progress_fn=progress_fn,
enable_sample_tables=enable_sample_tables,
enable_media_metadata=media_filter is not None,
media_filter=media_filter,
)
Expand Down
19 changes: 19 additions & 0 deletions src/megatron/energon/tools/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ def printify_json(data: Any) -> Any:
help="Only (re)generate the tar-index",
is_flag=True,
)
@click.option(
"--no-sample-tables",
help=(
"Skip populating the SQLite samples and sample_parts tables. Only the per-tar "
".tar.idx files, .info.json and split config are produced. Use for datasets "
"consumed purely by the integer-indexed loader; sample-key lookups, polylithic "
"joins and media-metadata filtering will not work. Substantially reduces "
"preparation time on very large datasets."
),
is_flag=True,
)
@click.option(
"--shuffle-tars",
help="If set, the tar files will be shuffled before splitting.",
Expand Down Expand Up @@ -191,6 +202,7 @@ def command(
exclude: str,
num_workers: int,
tar_index_only: bool,
no_sample_tables: bool,
shuffle_tars: bool,
media_metadata_by_glob: str | None,
media_metadata_by_header: bool,
Expand Down Expand Up @@ -220,6 +232,12 @@ def command(
if do_media_metadata and tar_index_only:
raise click.UsageError("--media-metadata-by-... cannot be combined with --tar-index-only")

if no_sample_tables and tar_index_only:
raise click.UsageError(
"--no-sample-tables cannot be combined with --tar-index-only "
"(--tar-index-only operates on an already-prepared dataset)"
)

media_filter_config = (
MediaFilterConfig.parse(
media_metadata_by_glob, media_metadata_by_header, media_metadata_by_extension
Expand Down Expand Up @@ -348,6 +366,7 @@ def progress_fn(els, length=None):
workers=num_workers,
media_filter=media_filter_config,
fix_duplicates=fix_duplicates,
enable_sample_tables=not no_sample_tables,
)

found_types = list(found_types)
Expand Down
130 changes: 130 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import math
import random
import sqlite3
import sys
import tempfile
import unittest
Expand Down Expand Up @@ -47,8 +48,11 @@
)
from megatron.energon.dataset_config import get_dataset_from_config
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.flavors import BaseWebdatasetFactory
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.indexing import MissingSamplesTableError, SqliteIndexReader
from megatron.energon.flavors.webdataset.itar_reader import SqliteITarEntryReader
from megatron.energon.task_encoder.base import stateless
from megatron.energon.tools.analyze_debug import command as analyze_debug_command
from megatron.energon.tools.info import command as info_command
Expand Down Expand Up @@ -1803,6 +1807,132 @@ def test_prepare_dataset_noninteractive_crude(self):
content = f.read()
assert "CrudeWebdataset" in content

def test_prepare_dataset_no_sample_tables(self):
"""`--no-sample-tables` skips the SQLite samples/sample_parts tables.

Verifies that:
- Prepare still succeeds and emits .info.json + the per-tar .tar.idx files.
- The SQLite database exists but does not contain the `samples` or `sample_parts`
tables (the bulk of the SQLite cost on large datasets).
- The flag rejects being combined with `--tar-index-only`.
"""

runner = CliRunner()
result = runner.invoke(
prepare_command,
[
str(self.dataset_path),
"--non-interactive",
"--force-overwrite",
"--split-ratio=1,0,0",
"--sample-type=CrudeWebdataset",
"--no-sample-tables",
],
catch_exceptions=False,
)
assert result.exit_code == 0, f"Prepare failed: {result.stdout}"

# .info.json and per-tar .tar.idx files must still exist.
assert (self.dataset_path / MAIN_FOLDER_NAME / ".info.json").is_file()
tar_idx_files = list(self.dataset_path.glob("**/*.tar.idx"))
assert len(tar_idx_files) > 0, "Expected per-tar .tar.idx files to be produced"

# SQLite file exists, but the samples / sample_parts tables must NOT be created.
sqlite_path = self.dataset_path / MAIN_FOLDER_NAME / "index.sqlite"
assert sqlite_path.is_file()
with sqlite3.connect(str(sqlite_path)) as conn:
tables = {
row[0] for row in conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
}
assert "samples" not in tables, f"unexpected samples table: {tables}"
assert "sample_parts" not in tables, f"unexpected sample_parts table: {tables}"

# --no-sample-tables + --tar-index-only must be rejected.
result = runner.invoke(
prepare_command,
[
str(self.dataset_path),
"--non-interactive",
"--no-sample-tables",
"--tar-index-only",
],
catch_exceptions=True,
)
assert result.exit_code != 0
assert "--no-sample-tables cannot be combined with --tar-index-only" in result.output

# SqliteIndexReader exposes db_has_samples() as part of its public surface.
index_reader = SqliteIndexReader(EPath(str(sqlite_path)))
assert index_reader.db_has_samples() is False

with self.assertRaises(MissingSamplesTableError) as ctx:
SqliteITarEntryReader(EPath(str(self.dataset_path)))
assert "no-sample-tables" in str(ctx.exception)

def test_prepare_dataset_no_sample_tables_save_restore(self):
"""Resume after a mid-iteration checkpoint must reach the same samples in the
same order on a dataset prepared with ``--no-sample-tables``.

This is the load-bearing claim of the flag: the integer-indexed loader
(``ShardInfosITarReader``) and the savable state (``SliceState``) do not
touch the SQLite samples tables, so save/restore must work without them.
We re-prepare the fixture with ``--no-sample-tables`` plus a captioning
field-map (so ``get_train_dataset`` yields decodable samples), then compare
a save/restore round-trip against a reference run.
"""

runner = CliRunner()
result = runner.invoke(
prepare_command,
[
str(self.dataset_path),
"--non-interactive",
"--force-overwrite",
"--split-ratio=1,0,0",
"--sample-type=CaptioningSample",
'--field-map={"image": "png", "caption": "txt"}',
"--no-sample-tables",
],
catch_exceptions=False,
)
assert result.exit_code == 0, f"Prepare failed: {result.stdout}"

def loader_factory():
return get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
worker_config=no_worker_config,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)

def keys_from(loader, n):
return [tuple(batch.__key__) for _, batch in zip(range(n), loader)]

# Reference: a single uninterrupted run, used as the ground truth.
reference = keys_from(loader_factory(), 20)

# Capture state mid-stream.
loader = loader_factory()
first_half = keys_from(loader, 10)
state = loader.save_state_rank()
post_save = keys_from(loader, 10)

# Restore into a fresh loader and continue. The resumed sequence must
# match what the original loader produced after `save_state_rank()`.
resumed = loader_factory()
resumed.restore_state_rank(state)
post_restore = keys_from(resumed, 10)

assert first_half + post_save == reference, (
f"Uninterrupted iteration diverges from reference: {first_half + post_save} != {reference}"
)
assert post_restore == post_save, (
f"Resume diverged from continued iteration: {post_restore} != {post_save}"
)

def test_preview_captioning_dataset(self):
runner = CliRunner()
result = runner.invoke(
Expand Down