Skip to content

Add cuBLAS+NCCL fast path for small GEMM in GEMM+ReduceScatter#166

Open
yxs wants to merge 2 commits intoByteDance-Seed:mainfrom
yxs:feat/gemm-rs-small-gemm-fast-path
Open

Add cuBLAS+NCCL fast path for small GEMM in GEMM+ReduceScatter#166
yxs wants to merge 2 commits intoByteDance-Seed:mainfrom
yxs:feat/gemm-rs-small-gemm-fast-path

Conversation

@yxs
Copy link
Copy Markdown

@yxs yxs commented Apr 13, 2026

Summary

For small GEMM shapes, the persistent kernel's overlap mechanism (TMA descriptors, barriers, signals) has fixed overhead that exceeds the overlap benefit, making gemm_rs slower than plain PyTorch (torch.matmul + reduce_scatter_tensor).

This PR adds a fast path: when per-GPU GEMM FLOPs < 100 GFLOPS, bypass the overlap kernel and fall back to cuBLAS GEMM + NCCL ReduceScatter directly.

Benchmark (8×H100 80GB, NV18 full mesh)

Small shape: M=4096, N=4096, K=8192

GPUs Before After PyTorch Speedup vs Before
2 0.259 ms 0.256 ms 0.275 ms — (above threshold)
4 0.286 ms 0.231 ms 0.205 ms +24%
8 0.323 ms 0.167 ms 0.167 ms +94%

Large shape: M=8192, N=8192, K=29568 (LLaMA-3.1-70B)

GPUs Before After PyTorch Speedup vs Before
2 2.886 ms 2.902 ms 3.077 ms no regression
4 1.548 ms 1.575 ms 1.707 ms no regression
8 0.855 ms 0.851 ms 1.045 ms no regression

The threshold (100 GFLOPS) is the crossover point where the overlap kernel stops being profitable, determined by benchmarking on 8×H100:

  • 2×4096×4096×4096 = 137 GFLOPS → overlap wins (1.06× vs PyTorch)
  • 2×4096×4096×2048 = 69 GFLOPS → overlap loses (0.71× vs PyTorch)

100 GFLOPS is set between these two data points.

The fast path is backward-compatible: it only activates when tp_group is provided and GEMM FLOPs < threshold. Existing callers that don't pass tp_group are unaffected.

  • Correctness: 2/4/8 GPU × small/large shapes × autotune/fuse_scatter/float16
  • Performance: no regression on large shapes

For small GEMM shapes (FLOPs < 100 GFLOPS per GPU), the persistent
kernel's fixed overhead (TMA descriptors, barriers, signals) exceeds
the overlap benefit. On H100, this caused 4-GPU to be 40% slower and
8-GPU to be 93% slower than PyTorch for M=4096 N=4096 K=8192.

Fall back to cuBLAS matmul + NCCL reduce_scatter when the GEMM is too
small for overlap to be profitable. Benchmarked on 8xH100 (NV18):
- 4-GPU small shape: 0.286ms -> 0.231ms (+24%)
- 8-GPU small shape: 0.323ms -> 0.167ms (+94%)
- Large shapes (LLaMA-70B): no regression
Copilot AI review requested due to automatic review settings April 13, 2026 04:28
@yxs
Copy link
Copy Markdown
Author

yxs commented Apr 13, 2026

cc @KnowingNothing

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a “small GEMM” fast path for the GEMM+ReduceScatter operation: when the per-rank GEMM workload is below a FLOPs threshold, it bypasses the persistent overlap kernel and instead runs cuBLAS-backed torch.matmul followed by NCCL-backed reduce_scatter_tensor, improving performance for small shapes while keeping large-shape behavior unchanged.

Changes:

  • Add an FLOPs-based threshold and a cuBLAS+NCCL fallback path in gemm_rs_op.
  • Extend GEMMReduceScatterTensorParallelContext / create_gemm_rs_context to carry an optional tp_group used by the fallback path.
  • Wire tp_group through representative call sites (TP MLP and the NVIDIA GEMM-RS test harness).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py Adds tp_group to context and implements the FLOPs-threshold cuBLAS+NCCL fast path in gemm_rs_op.
python/triton_dist/layers/nvidia/tp_mlp.py Passes the tensor-parallel process group into create_gemm_rs_context so the fast path can activate.
python/triton_dist/test/nvidia/test_gemm_rs.py Updates context creation to pass tp_group, enabling coverage/exercise of the new fast path in this harness.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py
Comment thread python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py
torch.matmul with int8 accumulates in int8 and wraps on overflow,
while the Triton kernel accumulates in fp32. Restrict the fast path
to floating-point dtypes only.
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an opt-in “small GEMM” fast path for NVIDIA GEMM+ReduceScatter that bypasses the persistent overlap kernel overhead by falling back to cuBLAS GEMM + NCCL reduce_scatter_tensor when the per-rank GEMM work is below a FLOP-count threshold.

Changes:

  • Extend GEMMReduceScatterTensorParallelContext / create_gemm_rs_context to carry an optional tp_group used by the NCCL fast path.
  • Add a FLOP-threshold-based branch in gemm_rs_op that runs torch.matmul + torch.distributed.reduce_scatter_tensor.
  • Wire tp_group through a TP MLP layer and the NVIDIA GEMM-RS test harness.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py Adds tp_group to context and introduces cuBLAS+NCCL fast path guarded by a small-GEMM threshold.
python/triton_dist/layers/nvidia/tp_mlp.py Passes the TP process group into the GEMM-RS context so the fast path can be used.
python/triton_dist/test/nvidia/test_gemm_rs.py Passes tp_group into context creation for test coverage / benchmarking.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +642 to +646
output = torch.empty((M_per_rank, N), dtype=output_dtype, device=A.device)
gemm_out = torch.matmul(A, B)
if gemm_out.dtype != output_dtype:
gemm_out = gemm_out.to(output_dtype)
torch.distributed.reduce_scatter_tensor(output, gemm_out, group=ctx.tp_group)
################### context ###################

# fall back to cuBLAS + NCCL when overlap overhead > benefit
_SMALL_GEMM_FLOPS_THRESHOLD = 1e11 # 100 GFLOPS
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants