zero3: SDMA allgather via mori (sdma_allgather) #7999
zero3: SDMA allgather via mori (sdma_allgather) #7999inkcherry wants to merge 21 commits intodeepspeedai:masterfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 66673546b5
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
|
||
|
|
||
| def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None): | ||
| work = mori.allgather_into_tensor(input_tensor, output_tensor) |
There was a problem hiding this comment.
Honor ZeRO's process group before using SDMA
When ZeRO is initialized with a non-WORLD data-parallel group, or with a secondary zero-param group, _all_gather passes that group down as ds_process_group (partition_parameters.py:1463-1471), but this new SDMA call ignores the group argument and uses mori's WORLD-backed default process group. In those model/tensor-parallel configurations mori gathers from more ranks than the caller allocated output_tensor for, which can corrupt fetched parameters or write past the expected buffer; fall back unless group is WORLD or make mori initialize/use the matching group.
Useful? React with 👍 / 👎.
| handle = instrument_w_nvtx(dist.allgather_fn)( | ||
| flat_tensor, partitions[rank_in_group], group=ds_process_group, async_op=True) |
There was a problem hiding this comment.
Route coalesced allgathers through the SDMA wrapper
With the default stage3_allgather_sequential=false, any ZeRO-3 fetch containing more than one parameter takes _all_gather_dtype, but this path now calls dist.allgather_fn directly instead of _dist_allgather_fn. As a result, enabling sdma_allgather has no effect for the common coalesced prefetch path (including the added sample config, which does not enable sequential allgather), so the advertised optimization is skipped for most multi-parameter buckets.
Useful? React with 👍 / 👎.
Signed-off-by: wuyl1 <yangwu@amd.com>
Signed-off-by: wuyl1 <yangwu@amd.com>
Signed-off-by: wuyl1 <yangwu@amd.com>
Signed-off-by: wuyl1 <yangwu@amd.com>
Signed-off-by: wuyl1 <yangwu@amd.com>
Signed-off-by: wuyl1 <yangwu@amd.com>
Signed-off-by: wuyl1 <yangwu@amd.com>
Signed-off-by: wuyl1 <yangwu@amd.com>
Signed-off-by: wuyl1 <yangwu@amd.com>
Signed-off-by: wuyl1 <yangwu@amd.com>
Move all mori-specific code (handle, dtype map, transit buffer sizing,
PG-registration helper, Work wrapper) out of partition_parameters.py
into a dedicated runtime/comm backend module:
deepspeed/runtime/comm/mori.py
mori.init(max_numel) # one-shot, idempotent, exception-safe
mori.is_enabled() # cheap predicate
mori.allgather_into_tensor(in, out)
-> Work-compatible object on success, None on fallback
The new backend uses mori_cpp.AllGatherIntoTensor (NCCL/RCCL-style
flat->flat C++ dispatcher) instead of the old mori.ccl.AllgatherSdma
templated Python wrapper, so DeepSpeed no longer has to pre-convert
numel into uint32 lane counts or template the C++ class on dtype.
partition_parameters.py is now agnostic to the SDMA path:
def _dist_allgather_fn(input_tensor, output_tensor, group=None):
work = mori.allgather_into_tensor(input_tensor, output_tensor)
if work is not None:
return work
return instrument_w_nvtx(dist.allgather_fn)(...)
Init failure (mori missing, non-AMD/ROCm runtime, shmem init error)
leaves the handle unset and logs a single rank-0 warning, so the SDMA
path silently no-ops and dist.allgather_fn (RCCL/NCCL) takes over —
no hard fail.
Net change: partition_parameters.py shrinks by 79 lines; one new
self-contained module under runtime/comm/.
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
_SdmaWork.wait() previously blocked the CPU on _event.synchronize()
before issuing the stream-level dependency. RCCL's Work.wait() only
records a stream-level wait (cudaStreamWaitEvent / hipStreamWaitEvent)
and does NOT block the CPU, which is what the ZeRO-3 prefetch pipeline
relies on: while bucket N is in flight on the GPU, the CPU is free to
queue bucket N+1 so it can overlap with the trailing compute of N.
The CPU-blocking variant turned out to be a per-step critical-path tax
that wiped out SDMA's headroom on workloads that issue many small
allgathers per step. Concretely, on Qwen3-32B + ZeRO-3 + seq_len=128,
8x MI300X, ~6400 prefetch buckets per step:
before: SDMA 1014 ms / step (1009 tok/s)
RCCL 932 ms / step (1099 tok/s) -> SDMA -8.0%
after: SDMA 927 ms / step (1104 tok/s)
RCCL 929 ms / step (1100 tok/s) -> within noise
Loss curve is bit-identical with and without the CPU sync, so this is
purely a CPU-pipelining fix. is_completed() is unchanged (it polls
via _event.query() without blocking, same as before).
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Move the zero3_overlap demo dir into examples/sdma_allgather/ (the name
that matches the feature being demoed) and add a Qwen3-32B + ZeRO-3
trainer that reproduces the +9.93% end-to-end speedup of this PR on
8x MI300X with the default DeepSpeed bucket sizes.
Layout:
ds_config_zero3_{sdma,nosdma}.json ZeRO-3 + bf16 + DS-default buckets
run_gpt_sdma_{on,off}.sh GPT-7B-ish demo (existing trainer)
run_qwen3_sdma_{on,off}.sh Qwen3-32B demo (new trainer)
train_qwen3_zero3.py self-contained Qwen3 trainer
README.md feature overview + repro steps
train_zero3.py unchanged (renamed only)
test_sdma_allgather_zero3.py unchanged (renamed only)
train_qwen3_zero3.py inlines a minimal wikitext-103 dataloader so the
benchmark has no dependency on external benchmark repos. Loading via
AutoConfig + from_config keeps the example weight-free; only the model
config and tokenizer are pulled from HuggingFace.
The configs use DeepSpeed's default ZeRO-3 bucket sizes
(stage3_prefetch_bucket_size = 5e7, etc.) so the published numbers
in README.md are reproducible without any tuning.
Verified on 8x MI300X, two fresh rounds:
Qwen3-32B + ZeRO-3 + DP=8, seq_len=1024, micro_bs=1, 100 steps
SDMA off : 1402.5 ms / step (5841 tok/s)
SDMA on : 1263.2 ms / step (6486 tok/s) -> +9.93% e2e
GPT-7B + ZeRO-3 + DP=8, 100 steps -> +5.9% e2e
Loss curves match across the two backends, peak memory is identical
(96.45 GB), per-step jitter is 1.4-2.7%, so the ~140 ms gap is well
above noise.
Drops:
examples/zero3_overlap/run.sh superseded by run_gpt_*
examples/zero3_overlap/ds_config_zero3.json superseded by *_sdma.json
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
Summary
RFC: #7884
Wire
sdma_allgatherinto ZeRO-3's parameter prefetch path(
_dist_allgather_fn). When enabled, ZeRO-3 allgather routes throughmori_cpp.AllGatherIntoTensor(intra-node SDMA copy on AMD MI300), with atransparent fallback to
dist.allgather_fn(RCCL/NCCL) on init failure.End-to-end demo + repro steps + verified numbers live in
examples/sdma_allgather/README.md.Headline (8x MI300X, DeepSpeed default ZeRO-3 buckets, 100 steps):
Loss curves match off ↔ on, peak memory unchanged.
Speedup is workload-dependent — gains shrink (or invert) when allgather can't be overlapped with compute
Co-authored-by: wuyl1 yangwu@amd.com