Skip to content

Commit ec26b9d

Browse files
committed
Fused all-gather matmul
Signed-off-by: Ke Wen <kwen@nvidia.com>
1 parent 29ce019 commit ec26b9d

1 file changed

Lines changed: 159 additions & 0 deletions

File tree

samples/all_gather_matmul.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Example demonstrating all-gather and matrix multiplication in a single kernel.
7+
8+
Run with:
9+
torchrun --nproc-per-node 4 --standalone all_gather_matmul.py
10+
"""
11+
12+
import torch
13+
import torch.distributed as dist
14+
import torch.distributed._symmetric_memory as symm_mem
15+
import cuda.tile as ct
16+
17+
18+
# cuTile kernel for gather then matmul
19+
@ct.kernel
20+
def gather_matmul_kernel(
21+
inp_list,
22+
w,
23+
out,
24+
tile_m: ct.Constant[int],
25+
tile_n: ct.Constant[int],
26+
tile_k: ct.Constant[int],
27+
):
28+
# Number of m tiles per peer
29+
peer_inp_size_m = inp_list[0].shape[0]
30+
num_tiles_m_per_peer = ct.cdiv(peer_inp_size_m, tile_m)
31+
num_tiles_k = ct.num_tiles(w, axis=0, shape=(tile_k, tile_n))
32+
33+
# 0-dim maps to m_tile_idx, 1-dim maps to n_tile_idx, of out tensor
34+
m_tile_idx = ct.bid(0)
35+
n_tile_idx = ct.bid(1)
36+
37+
# Which peer should this tile get input from?
38+
peer = m_tile_idx // num_tiles_m_per_peer
39+
# Select ct.Array from inp_list
40+
peer_inp = inp_list[peer]
41+
m_tile_idx_in_peer = m_tile_idx % num_tiles_m_per_peer
42+
43+
# Initialize accumulator
44+
accumulator = ct.full((tile_m, tile_n), 0, dtype=ct.float32)
45+
zero_pad = ct.PaddingMode.ZERO
46+
47+
# Convert fp32 to tf32 to use tensorcore
48+
dtype = ct.tfloat32 if peer_inp.dtype == ct.float32 else peer_inp.dtype
49+
50+
for k in range(num_tiles_k):
51+
# Load remote input tile
52+
a = ct.load(
53+
peer_inp,
54+
index=(m_tile_idx_in_peer, k),
55+
shape=(tile_m, tile_k),
56+
padding_mode=zero_pad,
57+
).astype(dtype)
58+
# Load weight tile
59+
b = ct.load(
60+
w,
61+
index=(k, n_tile_idx),
62+
shape=(tile_k, tile_n),
63+
padding_mode=zero_pad,
64+
).astype(dtype)
65+
# Perform matrix multiplication
66+
accumulator = ct.mma(a, b, accumulator)
67+
68+
# Cast result back to output dtype
69+
accumulator = ct.astype(accumulator, out.dtype)
70+
71+
# Store result tile
72+
ct.store(out, index=(m_tile_idx, n_tile_idx), tile=accumulator)
73+
74+
75+
# Host-side launcher for all-gather
76+
def cutile_gather_matmul(
77+
inp: torch.Tensor,
78+
w: torch.Tensor,
79+
group: dist.ProcessGroup,
80+
):
81+
handle = symm_mem.rendezvous(inp, group.group_name)
82+
world_size = handle.world_size
83+
inp_list = [
84+
handle.get_buffer(rank, inp.shape, inp.dtype, 0) for rank in range(world_size)
85+
]
86+
87+
# Allocate output tensor
88+
M = inp.shape[0]
89+
M_out = M * world_size
90+
N = w.shape[1]
91+
out = torch.empty(M_out, N, device=inp.device)
92+
93+
assert inp.shape[1] == w.shape[0], "reduction dimension mismatch"
94+
K = inp.shape[1]
95+
tile_m = 128
96+
tile_n = 128
97+
tile_k = 128
98+
assert M % tile_m == 0
99+
assert N % tile_n == 0
100+
assert K % tile_k == 0
101+
102+
# Map each output tile to a block
103+
grid = (ct.cdiv(M_out, tile_m), ct.cdiv(N, tile_n),)
104+
ct.launch(
105+
torch.cuda.current_stream(),
106+
grid,
107+
gather_matmul_kernel,
108+
(inp_list, w, out, tile_m, tile_n, tile_k),
109+
)
110+
111+
return out
112+
113+
114+
# Reference gather then matmul implementation
115+
def ref_gather_matmul(
116+
inp: torch.Tensor,
117+
w: torch.Tensor,
118+
group: dist.ProcessGroup,
119+
):
120+
world_size = dist.get_world_size(group)
121+
ag_scratch = torch.empty((world_size * inp.shape[0], inp.shape[1]), device=inp.device)
122+
dist.all_gather_into_tensor(ag_scratch, inp, group=group)
123+
out = ag_scratch @ w
124+
return out
125+
126+
127+
def main():
128+
dist.init_process_group(backend="nccl")
129+
rank = dist.get_rank()
130+
world_size = dist.get_world_size()
131+
device = torch.device(f"cuda:{rank}")
132+
group = dist.group.WORLD
133+
torch.manual_seed(rank + 52)
134+
print(f"Rank {rank} of {world_size} is initializing")
135+
136+
bs = 256
137+
hid = 1024
138+
out_hid = 512
139+
ref_inp = torch.rand((bs, hid), device=device)
140+
inp = symm_mem.empty(bs, hid, device=device).copy_(ref_inp)
141+
w = torch.rand((hid, out_hid), device=device)
142+
143+
expected_out = ref_gather_matmul(ref_inp, w, group)
144+
145+
out = cutile_gather_matmul(inp, w, group)
146+
147+
torch.testing.assert_close(
148+
out,
149+
expected_out,
150+
atol=1e-3,
151+
rtol=1e-3,
152+
)
153+
154+
print(f"Rank {rank} of {world_size}: correct")
155+
dist.destroy_process_group()
156+
157+
158+
if __name__ == "__main__":
159+
main()

0 commit comments

Comments
 (0)