Skip to content

Commit 0ab127b

Browse files
committed
[Gemm,Sm120] Add fence_view_async_shared between ldmatrix and TMA
1 parent bccaf78 commit 0ab127b

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

quack/gemm_sm120.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,11 @@ def mma(
598598
for k in cutlass.range_constexpr(num_k_blocks):
599599
k_next = 0 if k + 1 == num_k_blocks else k + 1
600600
if const_expr(k == num_k_blocks - 1):
601-
# Don't need to sync_warp: the previous instruction was mma.sync from cute.gemm
601+
# TMA writes smem through the async proxy; ldmatrix reads it through the
602+
# generic proxy. Fence before releasing the smem stage for reuse, then
603+
# sync the warp because only one lane signals the mbarrier.
604+
cute.arch.fence_view_async_shared()
605+
cute.arch.sync_warp()
602606
ab_pipeline.consumer_release(ab_read_state)
603607
ab_read_state.advance()
604608
peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state)
@@ -614,6 +618,11 @@ def mma(
614618
for k in cutlass.range_constexpr(num_k_blocks):
615619
k_next = 0 if k + 1 == num_k_blocks else k + 1
616620
if const_expr(k == num_k_blocks - 1):
621+
# TMA writes smem through the async proxy; ldmatrix reads it through the
622+
# generic proxy. Fence before releasing the smem stage for reuse, then
623+
# sync the warp because only one lane signals the mbarrier.
624+
cute.arch.fence_view_async_shared()
625+
cute.arch.sync_warp()
617626
ab_pipeline.consumer_release(ab_read_state)
618627
ab_read_state.advance()
619628
if const_expr(k_next > 0):

0 commit comments

Comments
 (0)