diff --git a/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py b/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py index 23f4927d6..b3b70cd9f 100644 --- a/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py +++ b/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py @@ -42,8 +42,12 @@ from triton_dist.kernels.nvidia.gemm_perf_model import estimate_gemm_sol_time_ms from triton_dist.nv_utils import get_intranode_max_speed_gbps - ################### context ################### + +# fall back to cuBLAS + NCCL when overlap overhead > benefit +_SMALL_GEMM_FLOPS_THRESHOLD = 1e11 # 100 GFLOPS + + @dataclasses.dataclass class GEMMReduceScatterTensorParallelContext: rs_ctx: ReduceScatter2DContext @@ -58,6 +62,9 @@ class GEMMReduceScatterTensorParallelContext: # gemm kernel config num_gemm_sms: int + # process group for cuBLAS + NCCL fast path + tp_group: Optional[torch.distributed.ProcessGroup] = None + def finalize(self): self.rs_ctx.finalize() nvshmem_free_tensor_sync(self.gemm_out_bufs[self.rs_ctx.local_rank]) @@ -68,9 +75,17 @@ def get_gemm_out_buf(self, input): return self.gemm_out_bufs[local_rank][:M] -def create_gemm_rs_context(max_M, N, rank, world_size, local_world_size, output_dtype: torch.dtype, - rs_stream: torch.cuda.Stream, - reduce_st: bool = False) -> GEMMReduceScatterTensorParallelContext: +def create_gemm_rs_context( + max_M, + N, + rank, + world_size, + local_world_size, + output_dtype: torch.dtype, + rs_stream: torch.cuda.Stream, + reduce_st: bool = False, + tp_group: Optional[torch.distributed.ProcessGroup] = None, +) -> GEMMReduceScatterTensorParallelContext: rs_ctx = create_reduce_scater_2d_ctx(max_M, N, rank, world_size, local_world_size, output_dtype) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count @@ -79,7 +94,7 @@ def create_gemm_rs_context(max_M, N, rank, world_size, local_world_size, output_ gemm_out_bufs = nvshmem_create_tensors((max_M // world_size if reduce_st else max_M, N), output_dtype, rank, local_world_size) ctx = GEMMReduceScatterTensorParallelContext(rs_ctx=rs_ctx, output_dtype=output_dtype, gemm_out_bufs=gemm_out_bufs, - rs_stream=rs_stream, num_gemm_sms=num_gemm_sms) + rs_stream=rs_stream, num_gemm_sms=num_gemm_sms, tp_group=tp_group) nvshmem_barrier_all_on_stream(torch.cuda.current_stream()) return ctx @@ -619,6 +634,18 @@ def gemm_rs_op(A: torch.Tensor, B: torch.Tensor, ctx: GEMMReduceScatterTensorPar assert M % world_size == 0 M_per_rank = M // world_size + + # fast path: cuBLAS + NCCL for small GEMMs where overlap overhead > benefit + gemm_flops = 2 * M * N * local_K + if (gemm_flops < _SMALL_GEMM_FLOPS_THRESHOLD and ctx.tp_group is not None and not reduce_st + and A.dtype.is_floating_point): + output = torch.empty((M_per_rank, N), dtype=output_dtype, device=A.device) + gemm_out = torch.matmul(A, B) + if gemm_out.dtype != output_dtype: + gemm_out = gemm_out.to(output_dtype) + torch.distributed.reduce_scatter_tensor(output, gemm_out, group=ctx.tp_group) + return output + current_stream = torch.cuda.current_stream() rs_stream.wait_stream(current_stream) diff --git a/python/triton_dist/layers/nvidia/tp_mlp.py b/python/triton_dist/layers/nvidia/tp_mlp.py index 55b90eb69..f299b05da 100644 --- a/python/triton_dist/layers/nvidia/tp_mlp.py +++ b/python/triton_dist/layers/nvidia/tp_mlp.py @@ -110,6 +110,7 @@ def _init_ctx(self, max_M, ag_intranode_stream: torch.cuda.Stream | None = None, local_world_size=self.world_size, output_dtype=self.dtype, rs_stream=ag_intranode_stream, + tp_group=self.group, ) nvshmem_barrier_all_on_stream(torch.cuda.current_stream()) torch.cuda.synchronize() diff --git a/python/triton_dist/test/nvidia/test_gemm_rs.py b/python/triton_dist/test/nvidia/test_gemm_rs.py index e3e510358..3c454307b 100644 --- a/python/triton_dist/test/nvidia/test_gemm_rs.py +++ b/python/triton_dist/test/nvidia/test_gemm_rs.py @@ -109,7 +109,7 @@ def __init__( self.rs_stream: torch.cuda.Stream = torch.cuda.Stream(priority=-1) self.ctx = create_gemm_rs_context(max_M, N, self.rank, self.world_size, self.local_world_size, output_dtype, - self.rs_stream, reduce_st) + self.rs_stream, reduce_st, tp_group=tp_group) self.reduce_st = reduce_st self.fuse_scatter = fuse_scatter self.persistent = persistent