Add cuBLAS+NCCL fast path for small GEMM in GEMM+ReduceScatter#166
Add cuBLAS+NCCL fast path for small GEMM in GEMM+ReduceScatter#166yxs wants to merge 2 commits intoByteDance-Seed:mainfrom
Conversation
For small GEMM shapes (FLOPs < 100 GFLOPS per GPU), the persistent kernel's fixed overhead (TMA descriptors, barriers, signals) exceeds the overlap benefit. On H100, this caused 4-GPU to be 40% slower and 8-GPU to be 93% slower than PyTorch for M=4096 N=4096 K=8192. Fall back to cuBLAS matmul + NCCL reduce_scatter when the GEMM is too small for overlap to be profitable. Benchmarked on 8xH100 (NV18): - 4-GPU small shape: 0.286ms -> 0.231ms (+24%) - 8-GPU small shape: 0.323ms -> 0.167ms (+94%) - Large shapes (LLaMA-70B): no regression
There was a problem hiding this comment.
Pull request overview
This PR introduces a “small GEMM” fast path for the GEMM+ReduceScatter operation: when the per-rank GEMM workload is below a FLOPs threshold, it bypasses the persistent overlap kernel and instead runs cuBLAS-backed torch.matmul followed by NCCL-backed reduce_scatter_tensor, improving performance for small shapes while keeping large-shape behavior unchanged.
Changes:
- Add an FLOPs-based threshold and a cuBLAS+NCCL fallback path in
gemm_rs_op. - Extend
GEMMReduceScatterTensorParallelContext/create_gemm_rs_contextto carry an optionaltp_groupused by the fallback path. - Wire
tp_groupthrough representative call sites (TP MLP and the NVIDIA GEMM-RS test harness).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py | Adds tp_group to context and implements the FLOPs-threshold cuBLAS+NCCL fast path in gemm_rs_op. |
| python/triton_dist/layers/nvidia/tp_mlp.py | Passes the tensor-parallel process group into create_gemm_rs_context so the fast path can activate. |
| python/triton_dist/test/nvidia/test_gemm_rs.py | Updates context creation to pass tp_group, enabling coverage/exercise of the new fast path in this harness. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
torch.matmul with int8 accumulates in int8 and wraps on overflow, while the Triton kernel accumulates in fp32. Restrict the fast path to floating-point dtypes only.
There was a problem hiding this comment.
Pull request overview
Adds an opt-in “small GEMM” fast path for NVIDIA GEMM+ReduceScatter that bypasses the persistent overlap kernel overhead by falling back to cuBLAS GEMM + NCCL reduce_scatter_tensor when the per-rank GEMM work is below a FLOP-count threshold.
Changes:
- Extend
GEMMReduceScatterTensorParallelContext/create_gemm_rs_contextto carry an optionaltp_groupused by the NCCL fast path. - Add a FLOP-threshold-based branch in
gemm_rs_opthat runstorch.matmul+torch.distributed.reduce_scatter_tensor. - Wire
tp_groupthrough a TP MLP layer and the NVIDIA GEMM-RS test harness.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py | Adds tp_group to context and introduces cuBLAS+NCCL fast path guarded by a small-GEMM threshold. |
| python/triton_dist/layers/nvidia/tp_mlp.py | Passes the TP process group into the GEMM-RS context so the fast path can be used. |
| python/triton_dist/test/nvidia/test_gemm_rs.py | Passes tp_group into context creation for test coverage / benchmarking. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| output = torch.empty((M_per_rank, N), dtype=output_dtype, device=A.device) | ||
| gemm_out = torch.matmul(A, B) | ||
| if gemm_out.dtype != output_dtype: | ||
| gemm_out = gemm_out.to(output_dtype) | ||
| torch.distributed.reduce_scatter_tensor(output, gemm_out, group=ctx.tp_group) |
| ################### context ################### | ||
|
|
||
| # fall back to cuBLAS + NCCL when overlap overhead > benefit | ||
| _SMALL_GEMM_FLOPS_THRESHOLD = 1e11 # 100 GFLOPS |
Summary
For small GEMM shapes, the persistent kernel's overlap mechanism (TMA descriptors, barriers, signals) has fixed overhead that exceeds the overlap benefit, making
gemm_rsslower than plain PyTorch (torch.matmul+reduce_scatter_tensor).This PR adds a fast path: when per-GPU GEMM FLOPs < 100 GFLOPS, bypass the overlap kernel and fall back to cuBLAS GEMM + NCCL ReduceScatter directly.
Benchmark (8×H100 80GB, NV18 full mesh)
Small shape: M=4096, N=4096, K=8192
Large shape: M=8192, N=8192, K=29568 (LLaMA-3.1-70B)
The threshold (100 GFLOPS) is the crossover point where the overlap kernel stops being profitable, determined by benchmarking on 8×H100:
100 GFLOPS is set between these two data points.
The fast path is backward-compatible: it only activates when
tp_groupis provided and GEMM FLOPs < threshold. Existing callers that don't passtp_groupare unaffected.