Skip to content

Commit 07f67b6

Browse files
committed
[Gemm,Sm100] Implement TMA gather for varlen_k
1 parent 96000aa commit 07f67b6

4 files changed

Lines changed: 178 additions & 11 deletions

File tree

quack/copy_utils.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
import cutlass.pipeline
1414
from cutlass._mlir.dialects import llvm
1515
from cutlass._mlir import ir
16+
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
1617

18+
from quack import layout_utils
1719
from quack.utils import make_vector
18-
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
1920

2021

2122
Sm100MmaPeerBitMask = 0xFEFFFFFF
@@ -1023,16 +1024,83 @@ def gather_m_get_tma_copy_fn(
10231024
tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
10241025

10251026
def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer):
1027+
tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
10261028
col_idx = tile_K * src_idx
10271029
for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
10281030
row_indices = [tSR_rAIdx[v, m] for v in range(4)]
1029-
smem_ptr = tSR_sA[None, m, None, dst_idx].iterator
1031+
smem_ptr = tSR_sA_cur[None, m, None].iterator
10301032
with cute.arch.elect_one():
10311033
tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices)
10321034

10331035
return copy_fn
10341036

10351037

1038+
@cute.jit
1039+
def gather_k_get_tma_copy_fn(
1040+
tma_atom: cute.CopyAtom,
1041+
sA: cute.Tensor, # ((4, tile_K/4), (tile_M,), STAGE) — K-grouped load layout
1042+
sAIdx: cute.Tensor, # (tile_K, a_prefetch_stage) — K indices in smem
1043+
col_idx: Int32, # M offset in global tensor (contiguous dim for M-major)
1044+
warp_idx: Int32,
1045+
num_warps: int,
1046+
num_cta: int = 1,
1047+
) -> Tuple[Callable, Callable]:
1048+
"""Build a copy function for TMA gather4 in K dimension (M-major A).
1049+
1050+
Each gather4 instruction loads 4 K-columns × tile_M contiguous M-elements.
1051+
col_idx is the absolute M position in the global tensor.
1052+
K indices come from sAIdx (prefetched to smem by the scheduler warp).
1053+
1054+
Returns copy_fn(src_idx, dst_idx, tma_bar_ptr) which:
1055+
Issues gather4 calls with those K indices as row_indices
1056+
"""
1057+
tile_K = cute.size(sAIdx, mode=[0])
1058+
assert tile_K % 4 == 0
1059+
cta_group = num_cta
1060+
1061+
# Tiled copy for loading K indices from smem to registers (4 per vector, across warps)
1062+
copy_AIdx_s2r = cute.make_tiled_copy_tv(
1063+
cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128),
1064+
cute.make_layout(num_warps), # thr_layout
1065+
cute.make_layout(4), # val_layout — 4 K indices per gather4
1066+
)
1067+
warp_idx = cute.arch.make_warp_uniform(warp_idx)
1068+
warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx)
1069+
tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) # (((4,1),4,4))
1070+
# ((4,1),4,(64,2),(1,4)):((64,0),1024,(1,4096),(0,8192))
1071+
tSR_sA = warp_copy_AIdx_s2r.partition_S(layout_utils.transpose_view(sA))
1072+
tma_desc_ptr = get_tma_desc_addr(tma_atom)
1073+
tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group)
1074+
1075+
def prefetch_from_smem_fn(
1076+
a_prefetch_pipeline,
1077+
src_idx,
1078+
dst_idx,
1079+
a_prefetch_consumer_state,
1080+
) -> cute.Tensor:
1081+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
1082+
tSR_rAIdx = load_s2r(tSR_sAIdx[None, None, dst_idx])
1083+
cute.arch.sync_warp()
1084+
with cute.arch.elect_one():
1085+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
1086+
return tSR_rAIdx
1087+
1088+
def copy_fn(src_idx, dst_idx, tSR_rAIdx, tma_bar_ptr: cute.Pointer):
1089+
# Issue gather4: col_idx = M position, row_indices = 4 K positions
1090+
tSR_sA_cur = tSR_sA[None, None, None, dst_idx]
1091+
gather_dim = cute.size(tSR_sA_cur, mode=[2, 0]) # Typically 64
1092+
for k in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True):
1093+
row_indices = [tSR_rAIdx[v, k] for v in range(4)]
1094+
for m in cutlass.range(cute.size(tSR_sA_cur, mode=[2, 1]), unroll_full=True):
1095+
smem_ptr = tSR_sA_cur[None, k, (None, m)].iterator
1096+
with cute.arch.elect_one():
1097+
tma_gather4_load_fn(
1098+
smem_ptr, tma_bar_ptr, col_idx + m * gather_dim, row_indices
1099+
)
1100+
1101+
return copy_fn, prefetch_from_smem_fn
1102+
1103+
10361104
# ---------------------------------------------------------------------------
10371105
# Store helpers
10381106
# ---------------------------------------------------------------------------

quack/gemm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,6 @@ def gemm(
188188
assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
189189
if use_tma_gather:
190190
assert device_capacity[0] in [10, 11], "TMA gather currently requires SM100/SM110"
191-
assert gather_A and varlen_m and not varlen_k, (
192-
"TMA gather currently only supports varlen_m + gather_A"
193-
)
194191
if rounding_mode == RoundingMode.RS:
195192
assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100"
196193
if is_dynamic_persistent and device_capacity[0] == 9:

quack/gemm_sm100.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,6 @@ def __call__(
504504
assert (varlen_args.mAIdx is not None) == self.gather_A
505505
varlen_m = varlen_args.mCuSeqlensM is not None
506506
varlen_k = varlen_args.mCuSeqlensK is not None
507-
if const_expr(self.use_tma_gather):
508-
assert varlen_m and not varlen_k, "TMA gather currently only supports varlen_m"
509507

510508
# Setup attributes that dependent on gemm inputs
511509
self._setup_attributes(epilogue_args, varlen_args)
@@ -794,8 +792,6 @@ def kernel(
794792
assert not (varlen_m and varlen_k)
795793
if const_expr(self.gather_A):
796794
assert varlen_m or varlen_k
797-
if const_expr(self.use_tma_gather):
798-
assert varlen_m and not varlen_k
799795
has_D = const_expr(mD_mnl is not None)
800796
has_C = const_expr(mC_mnl is not None)
801797

@@ -1561,7 +1557,7 @@ def _make_gather_A_copy(
15611557
)
15621558
elif const_expr(varlen_k):
15631559
col_idx = Int32(tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0])
1564-
copy_A = copy_utils.gather_k_get_tma_copy_fn(
1560+
copy_A, prefetch_A = copy_utils.gather_k_get_tma_copy_fn(
15651561
tma_atom_a,
15661562
sA,
15671563
sAIdx,
@@ -1686,7 +1682,7 @@ def load_AB_tma_gather(
16861682
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
16871683
if is_tma_warp:
16881684
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
1689-
copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr, *prefetch_out)
1685+
copy_A(k_tile, smem_idx, *prefetch_out, tma_bar_ptr=tma_bar_ptr)
16901686
ab_pipeline.producer_commit(ab_producer_state)
16911687
ab_producer_state.advance()
16921688
peek_ab_empty_status = Boolean(True)

tests/test_linear_varlen_k.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import pytest
44
import torch
55

6+
from quack.cute_dsl_utils import get_device_capacity
7+
from quack.gemm import gemm as quack_gemm
68
from quack.gemm_interface import (
79
gemm,
810
gemm_ref,
@@ -11,6 +13,11 @@
1113
gemm_add_inplace,
1214
)
1315

16+
sm100_tma_gather_only = pytest.mark.skipif(
17+
not torch.cuda.is_available() or get_device_capacity(torch.device("cuda"))[0] not in (10, 11),
18+
reason="TMA gather tests require SM100/SM110",
19+
)
20+
1421

1522
def generate_A_with_gather(m, total_k, device, dtype, gather_A=False):
1623
"""Generate A matrix and optionally A_idx for gather_A case with varlen_k.
@@ -42,6 +49,105 @@ def generate_A_with_gather(m, total_k, device, dtype, gather_A=False):
4249
return A, A_idx
4350

4451

52+
def run_lowlevel_varlen_k_gemm(
53+
A,
54+
B,
55+
out,
56+
cu_seqlens_k,
57+
A_idx,
58+
*,
59+
dynamic_persistent=False,
60+
use_tma_gather=False,
61+
):
62+
device_capacity = get_device_capacity(A.device)[0]
63+
tile_count_semaphore = (
64+
torch.zeros(1, dtype=torch.int32, device=A.device)
65+
if dynamic_persistent and device_capacity == 9
66+
else None
67+
)
68+
quack_gemm(
69+
A,
70+
B,
71+
out,
72+
C=None,
73+
tile_count_semaphore=tile_count_semaphore,
74+
tile_M=256,
75+
tile_N=256,
76+
cluster_M=2,
77+
cluster_N=1,
78+
persistent=True,
79+
is_dynamic_persistent=dynamic_persistent,
80+
cu_seqlens_k=cu_seqlens_k,
81+
A_idx=A_idx,
82+
use_tma_gather=use_tma_gather,
83+
)
84+
85+
86+
@sm100_tma_gather_only
87+
@pytest.mark.parametrize("dynamic_persistent", [False, True])
88+
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
89+
@pytest.mark.parametrize("n", [1024])
90+
@pytest.mark.parametrize("m", [2048])
91+
def test_gemm_varlen_k_tma_gather_matches_cpasync(
92+
m,
93+
n,
94+
input_dtype,
95+
dynamic_persistent,
96+
):
97+
"""Compare TMA gather vs cp.async gather for varlen_k."""
98+
device = "cuda"
99+
torch.random.manual_seed(42)
100+
num_groups = 4
101+
# Use K values divisible by tile_K (64 for bf16) to avoid partial-tile edge cases
102+
seq_lens = torch.randint(2, 6, (num_groups,), device="cpu") * 64
103+
total_k = seq_lens.sum().item()
104+
cu_seqlens_k = torch.cat(
105+
[torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)]
106+
).to(device)
107+
A, A_idx = generate_A_with_gather(m, total_k, device, input_dtype, gather_A=True)
108+
# B for quack_gemm varlen_k: 2D (n, total_k), n-major (stride(-2)==1)
109+
B_ref = torch.randn((total_k, n), device=device, dtype=input_dtype) / math.sqrt(
110+
total_k / num_groups
111+
)
112+
B = B_ref.T # (n, total_k) with n contiguous — stride(-2)==1
113+
114+
out_cpasync = torch.empty((num_groups, m, n), device=device, dtype=input_dtype)
115+
out_tma = torch.empty_like(out_cpasync)
116+
117+
run_lowlevel_varlen_k_gemm(
118+
A,
119+
B,
120+
out_cpasync,
121+
cu_seqlens_k,
122+
A_idx,
123+
dynamic_persistent=dynamic_persistent,
124+
use_tma_gather=False,
125+
)
126+
run_lowlevel_varlen_k_gemm(
127+
A,
128+
B,
129+
out_tma,
130+
cu_seqlens_k,
131+
A_idx,
132+
dynamic_persistent=dynamic_persistent,
133+
use_tma_gather=True,
134+
)
135+
136+
# gemm_ref expects B as (total_K, N)
137+
out_ref = gemm_ref(
138+
A.float(),
139+
B_ref.float(),
140+
cu_seqlens_k=cu_seqlens_k,
141+
A_idx=A_idx,
142+
)
143+
out_pt = gemm_ref(A, B_ref, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx)
144+
145+
assert out_tma.shape == (num_groups, m, n)
146+
assert (out_tma - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-5
147+
assert (out_cpasync - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-5
148+
torch.testing.assert_close(out_tma, out_cpasync, atol=3e-2, rtol=1e-3)
149+
150+
45151
@pytest.mark.parametrize("permute_batch", [False, True])
46152
@pytest.mark.parametrize("gather_A", [False, True])
47153
# @pytest.mark.parametrize("gather_A", [False])

0 commit comments

Comments
 (0)