Skip to content

Add MultiGroupDistributedDataParallel#688

Draft
AkshitaB wants to merge 3 commits into
mainfrom
akshitab/multigroup-ddp
Draft

Add MultiGroupDistributedDataParallel#688
AkshitaB wants to merge 3 commits into
mainfrom
akshitab/multigroup-ddp

Conversation

@AkshitaB
Copy link
Copy Markdown
Contributor

Summary

Adds a data-parallel module wrapper, MultiGroupDistributedDataParallel, under olmo_core.nn.parallel. It all-reduces gradients across one or more process groups.

The distinguishing feature vs. torch.nn.parallel.DistributedDataParallel: different parameters can be reduced over different process groups, via a param_process_group_fn(name, param) -> process_group mapping. This is needed when subsets of a model are replicated over different device meshes — e.g. expert weights under expert parallelism, which must be data-parallel-reduced only over the ranks holding the same experts, while the rest of the model reduces over the full DP group.

Mechanics

  • Bucket views: gradients are accumulated into flat, pre-allocated buffers; each parameter's .grad is a view into its bucket slice. This enables one large all-reduce per bucket instead of many tiny ones.
  • Per-group buckets: a bucket never mixes process groups, so each bucket is reduced over its own group with its own world size.
  • Overlap: bucketed all-reduces launch during the backward pass as each bucket fills; finalize_grad_reduce() is called after loss.backward() to wait for completion.
  • Optional fp32: gradients can be accumulated and/or reduced in fp32 for numerical stability with low-precision parameters.
ddp = MultiGroupDistributedDataParallel(model)
out = ddp(inputs)
out.loss.backward()
ddp.finalize_grad_reduce()  # waits for the overlapped all-reduces
optimizer.step()

Notes

  • This lands the primitive ahead of its consumers; nothing in-tree wires it in yet.
  • The wrapper includes opt-in, duck-typed hooks (getattr(module, ...)) for syncing fp8/mxfp8 gradient stores; these are no-ops for standard modules.

Tests

New gloo/nccl distributed test (src/test/nn/parallel/distributed_test.py):

  • test_grad_parity — DDP-averaged gradients match a manual all-reduce reference.
  • test_no_sync_accumulationno_sync() defers the reduction and correctly accumulates gradients across micro-batches before a single reduce.

pytest src/test/nn/parallel/ passes on gloo (nccl auto-skips without a GPU); make checks (isort/black/ruff/mypy) clean.

🤖 Generated with Claude Code

Add a data-parallel module wrapper under `olmo_core.nn.parallel` that
all-reduces gradients across one or more process groups.

Unlike `torch.nn.parallel.DistributedDataParallel`, it accumulates
gradients into flat, pre-allocated bucket views (each parameter's
`.grad` is a view into a contiguous bucket) and supports assigning
different parameters to different process groups via
`param_process_group_fn`. This is useful when subsets of a model are
replicated over different device meshes (e.g. expert weights under
expert parallelism, which must be data-parallel-reduced only over the
ranks that hold the same experts).

Bucketed all-reduces are launched and overlapped with the backward
pass as each bucket fills; `finalize_grad_reduce()` is called after
`loss.backward()` to wait for them to complete. Gradients can
optionally be accumulated and/or reduced in fp32.

Includes a gloo/nccl distributed test covering gradient parity against
a manual all-reduce reference and `no_sync()` accumulation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@AkshitaB AkshitaB requested a review from TianhuaTao May 27, 2026 04:07
@AkshitaB AkshitaB marked this pull request as draft May 27, 2026 04:20
AkshitaB and others added 2 commits June 2, 2026 18:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant