Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@
from triton_dist.kernels.nvidia.gemm_perf_model import estimate_gemm_sol_time_ms
from triton_dist.nv_utils import get_intranode_max_speed_gbps


################### context ###################

# fall back to cuBLAS + NCCL when overlap overhead > benefit
_SMALL_GEMM_FLOPS_THRESHOLD = 1e11 # 100 GFLOPS


@dataclasses.dataclass
class GEMMReduceScatterTensorParallelContext:
rs_ctx: ReduceScatter2DContext
Expand All @@ -58,6 +62,9 @@ class GEMMReduceScatterTensorParallelContext:
# gemm kernel config
num_gemm_sms: int

# process group for cuBLAS + NCCL fast path
tp_group: Optional[torch.distributed.ProcessGroup] = None

def finalize(self):
self.rs_ctx.finalize()
nvshmem_free_tensor_sync(self.gemm_out_bufs[self.rs_ctx.local_rank])
Expand All @@ -68,9 +75,17 @@ def get_gemm_out_buf(self, input):
return self.gemm_out_bufs[local_rank][:M]


def create_gemm_rs_context(max_M, N, rank, world_size, local_world_size, output_dtype: torch.dtype,
rs_stream: torch.cuda.Stream,
reduce_st: bool = False) -> GEMMReduceScatterTensorParallelContext:
def create_gemm_rs_context(
max_M,
N,
rank,
world_size,
local_world_size,
output_dtype: torch.dtype,
rs_stream: torch.cuda.Stream,
reduce_st: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> GEMMReduceScatterTensorParallelContext:
rs_ctx = create_reduce_scater_2d_ctx(max_M, N, rank, world_size, local_world_size, output_dtype)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count

Expand All @@ -79,7 +94,7 @@ def create_gemm_rs_context(max_M, N, rank, world_size, local_world_size, output_
gemm_out_bufs = nvshmem_create_tensors((max_M // world_size if reduce_st else max_M, N), output_dtype, rank,
local_world_size)
ctx = GEMMReduceScatterTensorParallelContext(rs_ctx=rs_ctx, output_dtype=output_dtype, gemm_out_bufs=gemm_out_bufs,
rs_stream=rs_stream, num_gemm_sms=num_gemm_sms)
rs_stream=rs_stream, num_gemm_sms=num_gemm_sms, tp_group=tp_group)
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
return ctx

Expand Down Expand Up @@ -619,6 +634,18 @@ def gemm_rs_op(A: torch.Tensor, B: torch.Tensor, ctx: GEMMReduceScatterTensorPar

assert M % world_size == 0
M_per_rank = M // world_size

# fast path: cuBLAS + NCCL for small GEMMs where overlap overhead > benefit
gemm_flops = 2 * M * N * local_K
if (gemm_flops < _SMALL_GEMM_FLOPS_THRESHOLD and ctx.tp_group is not None and not reduce_st
and A.dtype.is_floating_point):
output = torch.empty((M_per_rank, N), dtype=output_dtype, device=A.device)
Comment thread
yxs marked this conversation as resolved.
gemm_out = torch.matmul(A, B)
Comment thread
yxs marked this conversation as resolved.
if gemm_out.dtype != output_dtype:
gemm_out = gemm_out.to(output_dtype)
torch.distributed.reduce_scatter_tensor(output, gemm_out, group=ctx.tp_group)
Comment on lines +642 to +646
return output

current_stream = torch.cuda.current_stream()
rs_stream.wait_stream(current_stream)

Expand Down
1 change: 1 addition & 0 deletions python/triton_dist/layers/nvidia/tp_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _init_ctx(self, max_M, ag_intranode_stream: torch.cuda.Stream | None = None,
local_world_size=self.world_size,
output_dtype=self.dtype,
rs_stream=ag_intranode_stream,
tp_group=self.group,
)
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
torch.cuda.synchronize()
Expand Down
2 changes: 1 addition & 1 deletion python/triton_dist/test/nvidia/test_gemm_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
self.rs_stream: torch.cuda.Stream = torch.cuda.Stream(priority=-1)

self.ctx = create_gemm_rs_context(max_M, N, self.rank, self.world_size, self.local_world_size, output_dtype,
self.rs_stream, reduce_st)
self.rs_stream, reduce_st, tp_group=tp_group)
self.reduce_st = reduce_st
self.fuse_scatter = fuse_scatter
self.persistent = persistent
Expand Down
Loading