Skip to content

Commit e2ed8c4

Browse files
committed
[Gemm,Sm120] Use 8 MMA warps
1 parent 69451c7 commit e2ed8c4

1 file changed

Lines changed: 64 additions & 54 deletions

File tree

quack/gemm_sm120.py

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,17 @@ def __init__(
6161
self.cta_tile_shape_mnk = (tile_M, tile_N, 1)
6262

6363
# Warp-level MMA uses (2, 2, 1) atom layout like the example
64-
self.atom_layout_mnk = (2, 2, 1)
64+
# TODO: autotune (2, 2, 1) and (4, 2, 1)
65+
self.atom_layout_mnk = (4, 2, 1)
6566
self.mma_inst_mnk = (16, 8, 16)
6667
self.num_mma_warps = math.prod(self.atom_layout_mnk)
67-
self.threads_per_cta = (self.num_mma_warps + 1) * cute.arch.WARP_SIZE
6868
# For compatibility with SM90 code that uses warp groups
6969
self.num_threads_per_warp_group = 128
70-
self.mma_warp_groups = 1
70+
assert self.num_mma_warps % 4 == 0
71+
self.mma_warp_groups = self.num_mma_warps // 4
72+
# threads_per_cta must be a multiple of 128 (warp group size) so that
73+
# the DMA warp's setmaxnreg.dec.sync has a complete warp group to sync with.
74+
self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
7175

7276
self.num_mcast_ctas_a = cluster_shape_mnk[1]
7377
if gather_A:
@@ -291,7 +295,7 @@ def kernel(
291295
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
292296
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
293297
acc_shape = tiled_mma.partition_shape_C(self.cta_tile_shape_mnk[:2])
294-
accumulators = cute.make_rmem_tensor(acc_shape, self.acc_dtype)
298+
acc = cute.make_rmem_tensor(acc_shape, self.acc_dtype)
295299

296300
TileSchedulerCls = partial(
297301
TileSchedulerCls.create, tile_sched_params, sched_data, sched_pipeline
@@ -303,56 +307,62 @@ def kernel(
303307
k_tile_cnt = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2])
304308

305309
# =====================================================================
306-
# DMA warp — reuses SM90's load_AB via tma_get_copy_fn
310+
# DMA warp group — all warps >= num_mma_warps must enter to participate
311+
# in setmaxnreg.dec.sync (warp-group-level barrier).
312+
# Only warp num_mma_warps actually does TMA loads.
307313
# =====================================================================
308-
if warp_idx == self.num_mma_warps:
314+
if warp_idx >= self.num_mma_warps:
315+
# All warps in this warp group must execute setmaxnreg (warp-group barrier).
309316
cute.arch.setmaxregister_decrease(self.num_regs_load)
310-
tile_scheduler = TileSchedulerCls()
311-
work_tile = tile_scheduler.initial_work_tile_info()
312-
ab_producer_state = make_pipeline_state(
313-
pipeline.PipelineUserType.Producer, self.ab_stage
314-
)
315-
while work_tile.is_valid_tile:
316-
tile_coord_mnkl = work_tile.tile_idx
317-
batch_idx = tile_coord_mnkl[3]
318-
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
319-
gA_mk = cute.local_tile(
320-
mA_mk,
321-
cute.select(self.cta_tile_shape_mnk, [0, 2]),
322-
(tile_coord_mnkl[0], None),
323-
)
324-
gB_nk = cute.local_tile(
325-
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
326-
cute.select(self.cta_tile_shape_mnk, [1, 2]),
327-
(tile_coord_mnkl[1], None),
328-
)
329-
copy_A, _, _ = copy_utils.tma_get_copy_fn(
330-
tma_atom_a,
331-
cta_coord=cluster_coord_mnk[1],
332-
cta_layout=cute.make_layout(
333-
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
334-
),
335-
src_tensor=gA_mk,
336-
dst_tensor=sA,
337-
mcast_mask=a_mcast_mask,
338-
)
339-
copy_B, _, _ = copy_utils.tma_get_copy_fn(
340-
tma_atom_b,
341-
cta_coord=cluster_coord_mnk[0],
342-
cta_layout=cute.make_layout(
343-
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
344-
),
345-
src_tensor=gB_nk,
346-
dst_tensor=sB,
347-
mcast_mask=b_mcast_mask,
317+
# Only the first DMA warp does actual work; the rest are padding
318+
# for the warp-group-level setmaxnreg barrier.
319+
if warp_idx == self.num_mma_warps:
320+
tile_scheduler = TileSchedulerCls()
321+
work_tile = tile_scheduler.initial_work_tile_info()
322+
ab_producer_state = make_pipeline_state(
323+
pipeline.PipelineUserType.Producer, self.ab_stage
348324
)
349-
ab_producer_state = self.load_AB(
350-
ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
351-
)
352-
tile_scheduler.advance_to_next_work(is_scheduler_warp=True)
353-
work_tile = tile_scheduler.get_current_work()
354-
ab_pipeline.producer_tail(ab_producer_state)
355-
tile_scheduler.producer_tail()
325+
while work_tile.is_valid_tile:
326+
tile_coord_mnkl = work_tile.tile_idx
327+
batch_idx = tile_coord_mnkl[3]
328+
mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
329+
gA_mk = cute.local_tile(
330+
mA_mk,
331+
cute.select(self.cta_tile_shape_mnk, [0, 2]),
332+
(tile_coord_mnkl[0], None),
333+
)
334+
gB_nk = cute.local_tile(
335+
varlen_manager.offset_batch_B(mB_nkl, batch_idx),
336+
cute.select(self.cta_tile_shape_mnk, [1, 2]),
337+
(tile_coord_mnkl[1], None),
338+
)
339+
copy_A, _, _ = copy_utils.tma_get_copy_fn(
340+
tma_atom_a,
341+
cta_coord=cluster_coord_mnk[1],
342+
cta_layout=cute.make_layout(
343+
cute.slice_(cluster_layout_mnk, (0, None, 0)).shape
344+
),
345+
src_tensor=gA_mk,
346+
dst_tensor=sA,
347+
mcast_mask=a_mcast_mask,
348+
)
349+
copy_B, _, _ = copy_utils.tma_get_copy_fn(
350+
tma_atom_b,
351+
cta_coord=cluster_coord_mnk[0],
352+
cta_layout=cute.make_layout(
353+
cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape
354+
),
355+
src_tensor=gB_nk,
356+
dst_tensor=sB,
357+
mcast_mask=b_mcast_mask,
358+
)
359+
ab_producer_state = self.load_AB(
360+
ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt
361+
)
362+
tile_scheduler.advance_to_next_work(is_scheduler_warp=True)
363+
work_tile = tile_scheduler.get_current_work()
364+
ab_pipeline.producer_tail(ab_producer_state)
365+
tile_scheduler.producer_tail()
356366

357367
# =====================================================================
358368
# MMA warps
@@ -395,7 +405,7 @@ def kernel(
395405
ab_pipeline,
396406
ab_read_state,
397407
tiled_mma,
398-
accumulators,
408+
acc,
399409
k_tile_cnt,
400410
smem_tiled_copy_A,
401411
smem_tiled_copy_B,
@@ -434,7 +444,7 @@ def kernel(
434444
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
435445
tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
436446
)
437-
tRS_rAcc = self.epi_retile_acc(accumulators, tRS_rD, tiled_copy_r2s)
447+
tRS_rAcc = self.epi_retile_acc(acc, tRS_rD, tiled_copy_r2s)
438448
load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
439449
if const_expr(has_C):
440450
tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
@@ -443,7 +453,7 @@ def kernel(
443453
else:
444454
tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
445455

446-
self.epi_visit_acc(epilogue_params, accumulators, tiled_mma, tile_coord_mnkl, tidx)
456+
self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
447457

448458
epi_read_state, epi_producer_state = self.epilogue(
449459
epilogue_params,

0 commit comments

Comments
 (0)