@@ -61,13 +61,17 @@ def __init__(
6161 self .cta_tile_shape_mnk = (tile_M , tile_N , 1 )
6262
6363 # Warp-level MMA uses (2, 2, 1) atom layout like the example
64- self .atom_layout_mnk = (2 , 2 , 1 )
64+ # TODO: autotune (2, 2, 1) and (4, 2, 1)
65+ self .atom_layout_mnk = (4 , 2 , 1 )
6566 self .mma_inst_mnk = (16 , 8 , 16 )
6667 self .num_mma_warps = math .prod (self .atom_layout_mnk )
67- self .threads_per_cta = (self .num_mma_warps + 1 ) * cute .arch .WARP_SIZE
6868 # For compatibility with SM90 code that uses warp groups
6969 self .num_threads_per_warp_group = 128
70- self .mma_warp_groups = 1
70+ assert self .num_mma_warps % 4 == 0
71+ self .mma_warp_groups = self .num_mma_warps // 4
72+ # threads_per_cta must be a multiple of 128 (warp group size) so that
73+ # the DMA warp's setmaxnreg.dec.sync has a complete warp group to sync with.
74+ self .threads_per_cta = (self .mma_warp_groups + 1 ) * self .num_threads_per_warp_group
7175
7276 self .num_mcast_ctas_a = cluster_shape_mnk [1 ]
7377 if gather_A :
@@ -291,7 +295,7 @@ def kernel(
291295 tCrA = tiled_mma .make_fragment_A (tCsA [None , None , None , 0 ])
292296 tCrB = tiled_mma .make_fragment_B (tCsB [None , None , None , 0 ])
293297 acc_shape = tiled_mma .partition_shape_C (self .cta_tile_shape_mnk [:2 ])
294- accumulators = cute .make_rmem_tensor (acc_shape , self .acc_dtype )
298+ acc = cute .make_rmem_tensor (acc_shape , self .acc_dtype )
295299
296300 TileSchedulerCls = partial (
297301 TileSchedulerCls .create , tile_sched_params , sched_data , sched_pipeline
@@ -303,56 +307,62 @@ def kernel(
303307 k_tile_cnt = cute .ceil_div (mA_mkl .shape [1 ], self .cta_tile_shape_mnk [2 ])
304308
305309 # =====================================================================
306- # DMA warp — reuses SM90's load_AB via tma_get_copy_fn
310+ # DMA warp group — all warps >= num_mma_warps must enter to participate
311+ # in setmaxnreg.dec.sync (warp-group-level barrier).
312+ # Only warp num_mma_warps actually does TMA loads.
307313 # =====================================================================
308- if warp_idx == self .num_mma_warps :
314+ if warp_idx >= self .num_mma_warps :
315+ # All warps in this warp group must execute setmaxnreg (warp-group barrier).
309316 cute .arch .setmaxregister_decrease (self .num_regs_load )
310- tile_scheduler = TileSchedulerCls ()
311- work_tile = tile_scheduler .initial_work_tile_info ()
312- ab_producer_state = make_pipeline_state (
313- pipeline .PipelineUserType .Producer , self .ab_stage
314- )
315- while work_tile .is_valid_tile :
316- tile_coord_mnkl = work_tile .tile_idx
317- batch_idx = tile_coord_mnkl [3 ]
318- mA_mk = varlen_manager .offset_batch_A (mA_mkl , batch_idx )
319- gA_mk = cute .local_tile (
320- mA_mk ,
321- cute .select (self .cta_tile_shape_mnk , [0 , 2 ]),
322- (tile_coord_mnkl [0 ], None ),
323- )
324- gB_nk = cute .local_tile (
325- varlen_manager .offset_batch_B (mB_nkl , batch_idx ),
326- cute .select (self .cta_tile_shape_mnk , [1 , 2 ]),
327- (tile_coord_mnkl [1 ], None ),
328- )
329- copy_A , _ , _ = copy_utils .tma_get_copy_fn (
330- tma_atom_a ,
331- cta_coord = cluster_coord_mnk [1 ],
332- cta_layout = cute .make_layout (
333- cute .slice_ (cluster_layout_mnk , (0 , None , 0 )).shape
334- ),
335- src_tensor = gA_mk ,
336- dst_tensor = sA ,
337- mcast_mask = a_mcast_mask ,
338- )
339- copy_B , _ , _ = copy_utils .tma_get_copy_fn (
340- tma_atom_b ,
341- cta_coord = cluster_coord_mnk [0 ],
342- cta_layout = cute .make_layout (
343- cute .slice_ (cluster_layout_mnk , (None , 0 , 0 )).shape
344- ),
345- src_tensor = gB_nk ,
346- dst_tensor = sB ,
347- mcast_mask = b_mcast_mask ,
317+ # Only the first DMA warp does actual work; the rest are padding
318+ # for the warp-group-level setmaxnreg barrier.
319+ if warp_idx == self .num_mma_warps :
320+ tile_scheduler = TileSchedulerCls ()
321+ work_tile = tile_scheduler .initial_work_tile_info ()
322+ ab_producer_state = make_pipeline_state (
323+ pipeline .PipelineUserType .Producer , self .ab_stage
348324 )
349- ab_producer_state = self .load_AB (
350- ab_pipeline , ab_producer_state , copy_A , copy_B , k_tile_cnt
351- )
352- tile_scheduler .advance_to_next_work (is_scheduler_warp = True )
353- work_tile = tile_scheduler .get_current_work ()
354- ab_pipeline .producer_tail (ab_producer_state )
355- tile_scheduler .producer_tail ()
325+ while work_tile .is_valid_tile :
326+ tile_coord_mnkl = work_tile .tile_idx
327+ batch_idx = tile_coord_mnkl [3 ]
328+ mA_mk = varlen_manager .offset_batch_A (mA_mkl , batch_idx )
329+ gA_mk = cute .local_tile (
330+ mA_mk ,
331+ cute .select (self .cta_tile_shape_mnk , [0 , 2 ]),
332+ (tile_coord_mnkl [0 ], None ),
333+ )
334+ gB_nk = cute .local_tile (
335+ varlen_manager .offset_batch_B (mB_nkl , batch_idx ),
336+ cute .select (self .cta_tile_shape_mnk , [1 , 2 ]),
337+ (tile_coord_mnkl [1 ], None ),
338+ )
339+ copy_A , _ , _ = copy_utils .tma_get_copy_fn (
340+ tma_atom_a ,
341+ cta_coord = cluster_coord_mnk [1 ],
342+ cta_layout = cute .make_layout (
343+ cute .slice_ (cluster_layout_mnk , (0 , None , 0 )).shape
344+ ),
345+ src_tensor = gA_mk ,
346+ dst_tensor = sA ,
347+ mcast_mask = a_mcast_mask ,
348+ )
349+ copy_B , _ , _ = copy_utils .tma_get_copy_fn (
350+ tma_atom_b ,
351+ cta_coord = cluster_coord_mnk [0 ],
352+ cta_layout = cute .make_layout (
353+ cute .slice_ (cluster_layout_mnk , (None , 0 , 0 )).shape
354+ ),
355+ src_tensor = gB_nk ,
356+ dst_tensor = sB ,
357+ mcast_mask = b_mcast_mask ,
358+ )
359+ ab_producer_state = self .load_AB (
360+ ab_pipeline , ab_producer_state , copy_A , copy_B , k_tile_cnt
361+ )
362+ tile_scheduler .advance_to_next_work (is_scheduler_warp = True )
363+ work_tile = tile_scheduler .get_current_work ()
364+ ab_pipeline .producer_tail (ab_producer_state )
365+ tile_scheduler .producer_tail ()
356366
357367 # =====================================================================
358368 # MMA warps
@@ -395,7 +405,7 @@ def kernel(
395405 ab_pipeline ,
396406 ab_read_state ,
397407 tiled_mma ,
398- accumulators ,
408+ acc ,
399409 k_tile_cnt ,
400410 smem_tiled_copy_A ,
401411 smem_tiled_copy_B ,
@@ -434,7 +444,7 @@ def kernel(
434444 tiled_copy_r2s , tRS_rD , tRS_sD = self .epilog_smem_store_and_partition (
435445 tiled_mma , self .d_layout , d_dtype_for_layout , sD , tidx
436446 )
437- tRS_rAcc = self .epi_retile_acc (accumulators , tRS_rD , tiled_copy_r2s )
447+ tRS_rAcc = self .epi_retile_acc (acc , tRS_rD , tiled_copy_r2s )
438448 load_acc_subtile = partial (self .epi_load_acc_subtile , tRS_rAcc )
439449 if const_expr (has_C ):
440450 tiled_copy_s2r , tRS_rC , tSR_rC , tSR_sC = self .epilog_smem_load_and_partition (
@@ -443,7 +453,7 @@ def kernel(
443453 else :
444454 tiled_copy_s2r , tSR_sC , tRS_rC , tSR_rC = None , None , None , None
445455
446- self .epi_visit_acc (epilogue_params , accumulators , tiled_mma , tile_coord_mnkl , tidx )
456+ self .epi_visit_acc (epilogue_params , acc , tiled_mma , tile_coord_mnkl , tidx )
447457
448458 epi_read_state , epi_producer_state = self .epilogue (
449459 epilogue_params ,
0 commit comments