Skip to content

Commit 58e8693

Browse files
committed
[Gemm] Add gemm_sm120.py
1 parent be55f58 commit 58e8693

5 files changed

Lines changed: 532 additions & 29 deletions

File tree

benchmarks/benchmark_gemm.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from cutlass.cute.runtime import from_dlpack, make_ptr
1515
from cutlass import Int32, Boolean
1616

17-
from quack.gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100
17+
from quack.gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100, GemmDefaultSm120
1818
from quack.gemm_sm90 import TileSchedulerOptions
1919

2020
from quack.cute_dsl_utils import get_device_capacity
@@ -141,6 +141,7 @@ def parse_arguments() -> argparse.Namespace:
141141
parser.add_argument("--gather_A", action="store_true", help="Gather A")
142142
parser.add_argument("--add_to_output", action="store_true", help="Add to output")
143143
parser.add_argument("--fp8_fast_accum", action="store_true", help="FP8 fast accum")
144+
parser.add_argument("--sm120", action="store_true", help="Use SM120 warp-level MMA (on SM90 HW)")
144145
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
145146

146147
args = parser.parse_args()
@@ -181,6 +182,7 @@ def run(
181182
gather_A: bool,
182183
add_to_output: bool,
183184
fp8_fast_accum: bool,
185+
sm120: bool = False,
184186
**kwargs,
185187
):
186188
"""
@@ -235,7 +237,7 @@ def run(
235237
# Unpack parameters
236238
m, n, k, l = mnkl
237239
cluster_shape_mnk = (*cluster_shape_mn, 1)
238-
GemmCls = GemmDefaultSm100 if is_sm100 else GemmDefaultSm90
240+
GemmCls = GemmDefaultSm100 if is_sm100 else (GemmDefaultSm120 if sm120 else GemmDefaultSm90)
239241

240242
# Skip unsupported types
241243
if not GemmCls.is_valid_dtypes(
@@ -377,6 +379,15 @@ def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic
377379
gather_A=gather_A,
378380
use_clc_persistence=dynamic_persistent,
379381
)
382+
elif sm120:
383+
gemm = GemmCls(
384+
acc_dtype,
385+
a_dtype,
386+
tile_shape_mn,
387+
cluster_shape_mnk,
388+
is_persistent=persistent,
389+
gather_A=gather_A,
390+
)
380391
else:
381392
gemm = GemmCls(
382393
acc_dtype,
@@ -600,5 +611,6 @@ def fn():
600611
args.gather_A,
601612
args.add_to_output,
602613
args.fp8_fast_accum,
614+
args.sm120,
603615
)
604616
print("PASS")

quack/gemm_default_epi.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from quack.epi_ops import Scalar, RowVecLoad, ColVecLoad
1111
from quack.gemm_sm90 import GemmSm90
1212
from quack.gemm_sm100 import GemmSm100
13+
from quack.gemm_sm120 import GemmSm120
1314
from quack.rounding import RoundingMode
1415
import quack.utils as utils
1516

@@ -101,3 +102,7 @@ class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90):
101102

102103
class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100):
103104
pass
105+
106+
107+
class GemmDefaultSm120(GemmDefaultEpiMixin, GemmSm120):
108+
pass

quack/gemm_sm100.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) 2025-2026, Tri Dao.
12
# Based on the cute-dsl example:
23
# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py
34

@@ -210,6 +211,10 @@ def __init__(
210211
self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps
211212
self.scheduler_warp_id = self.epi_load_warp_id + 1
212213
self.num_epi_warps = len(self.epilog_warp_id)
214+
self.epilogue_barrier = pipeline.NamedBarrier(
215+
barrier_id=int(NamedBarrierGemm.Epilogue),
216+
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
217+
)
213218
# Register reallocation for gather_A (3 warp groups, 504 regs total, 168 per WG default).
214219
# Heavy epilogues (e.g. colvec_reduce in DGated) override these to avoid register spilling.
215220
# Without gather_A there are only 2 WGs (512 total, 256 per WG = max), no reallocation needed.
@@ -1393,11 +1398,6 @@ def kernel(
13931398
# (MMA, MMA_M, MMA_N, STAGE)
13941399
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
13951400

1396-
epilogue_barrier = pipeline.NamedBarrier(
1397-
barrier_id=int(NamedBarrierGemm.Epilogue),
1398-
num_threads=self.num_epi_warps * cute.arch.WARP_SIZE,
1399-
)
1400-
14011401
# Partition for epilogue
14021402
epi_tidx = tidx
14031403
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
@@ -1479,7 +1479,7 @@ def kernel(
14791479
copy_C,
14801480
tile_coord_mnkl,
14811481
varlen_manager,
1482-
epilogue_barrier,
1482+
self.epilogue_barrier,
14831483
tile_scheduler,
14841484
epi_tidx,
14851485
is_tma_warp,

0 commit comments

Comments
 (0)