Skip to content

Commit c9a3697

Browse files
dlangbemodularbot
authored andcommitted
[Kernels][AMD] Small M Autotuning for FP8 Matmul
Improved matmul performance in the 128 <= M <= 256 range by up to 30% on Llama3-405b shapes through use of different block sizes determined with an autotuning run. New kernel dispatching with vendorBLAS fallback path ensures parity or better performance as vendorBLAS for all Llama3-405b sizes. Also includes small renaming fix in `bench_matmul_dequant_mxfp4_amd.yaml` that was missed in previous PR. MODULAR_ORIG_COMMIT_REV_ID: fa6d1f94e1a23f5b5f5a944f2d3c1915a6c47974
1 parent e84e2cc commit c9a3697

2 files changed

Lines changed: 97 additions & 39 deletions

File tree

max/kernels/benchmarks/gpu/bench_matmul_dequant_mxfp4_amd.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# limitations under the License.
1212
##===----------------------------------------------------------------------===##
1313

14-
name: bench_matmul_mxfp4
15-
file: $KERNEL_BENCHMARKS_ROOT/gpu/bench_matmul_mxfp4.mojo
14+
name: bench_matmul_dequant_mxfp4_amd
15+
file: $KERNEL_BENCHMARKS_ROOT/gpu/bench_matmul_dequant_mxfp4_amd.mojo
1616

1717
params:
1818

max/kernels/src/linalg/matmul/gpu/__init__.mojo

Lines changed: 95 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -664,20 +664,27 @@ def _matmul_gpu[
664664
elementwise_lambda_fn=elementwise_lambda_wrapper,
665665
](c, a, b, ctx)
666666

667-
# M threshold above which vendor BLAS (hipBLASLt) outperforms
668-
# all custom kernels for these (N, K) shapes.
669-
# Derived from Llama3-405B TP=4 benchmarks on MI355X.
670-
# Format: Index(N, K, M_threshold) — vendor BLAS used when m >= threshold.
667+
# M thresholds where vendor BLAS (hipBLASLt) outperforms
668+
# all custom kernels. Derived from Llama3-405B TP=4 on MI355X.
669+
# Also includes M ranges where vendor wins for small M
670+
# (standard kernel can't match hipBLASLt's tiny-tile configs).
671+
# Format: Index(N, K, M_low, M_high) — vendor BLAS for M_low <= m < M_high.
671672
comptime vendor_blas_NK_m = [
673+
# N=2304: vendor wins at M=4-16 and M>=4096
672674
Index(2304, 16384, 4096),
673-
Index(16384, 2048, 225),
674-
Index(13312, 16384, 600),
675+
# N=16384 K=2048: vendor wins at M>=150 (except M=450 skinny)
676+
Index(16384, 2048, 150),
677+
# N=16384 K=6656: vendor wins at M=4-16 and M>=600
675678
Index(16384, 6656, 600),
679+
# N=13312: vendor wins at M>=600 (borderline)
680+
Index(13312, 16384, 600),
676681
]
677682
comptime for i in range(len(vendor_blas_NK_m)):
678683
comptime nk_m = vendor_blas_NK_m[i]
679684
comptime if static_N == nk_m[0] and static_K == nk_m[1]:
680-
if m >= nk_m[2]:
685+
if m >= nk_m[2] or (
686+
n == 16384 and k == 6656 and m > 1 and m < 64
687+
):
681688
logger.info(
682689
"Executing: vendor BLAS (hipBLASLt) for AMD"
683690
)
@@ -688,15 +695,79 @@ def _matmul_gpu[
688695

689696
comptime if not transpose_b:
690697
return kernel_helper[128, 128, num_pipeline_stages=2]()
691-
elif get_defined_bool["AUTOTUNING_MODE", False]():
698+
699+
comptime if get_defined_bool["AUTOTUNING_MODE", False]():
692700
comptime block_m = get_defined_int["TUNE_BM", 128]()
693701
comptime block_n = get_defined_int["TUNE_BN", 128]()
702+
comptime block_k = get_defined_int[
703+
"TUNE_BK", _bk_base[a_type, True]()
704+
]()
694705
comptime num_k_partitions = get_defined_int[
695706
"TUNE_NUM_K_PARTITIONS", 1
696707
]()
697-
return kernel_helper[
698-
block_m, block_n, num_k_partitions=num_k_partitions
699-
]()
708+
comptime config = MatmulConfig[
709+
a_type, b_type, c_type, transpose_b
710+
](
711+
block_tile_shape=Index(block_m, block_n, block_k),
712+
warp_tile_shape=Index(
713+
block_m // 2, block_n // 2, block_k
714+
),
715+
mma_shape=_amdgpu_get_mma_shape[a_type, transpose_b](),
716+
num_pipeline_stages=1,
717+
num_k_partitions=UInt(num_k_partitions),
718+
pdl_level=pdl_level,
719+
)
720+
return _multistage_gemm[config]()
721+
722+
# Shape-specific FP8 configs for small M (128-256).
723+
# These match hipBLASLt's tile choices which use deeper K
724+
# tiles and adapted block shapes for better per-block
725+
# throughput at low CU occupancy.
726+
# Format: Index(N, K, BM, BN, BK)
727+
# Small-M FP8 dispatch (M=128-256): use autotuned block
728+
# shapes that outperform the generic auto-tuner.
729+
# Derived from sweep over BM/BN/BK on MI355X.
730+
comptime if a_type.is_float8() and transpose_b:
731+
if m >= 128 and m <= 256:
732+
733+
@always_inline
734+
@parameter
735+
def _small_m_gemm[
736+
_bm: Int, _bn: Int, _bk: Int
737+
]() raises:
738+
comptime config = MatmulConfig[
739+
a_type, b_type, c_type, transpose_b
740+
](
741+
block_tile_shape=Index(_bm, _bn, _bk),
742+
warp_tile_shape=Index(_bm // 2, _bn // 2, _bk),
743+
mma_shape=_amdgpu_get_mma_shape[
744+
a_type, transpose_b
745+
](),
746+
num_pipeline_stages=1,
747+
pdl_level=pdl_level,
748+
)
749+
return _multistage_gemm[config]()
750+
751+
comptime if static_N < 4096:
752+
# Narrow N (e.g. N=2304): deep BK=512,
753+
# square blocks for balanced compute.
754+
# Guard: K must be >= BK to avoid OOB reads.
755+
if k >= 512:
756+
if m <= 160:
757+
return _small_m_gemm[32, 64, 512]()
758+
else:
759+
return _small_m_gemm[64, 64, 512]()
760+
else:
761+
if m <= 150:
762+
# Small M with wide N: BM=64 for
763+
# occupancy, BN=128 for N-coverage
764+
if k >= 256:
765+
return _small_m_gemm[64, 128, 256]()
766+
else:
767+
# Larger M (200-256): square 128x128
768+
# tiles with BK=256
769+
if k >= 256:
770+
return _small_m_gemm[128, 128, 256]()
700771

701772
comptime sm_count = ctx.default_device_info.sm_count
702773
comptime block_shape_list = _amdgpu_matmul_build_block_shape_list[
@@ -1074,18 +1145,16 @@ def multistage_gemm[
10741145

10751146
# Dispatch heuristic from Llama3-405B TP=4 benchmarks on MI355X.
10761147
#
1077-
# Three kernels: standard GEMM, pingpong 256x256, skinny 128x256.
1078-
# Skinny pingpong dominates at small M (128-512) with 35-56%
1079-
# advantage over 256x256 due to better occupancy.
1080-
# 256x256 pingpong dominates at large M (>=640) with 15-25%
1081-
# advantage due to higher compute density per barrier.
1082-
# Crossover is consistent at M ~= 512-640 across all (N,K).
1148+
# For N >= 4096:
1149+
# M < 225: autotuned standard GEMM (small-M configs handle 128-256)
1150+
# 225 <= M < 600: skinny pingpong (1.1-1.3x vs vendor BLAS)
1151+
# M >= 600: pingpong 256x256 (~1.0x vs vendor, best custom)
10831152
#
1084-
# N >= 4096: skinny at M 128-512, 256x256 at M >= 640
1085-
# N < 4096: standard GEMM at small M, skinny at M >= 512
1086-
1153+
# For N < 4096 (e.g. N=2304):
1154+
# M < 750: autotuned standard GEMM (1.4-2.2x vs vendor BLAS)
1155+
# M >= 750: skinny pingpong (1.2-1.3x vs vendor BLAS)
10871156
if N >= 4096:
1088-
if M >= 640:
1157+
if M >= 600:
10891158
logger.info("Executing: AMD ping-pong matmul (256x256)")
10901159
ctx.enqueue_function[pingpong_kernel, pingpong_kernel](
10911160
a,
@@ -1097,7 +1166,7 @@ def multistage_gemm[
10971166
),
10981167
block_dim=pingpong_config.num_threads(),
10991168
)
1100-
elif M >= 128:
1169+
elif M >= 256:
11011170
logger.info("Executing: AMD skinny pingpong matmul")
11021171
ctx.enqueue_function[skinny_kernel, skinny_kernel](
11031172
a,
@@ -1119,7 +1188,7 @@ def multistage_gemm[
11191188
block_dim=config.block_dim(),
11201189
)
11211190
else:
1122-
if M >= 512:
1191+
if M >= 750:
11231192
logger.info("Executing: AMD skinny pingpong matmul")
11241193
ctx.enqueue_function[skinny_kernel, skinny_kernel](
11251194
a,
@@ -1353,19 +1422,8 @@ def multistage_gemm[
13531422
elementwise_lambda_fn=elementwise_lambda_fn,
13541423
]
13551424

1356-
# Dispatch heuristic from Llama3-405B TP=4 benchmarks on MI355X.
1357-
#
1358-
# Three kernels: standard GEMM, pingpong 256x256, skinny 128x256.
1359-
# Skinny pingpong dominates at small M (128-512) with 35-56%
1360-
# advantage over 256x256 due to better occupancy.
1361-
# 256x256 pingpong dominates at large M (>=640) with 15-25%
1362-
# advantage due to higher compute density per barrier.
1363-
# Crossover is consistent at M ~= 512-640 across all (N,K).
1364-
#
1365-
# N >= 4096: skinny at M 128-512, 256x256 at M >= 640
1366-
# N < 4096: standard GEMM at small M, skinny at M >= 512
13671425
if N >= 4096:
1368-
if M >= 640:
1426+
if M >= 600:
13691427
logger.info("Executing: AMD ping-pong matmul (256x256)")
13701428
ctx.enqueue_function[pingpong_kernel, pingpong_kernel](
13711429
a,
@@ -1377,7 +1435,7 @@ def multistage_gemm[
13771435
),
13781436
block_dim=pingpong_config.num_threads(),
13791437
)
1380-
elif M >= 128:
1438+
elif M >= 256:
13811439
logger.info("Executing: AMD skinny pingpong matmul")
13821440
ctx.enqueue_function[skinny_kernel, skinny_kernel](
13831441
a,
@@ -1399,7 +1457,7 @@ def multistage_gemm[
13991457
block_dim=config.block_dim(),
14001458
)
14011459
else:
1402-
if M >= 512:
1460+
if M >= 750:
14031461
logger.info("Executing: AMD skinny pingpong matmul")
14041462
ctx.enqueue_function[skinny_kernel, skinny_kernel](
14051463
a,

0 commit comments

Comments
 (0)