|
14 | 14 | from cutlass.cute.runtime import from_dlpack, make_ptr |
15 | 15 | from cutlass import Int32, Boolean |
16 | 16 |
|
17 | | -from quack.gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100 |
| 17 | +from quack.gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100, GemmDefaultSm120 |
18 | 18 | from quack.gemm_sm90 import TileSchedulerOptions |
19 | 19 |
|
20 | 20 | from quack.cute_dsl_utils import get_device_capacity |
@@ -141,6 +141,7 @@ def parse_arguments() -> argparse.Namespace: |
141 | 141 | parser.add_argument("--gather_A", action="store_true", help="Gather A") |
142 | 142 | parser.add_argument("--add_to_output", action="store_true", help="Add to output") |
143 | 143 | parser.add_argument("--fp8_fast_accum", action="store_true", help="FP8 fast accum") |
| 144 | + parser.add_argument("--sm120", action="store_true", help="Use SM120 warp-level MMA (on SM90 HW)") |
144 | 145 | parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking") |
145 | 146 |
|
146 | 147 | args = parser.parse_args() |
@@ -181,6 +182,7 @@ def run( |
181 | 182 | gather_A: bool, |
182 | 183 | add_to_output: bool, |
183 | 184 | fp8_fast_accum: bool, |
| 185 | + sm120: bool = False, |
184 | 186 | **kwargs, |
185 | 187 | ): |
186 | 188 | """ |
@@ -235,7 +237,7 @@ def run( |
235 | 237 | # Unpack parameters |
236 | 238 | m, n, k, l = mnkl |
237 | 239 | cluster_shape_mnk = (*cluster_shape_mn, 1) |
238 | | - GemmCls = GemmDefaultSm100 if is_sm100 else GemmDefaultSm90 |
| 240 | + GemmCls = GemmDefaultSm100 if is_sm100 else (GemmDefaultSm120 if sm120 else GemmDefaultSm90) |
239 | 241 |
|
240 | 242 | # Skip unsupported types |
241 | 243 | if not GemmCls.is_valid_dtypes( |
@@ -377,6 +379,15 @@ def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic |
377 | 379 | gather_A=gather_A, |
378 | 380 | use_clc_persistence=dynamic_persistent, |
379 | 381 | ) |
| 382 | + elif sm120: |
| 383 | + gemm = GemmCls( |
| 384 | + acc_dtype, |
| 385 | + a_dtype, |
| 386 | + tile_shape_mn, |
| 387 | + cluster_shape_mnk, |
| 388 | + is_persistent=persistent, |
| 389 | + gather_A=gather_A, |
| 390 | + ) |
380 | 391 | else: |
381 | 392 | gemm = GemmCls( |
382 | 393 | acc_dtype, |
@@ -600,5 +611,6 @@ def fn(): |
600 | 611 | args.gather_A, |
601 | 612 | args.add_to_output, |
602 | 613 | args.fp8_fast_accum, |
| 614 | + args.sm120, |
603 | 615 | ) |
604 | 616 | print("PASS") |
0 commit comments