Skip to content

Commit b6c61c5

Browse files
kwen2501haijieg
authored andcommitted
Fused all-gather matmul
Argparser Add test Signed-off-by: Ke Wen <kwen@nvidia.com>
1 parent 1ad7689 commit b6c61c5

File tree

4 files changed

+392
-3
lines changed

4 files changed

+392
-3
lines changed

changelog.d/fix-rewrite-pattern.md

Whitespace-only changes.

samples/AllGatherMatmul.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Example demonstrating all-gather and matrix multiplication in a single kernel.
7+
8+
Run with:
9+
python AllGatherMatmul.py --correctness-check
10+
11+
Algorithm:
12+
Each rank has a local input tensor of size (M, K), and a weight tensor of size (K, N).
13+
We want to compute the output tensor of size (M * world_size, N), where each
14+
"slice" of size (M, N) is the result of the matrix multiplication of a peer input tensor
15+
and the weight tensor.
16+
"""
17+
18+
import argparse
19+
import random
20+
import torch
21+
import torch.distributed as dist
22+
import torch.distributed._symmetric_memory as symm_mem
23+
import torch.multiprocessing as mp
24+
import cuda.tile as ct
25+
26+
27+
# cuTile kernel for gather then matmul
28+
@ct.kernel
29+
def gather_matmul_kernel(
30+
inp_list,
31+
w,
32+
out,
33+
tile_m: ct.Constant[int],
34+
tile_n: ct.Constant[int],
35+
tile_k: ct.Constant[int],
36+
):
37+
# Number of m tiles per peer
38+
peer_inp_size_m = inp_list[0].shape[0]
39+
num_tiles_m_per_peer = ct.cdiv(peer_inp_size_m, tile_m)
40+
num_tiles_k = ct.num_tiles(w, axis=0, shape=(tile_k, tile_n))
41+
42+
# 0-dim maps to m_tile_idx, 1-dim maps to n_tile_idx, of out tensor
43+
m_tile_idx = ct.bid(0)
44+
n_tile_idx = ct.bid(1)
45+
46+
# Which peer should this tile get input from?
47+
peer = m_tile_idx // num_tiles_m_per_peer
48+
# Select ct.Array from inp_list
49+
peer_inp = inp_list[peer]
50+
m_tile_idx_in_peer = m_tile_idx % num_tiles_m_per_peer
51+
52+
# Initialize accumulator
53+
accumulator = ct.full((tile_m, tile_n), 0, dtype=ct.float32)
54+
zero_pad = ct.PaddingMode.ZERO
55+
56+
# Convert fp32 to tf32 to use tensorcore
57+
dtype = ct.tfloat32 if peer_inp.dtype == ct.float32 else peer_inp.dtype
58+
59+
for k in range(num_tiles_k):
60+
# Load remote input tile
61+
a = ct.load(
62+
peer_inp,
63+
index=(m_tile_idx_in_peer, k),
64+
shape=(tile_m, tile_k),
65+
padding_mode=zero_pad,
66+
).astype(dtype)
67+
# Load weight tile
68+
b = ct.load(
69+
w,
70+
index=(k, n_tile_idx),
71+
shape=(tile_k, tile_n),
72+
padding_mode=zero_pad,
73+
).astype(dtype)
74+
# Perform matrix multiplication
75+
accumulator = ct.mma(a, b, accumulator)
76+
77+
# Cast result back to output dtype
78+
accumulator = ct.astype(accumulator, out.dtype)
79+
80+
# Store result tile
81+
ct.store(out, index=(m_tile_idx, n_tile_idx), tile=accumulator)
82+
83+
84+
# Host-side launcher for all-gather
85+
def cutile_gather_matmul(
86+
inp: torch.Tensor,
87+
w: torch.Tensor,
88+
group: dist.ProcessGroup,
89+
):
90+
handle = symm_mem.rendezvous(inp, group.group_name)
91+
world_size = handle.world_size
92+
inp_list = [
93+
handle.get_buffer(rank, inp.shape, inp.dtype, 0) for rank in range(world_size)
94+
]
95+
96+
# Allocate output tensor
97+
M = inp.shape[0]
98+
M_out = M * world_size
99+
N = w.shape[1]
100+
out = torch.empty(M_out, N, device=inp.device)
101+
102+
assert inp.shape[1] == w.shape[0], "reduction dimension mismatch"
103+
K = inp.shape[1]
104+
tile_m = 128
105+
tile_n = 128
106+
tile_k = 128
107+
assert M % tile_m == 0
108+
assert N % tile_n == 0
109+
assert K % tile_k == 0
110+
111+
# Map each output tile to a block
112+
grid = (ct.cdiv(M_out, tile_m), ct.cdiv(N, tile_n),)
113+
ct.launch(
114+
torch.cuda.current_stream(),
115+
grid,
116+
gather_matmul_kernel,
117+
(inp_list, w, out, tile_m, tile_n, tile_k),
118+
)
119+
120+
return out
121+
122+
123+
# Reference gather then matmul implementation
124+
def ref_gather_matmul(
125+
inp: torch.Tensor,
126+
w: torch.Tensor,
127+
group: dist.ProcessGroup,
128+
):
129+
world_size = dist.get_world_size(group)
130+
ag_scratch = torch.empty((world_size * inp.shape[0], inp.shape[1]), device=inp.device)
131+
dist.all_gather_into_tensor(ag_scratch, inp, group=group)
132+
out = ag_scratch @ w
133+
return out
134+
135+
136+
def test(rank: int, world_size: int, args: argparse.Namespace, port: int):
137+
print(f"Rank {rank} of {world_size} is initializing")
138+
device = torch.device(f"cuda:{rank}")
139+
dist.init_process_group(
140+
backend="nccl",
141+
init_method=f"tcp://localhost:{port}",
142+
rank=rank,
143+
world_size=world_size,
144+
device_id=device,
145+
)
146+
group = dist.group.WORLD
147+
torch.manual_seed(rank + 52)
148+
149+
bs = 256
150+
hid = 1024
151+
out_hid = 512
152+
ref_inp = torch.rand((bs, hid), device=device)
153+
inp = symm_mem.empty(bs, hid, device=device).copy_(ref_inp)
154+
w = torch.rand((hid, out_hid), device=device)
155+
156+
# Make sure all ranks have initialized their inputs
157+
dist.barrier(group)
158+
159+
out = cutile_gather_matmul(inp, w, group)
160+
161+
if args.correctness_check:
162+
expected_out = ref_gather_matmul(ref_inp, w, group)
163+
torch.testing.assert_close(
164+
out,
165+
expected_out,
166+
atol=1e-3,
167+
rtol=1e-3,
168+
msg=f"Rank {rank} of {world_size}: Correctness check failed",
169+
)
170+
print(f"Rank {rank} of {world_size}: Correctness check passed")
171+
else:
172+
if rank == 0:
173+
print("Correctness check disabled")
174+
175+
dist.destroy_process_group()
176+
177+
178+
if __name__ == "__main__":
179+
parser = argparse.ArgumentParser()
180+
parser.add_argument(
181+
"--correctness-check",
182+
action="store_true",
183+
help="Check the correctness of the results",
184+
)
185+
args = parser.parse_args()
186+
187+
if dist.is_nccl_available():
188+
# IP port number for multi-process rendezvous
189+
port = random.randint(30000, 60000)
190+
world_size = torch.cuda.device_count()
191+
mp.spawn(test, args=(world_size, args, port), nprocs=world_size, join=True)
192+
else:
193+
print("Skipped test: NCCL backend is not available")
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Example demonstrating all-gather and matrix multiplication in a single kernel.
7+
8+
Run with:
9+
python AllGatherMatmul.py --correctness-check
10+
11+
Algorithm:
12+
Each rank has a local input tensor of size (M, K), and a weight tensor of size (K, N).
13+
We want to compute the output tensor of size (M * world_size, N), where each
14+
"slice" of size (M, N) is the result of the matrix multiplication of a peer input tensor
15+
and the weight tensor.
16+
"""
17+
18+
import argparse
19+
import random
20+
import torch
21+
import torch.distributed as dist
22+
import torch.distributed._symmetric_memory as symm_mem
23+
import torch.multiprocessing as mp
24+
import cuda.tile as ct
25+
26+
27+
# cuTile kernel for gather then matmul
28+
@ct.kernel
29+
def gather_matmul_kernel(
30+
inp_list,
31+
w,
32+
out,
33+
tile_m: ct.Constant[int],
34+
tile_n: ct.Constant[int],
35+
tile_k: ct.Constant[int],
36+
):
37+
# Number of m tiles per peer
38+
peer_inp_size_m = inp_list[0].shape[0]
39+
num_tiles_m_per_peer = ct.cdiv(peer_inp_size_m, tile_m)
40+
num_tiles_k = ct.num_tiles(w, axis=0, shape=(tile_k, tile_n))
41+
42+
# 0-dim maps to m_tile_idx, 1-dim maps to n_tile_idx, of out tensor
43+
m_tile_idx = ct.bid(0)
44+
n_tile_idx = ct.bid(1)
45+
46+
# Which peer should this tile get input from?
47+
peer = m_tile_idx // num_tiles_m_per_peer
48+
# Select ct.Array from inp_list
49+
peer_inp = inp_list[peer]
50+
m_tile_idx_in_peer = m_tile_idx % num_tiles_m_per_peer
51+
52+
# Initialize accumulator
53+
accumulator = ct.full((tile_m, tile_n), 0, dtype=ct.float32)
54+
zero_pad = ct.PaddingMode.ZERO
55+
56+
# Convert fp32 to tf32 to use tensorcore
57+
dtype = ct.tfloat32 if peer_inp.dtype == ct.float32 else peer_inp.dtype
58+
59+
for k in range(num_tiles_k):
60+
# Load remote input tile
61+
a = ct.load(
62+
peer_inp,
63+
index=(m_tile_idx_in_peer, k),
64+
shape=(tile_m, tile_k),
65+
padding_mode=zero_pad,
66+
).astype(dtype)
67+
# Load weight tile
68+
b = ct.load(
69+
w,
70+
index=(k, n_tile_idx),
71+
shape=(tile_k, tile_n),
72+
padding_mode=zero_pad,
73+
).astype(dtype)
74+
# Perform matrix multiplication
75+
accumulator = ct.mma(a, b, accumulator)
76+
77+
# Cast result back to output dtype
78+
accumulator = ct.astype(accumulator, out.dtype)
79+
80+
# Store result tile
81+
ct.store(out, index=(m_tile_idx, n_tile_idx), tile=accumulator)
82+
83+
84+
# Host-side launcher for all-gather
85+
def cutile_gather_matmul(
86+
inp: torch.Tensor,
87+
w: torch.Tensor,
88+
group: dist.ProcessGroup,
89+
):
90+
handle = symm_mem.rendezvous(inp, group.group_name)
91+
world_size = handle.world_size
92+
inp_list = [
93+
handle.get_buffer(rank, inp.shape, inp.dtype, 0) for rank in range(world_size)
94+
]
95+
96+
# Allocate output tensor
97+
M = inp.shape[0]
98+
M_out = M * world_size
99+
N = w.shape[1]
100+
out = torch.empty(M_out, N, device=inp.device)
101+
102+
assert inp.shape[1] == w.shape[0], "reduction dimension mismatch"
103+
K = inp.shape[1]
104+
tile_m = 128
105+
tile_n = 128
106+
tile_k = 128
107+
assert M % tile_m == 0
108+
assert N % tile_n == 0
109+
assert K % tile_k == 0
110+
111+
# Map each output tile to a block
112+
grid = (ct.cdiv(M_out, tile_m), ct.cdiv(N, tile_n),)
113+
ct.launch(
114+
torch.cuda.current_stream(),
115+
grid,
116+
gather_matmul_kernel,
117+
(inp_list, w, out, tile_m, tile_n, tile_k),
118+
)
119+
120+
return out
121+
122+
123+
# Reference gather then matmul implementation
124+
def ref_gather_matmul(
125+
inp: torch.Tensor,
126+
w: torch.Tensor,
127+
group: dist.ProcessGroup,
128+
):
129+
world_size = dist.get_world_size(group)
130+
ag_scratch = torch.empty((world_size * inp.shape[0], inp.shape[1]), device=inp.device)
131+
dist.all_gather_into_tensor(ag_scratch, inp, group=group)
132+
out = ag_scratch @ w
133+
return out
134+
135+
136+
def test(rank: int, world_size: int, args: argparse.Namespace, port: int):
137+
print(f"Rank {rank} of {world_size} is initializing")
138+
device = torch.device(f"cuda:{rank}")
139+
dist.init_process_group(
140+
backend="nccl",
141+
init_method=f"tcp://localhost:{port}",
142+
rank=rank,
143+
world_size=world_size,
144+
device_id=device,
145+
)
146+
group = dist.group.WORLD
147+
torch.manual_seed(rank + 52)
148+
149+
bs = 256
150+
hid = 1024
151+
out_hid = 512
152+
ref_inp = torch.rand((bs, hid), device=device)
153+
inp = symm_mem.empty(bs, hid, device=device).copy_(ref_inp)
154+
w = torch.rand((hid, out_hid), device=device)
155+
156+
# Make sure all ranks have initialized their inputs
157+
dist.barrier(group)
158+
159+
out = cutile_gather_matmul(inp, w, group)
160+
161+
if args.correctness_check:
162+
expected_out = ref_gather_matmul(ref_inp, w, group)
163+
torch.testing.assert_close(
164+
out,
165+
expected_out,
166+
atol=1e-3,
167+
rtol=1e-3,
168+
msg=f"Rank {rank} of {world_size}: Correctness check failed",
169+
)
170+
print(f"Rank {rank} of {world_size}: Correctness check passed")
171+
else:
172+
if rank == 0:
173+
print("Correctness check disabled")
174+
175+
dist.destroy_process_group()
176+
177+
178+
if __name__ == "__main__":
179+
parser = argparse.ArgumentParser()
180+
parser.add_argument(
181+
"--correctness-check",
182+
action="store_true",
183+
help="Check the correctness of the results",
184+
)
185+
args = parser.parse_args()
186+
187+
if dist.is_nccl_available():
188+
# IP port number for multi-process rendezvous
189+
port = random.randint(30000, 60000)
190+
world_size = torch.cuda.device_count()
191+
mp.spawn(test, args=(world_size, args, port), nprocs=world_size, join=True)
192+
else:
193+
print("Skipped test: NCCL backend is not available")

0 commit comments

Comments
 (0)