Skip to content

Commit 88245f2

Browse files
test: save/restore round-trip on a --no-sample-tables dataset
Empirically verifies that the integer-indexed loader's checkpoint / resume path works on a dataset prepared with --no-sample-tables. ShardInfosITarReader and SliceState never touch the SQLite samples tables, so the load-bearing claim of the flag is that training-time save/restore still produces the same sample sequence. This test exercises the round-trip: 1. Reference: an uninterrupted iteration of 20 samples. 2. Capture state mid-stream (after 10 samples) via save_state_rank(). 3. Continue iterating to capture the next 10 samples (post_save). 4. Build a fresh loader, restore_state_rank(state), iterate 10 samples (post_restore). 5. Assert first_half + post_save == reference (no divergence from the reference run) and post_restore == post_save (resumed iteration matches continued iteration). Re-prepares the test fixture as CaptioningSample + --no-sample-tables so get_train_dataset returns decodable samples. Signed-off-by: Pei Li <pei.li@kaiko.ai>
1 parent e71d662 commit 88245f2

1 file changed

Lines changed: 66 additions & 0 deletions

File tree

tests/test_dataset.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,6 +1882,72 @@ def test_prepare_dataset_no_sample_tables(self):
18821882
finally:
18831883
reader.close()
18841884

1885+
def test_prepare_dataset_no_sample_tables_save_restore(self):
1886+
"""Resume after a mid-iteration checkpoint must reach the same samples in the
1887+
same order on a dataset prepared with ``--no-sample-tables``.
1888+
1889+
This is the load-bearing claim of the flag: the integer-indexed loader
1890+
(``ShardInfosITarReader``) and the savable state (``SliceState``) do not
1891+
touch the SQLite samples tables, so save/restore must work without them.
1892+
We re-prepare the fixture with ``--no-sample-tables`` plus a captioning
1893+
field-map (so ``get_train_dataset`` yields decodable samples), then compare
1894+
a save/restore round-trip against a reference run.
1895+
"""
1896+
1897+
from megatron.energon import get_savable_loader, get_train_dataset
1898+
1899+
runner = CliRunner()
1900+
result = runner.invoke(
1901+
prepare_command,
1902+
[
1903+
str(self.dataset_path),
1904+
"--non-interactive",
1905+
"--force-overwrite",
1906+
"--split-ratio=1,0,0",
1907+
"--sample-type=CaptioningSample",
1908+
'--field-map={"image": "png", "caption": "txt"}',
1909+
"--no-sample-tables",
1910+
],
1911+
catch_exceptions=False,
1912+
)
1913+
assert result.exit_code == 0, f"Prepare failed: {result.stdout}"
1914+
1915+
def loader_factory():
1916+
return get_savable_loader(
1917+
get_train_dataset(
1918+
self.dataset_path,
1919+
batch_size=2,
1920+
worker_config=no_worker_config,
1921+
shuffle_buffer_size=20,
1922+
max_samples_per_sequence=10,
1923+
)
1924+
)
1925+
1926+
def keys_from(loader, n):
1927+
return [tuple(batch.__key__) for _, batch in zip(range(n), loader)]
1928+
1929+
# Reference: a single uninterrupted run, used as the ground truth.
1930+
reference = keys_from(loader_factory(), 20)
1931+
1932+
# Capture state mid-stream.
1933+
loader = loader_factory()
1934+
first_half = keys_from(loader, 10)
1935+
state = loader.save_state_rank()
1936+
post_save = keys_from(loader, 10)
1937+
1938+
# Restore into a fresh loader and continue. The resumed sequence must
1939+
# match what the original loader produced after `save_state_rank()`.
1940+
resumed = loader_factory()
1941+
resumed.restore_state_rank(state)
1942+
post_restore = keys_from(resumed, 10)
1943+
1944+
assert first_half + post_save == reference, (
1945+
f"Uninterrupted iteration diverges from reference: {first_half + post_save} != {reference}"
1946+
)
1947+
assert post_restore == post_save, (
1948+
f"Resume diverged from continued iteration: {post_restore} != {post_save}"
1949+
)
1950+
18851951
def test_preview_captioning_dataset(self):
18861952
runner = CliRunner()
18871953
result = runner.invoke(

0 commit comments

Comments
 (0)