Add MultiGroupDistributedDataParallel#688
Draft
AkshitaB wants to merge 3 commits into
Draft
Conversation
Add a data-parallel module wrapper under `olmo_core.nn.parallel` that all-reduces gradients across one or more process groups. Unlike `torch.nn.parallel.DistributedDataParallel`, it accumulates gradients into flat, pre-allocated bucket views (each parameter's `.grad` is a view into a contiguous bucket) and supports assigning different parameters to different process groups via `param_process_group_fn`. This is useful when subsets of a model are replicated over different device meshes (e.g. expert weights under expert parallelism, which must be data-parallel-reduced only over the ranks that hold the same experts). Bucketed all-reduces are launched and overlapped with the backward pass as each bucket fills; `finalize_grad_reduce()` is called after `loss.backward()` to wait for them to complete. Gradients can optionally be accumulated and/or reduced in fp32. Includes a gloo/nccl distributed test covering gradient parity against a manual all-reduce reference and `no_sync()` accumulation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a data-parallel module wrapper,
MultiGroupDistributedDataParallel, underolmo_core.nn.parallel. It all-reduces gradients across one or more process groups.The distinguishing feature vs.
torch.nn.parallel.DistributedDataParallel: different parameters can be reduced over different process groups, via aparam_process_group_fn(name, param) -> process_groupmapping. This is needed when subsets of a model are replicated over different device meshes — e.g. expert weights under expert parallelism, which must be data-parallel-reduced only over the ranks holding the same experts, while the rest of the model reduces over the full DP group.Mechanics
.gradis a view into its bucket slice. This enables one large all-reduce per bucket instead of many tiny ones.finalize_grad_reduce()is called afterloss.backward()to wait for completion.Notes
getattr(module, ...)) for syncing fp8/mxfp8 gradient stores; these are no-ops for standard modules.Tests
New gloo/nccl distributed test (
src/test/nn/parallel/distributed_test.py):test_grad_parity— DDP-averaged gradients match a manual all-reduce reference.test_no_sync_accumulation—no_sync()defers the reduction and correctly accumulates gradients across micro-batches before a single reduce.pytest src/test/nn/parallel/passes on gloo (nccl auto-skips without a GPU);make checks(isort/black/ruff/mypy) clean.🤖 Generated with Claude Code