File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -598,7 +598,11 @@ def mma(
598598 for k in cutlass .range_constexpr (num_k_blocks ):
599599 k_next = 0 if k + 1 == num_k_blocks else k + 1
600600 if const_expr (k == num_k_blocks - 1 ):
601- # Don't need to sync_warp: the previous instruction was mma.sync from cute.gemm
601+ # TMA writes smem through the async proxy; ldmatrix reads it through the
602+ # generic proxy. Fence before releasing the smem stage for reuse, then
603+ # sync the warp because only one lane signals the mbarrier.
604+ cute .arch .fence_view_async_shared ()
605+ cute .arch .sync_warp ()
602606 ab_pipeline .consumer_release (ab_read_state )
603607 ab_read_state .advance ()
604608 peek_ab_full_status = ab_pipeline .consumer_try_wait (ab_read_state )
@@ -614,6 +618,11 @@ def mma(
614618 for k in cutlass .range_constexpr (num_k_blocks ):
615619 k_next = 0 if k + 1 == num_k_blocks else k + 1
616620 if const_expr (k == num_k_blocks - 1 ):
621+ # TMA writes smem through the async proxy; ldmatrix reads it through the
622+ # generic proxy. Fence before releasing the smem stage for reuse, then
623+ # sync the warp because only one lane signals the mbarrier.
624+ cute .arch .fence_view_async_shared ()
625+ cute .arch .sync_warp ()
617626 ab_pipeline .consumer_release (ab_read_state )
618627 ab_read_state .advance ()
619628 if const_expr (k_next > 0 ):
You can’t perform that action at this time.
0 commit comments