Add sharded data-parallel runs sharing one run_path#295
Open
luciaquirke wants to merge 2 commits into
Open
Conversation
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
for more information, see https://pre-commit.ci
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Data attribution with IFs is embarrassingly parallel, but the SLURM example required users to hand-compute
--splitslices per node, write to per-node run paths, and manually stitch the resulting memmaps. This PR lets independentbergson build/bergson scoreinvocations (e.g. one per SLURM array task) write into the same run_path with no stitching at either end, and survive nodes dying and being restarted.--num_shards N --shard_id ionbuild/score(shard_idinferred fromSLURM_ARRAY_TASK_ID/SLURM_PROCIDwhen unset). Each shard processes one contiguous dataset slice and keeps its usual intra-node NCCL rendezvous.run_path/shards/<i>-of-<n>.partand atomically rename into place, recording provenance inshard.json(dataset row range, host, SLURM ids). A crashed shard leaves a.partdir and is rebuilt by re-running the same command; a published shard is skipped — requeued array tasks are idempotent.run_path/config.yaml(per-invocation fieldsshard_id/overwrite/node_rankstripped); later shards verify equality, so differently-configured shards can never mix in one run_path.bergson/sharding.pywithShardedMemmap, a lazy concatenated view over per-shard memmaps.load_gradients,load_scores,load_token_gradients,load_gradient_dataset,GradientProcessor.load, and the FAISS builder transparently resolve sharded run paths (allow_partialto peek at in-flight runs).bergson status <run_path>reports published / in-progress / missing shards.examples/slurm/data_parallel_score.shrewritten assbatch --array --requeuejob array; sharding docs added todocs/cli.rst.reduce, and pipeline commands explicitly reject sharded mode (factors can't be merged across independent shards yet).Test plan
tests/test_sharded_runs.pyunit tests: shard ranges matchDataset.shard(contiguous=True),ShardedMemmapindexing vsnp.concatenate(flat + structured), canonical config publish/verify, shard inventory/coverage errors, config validation.partleft, not published; restart rebuilds; re-running a published shard is a no-op; mismatched config rejected; sharded index reads back identical to a non-sharded build of the same databergson score(2 shards):load_scoresreturns one concatenated array,is_written(),bergson statusreports completetest_build/data/reduce/multinode/truncation/advantages/batch_size_invariancepass;pre-commit run --all-filesclean🤖 Generated with Claude Code