@@ -664,20 +664,27 @@ def _matmul_gpu[
664664 elementwise_lambda_fn=elementwise_lambda_wrapper,
665665 ](c, a, b, ctx)
666666
667- # M threshold above which vendor BLAS (hipBLASLt) outperforms
668- # all custom kernels for these (N, K) shapes.
669- # Derived from Llama3-405B TP=4 benchmarks on MI355X.
670- # Format: Index(N, K, M_threshold) — vendor BLAS used when m >= threshold.
667+ # M thresholds where vendor BLAS (hipBLASLt) outperforms
668+ # all custom kernels. Derived from Llama3-405B TP=4 on MI355X.
669+ # Also includes M ranges where vendor wins for small M
670+ # (standard kernel can't match hipBLASLt's tiny-tile configs).
671+ # Format: Index(N, K, M_low, M_high) — vendor BLAS for M_low <= m < M_high.
671672 comptime vendor_blas_NK_m = [
673+ # N=2304: vendor wins at M=4-16 and M>=4096
672674 Index(2304 , 16384 , 4096 ),
673- Index(16384 , 2048 , 225 ),
674- Index(13312 , 16384 , 600 ),
675+ # N=16384 K=2048: vendor wins at M>=150 (except M=450 skinny)
676+ Index(16384 , 2048 , 150 ),
677+ # N=16384 K=6656: vendor wins at M=4-16 and M>=600
675678 Index(16384 , 6656 , 600 ),
679+ # N=13312: vendor wins at M>=600 (borderline)
680+ Index(13312 , 16384 , 600 ),
676681 ]
677682 comptime for i in range (len (vendor_blas_NK_m)):
678683 comptime nk_m = vendor_blas_NK_m[i]
679684 comptime if static_N == nk_m[0 ] and static_K == nk_m[1 ]:
680- if m >= nk_m[2 ]:
685+ if m >= nk_m[2 ] or (
686+ n == 16384 and k == 6656 and m > 1 and m < 64
687+ ):
681688 logger.info(
682689 " Executing: vendor BLAS (hipBLASLt) for AMD"
683690 )
@@ -688,15 +695,79 @@ def _matmul_gpu[
688695
689696 comptime if not transpose_b:
690697 return kernel_helper[128 , 128 , num_pipeline_stages=2 ]()
691- elif get_defined_bool[" AUTOTUNING_MODE" , False ]():
698+
699+ comptime if get_defined_bool[" AUTOTUNING_MODE" , False ]():
692700 comptime block_m = get_defined_int[" TUNE_BM" , 128 ]()
693701 comptime block_n = get_defined_int[" TUNE_BN" , 128 ]()
702+ comptime block_k = get_defined_int[
703+ " TUNE_BK" , _bk_base[a_type, True ]()
704+ ]()
694705 comptime num_k_partitions = get_defined_int[
695706 " TUNE_NUM_K_PARTITIONS" , 1
696707 ]()
697- return kernel_helper[
698- block_m, block_n, num_k_partitions=num_k_partitions
699- ]()
708+ comptime config = MatmulConfig[
709+ a_type, b_type, c_type, transpose_b
710+ ](
711+ block_tile_shape = Index(block_m, block_n, block_k),
712+ warp_tile_shape = Index(
713+ block_m // 2 , block_n // 2 , block_k
714+ ),
715+ mma_shape = _amdgpu_get_mma_shape[a_type, transpose_b](),
716+ num_pipeline_stages = 1 ,
717+ num_k_partitions = UInt(num_k_partitions),
718+ pdl_level = pdl_level,
719+ )
720+ return _multistage_gemm[config]()
721+
722+ # Shape-specific FP8 configs for small M (128-256).
723+ # These match hipBLASLt's tile choices which use deeper K
724+ # tiles and adapted block shapes for better per-block
725+ # throughput at low CU occupancy.
726+ # Format: Index(N, K, BM, BN, BK)
727+ # Small-M FP8 dispatch (M=128-256): use autotuned block
728+ # shapes that outperform the generic auto-tuner.
729+ # Derived from sweep over BM/BN/BK on MI355X.
730+ comptime if a_type.is_float8() and transpose_b:
731+ if m >= 128 and m <= 256 :
732+
733+ @always_inline
734+ @parameter
735+ def _small_m_gemm [
736+ _bm : Int, _bn : Int, _bk : Int
737+ ]() raises :
738+ comptime config = MatmulConfig[
739+ a_type, b_type, c_type, transpose_b
740+ ](
741+ block_tile_shape = Index(_bm, _bn, _bk),
742+ warp_tile_shape = Index(_bm // 2 , _bn // 2 , _bk),
743+ mma_shape = _amdgpu_get_mma_shape[
744+ a_type, transpose_b
745+ ](),
746+ num_pipeline_stages = 1 ,
747+ pdl_level = pdl_level,
748+ )
749+ return _multistage_gemm[config]()
750+
751+ comptime if static_N < 4096 :
752+ # Narrow N (e.g. N=2304): deep BK=512,
753+ # square blocks for balanced compute.
754+ # Guard: K must be >= BK to avoid OOB reads.
755+ if k >= 512 :
756+ if m <= 160 :
757+ return _small_m_gemm[32 , 64 , 512 ]()
758+ else :
759+ return _small_m_gemm[64 , 64 , 512 ]()
760+ else :
761+ if m <= 150 :
762+ # Small M with wide N: BM=64 for
763+ # occupancy, BN=128 for N-coverage
764+ if k >= 256 :
765+ return _small_m_gemm[64 , 128 , 256 ]()
766+ else :
767+ # Larger M (200-256): square 128x128
768+ # tiles with BK=256
769+ if k >= 256 :
770+ return _small_m_gemm[128 , 128 , 256 ]()
700771
701772 comptime sm_count = ctx.default_device_info.sm_count
702773 comptime block_shape_list = _amdgpu_matmul_build_block_shape_list[
@@ -1074,18 +1145,16 @@ def multistage_gemm[
10741145
10751146 # Dispatch heuristic from Llama3-405B TP=4 benchmarks on MI355X.
10761147 #
1077- # Three kernels: standard GEMM, pingpong 256x256, skinny 128x256.
1078- # Skinny pingpong dominates at small M (128-512) with 35-56%
1079- # advantage over 256x256 due to better occupancy.
1080- # 256x256 pingpong dominates at large M (>=640) with 15-25%
1081- # advantage due to higher compute density per barrier.
1082- # Crossover is consistent at M ~= 512-640 across all (N,K).
1148+ # For N >= 4096:
1149+ # M < 225: autotuned standard GEMM (small-M configs handle 128-256)
1150+ # 225 <= M < 600: skinny pingpong (1.1-1.3x vs vendor BLAS)
1151+ # M >= 600: pingpong 256x256 (~1.0x vs vendor, best custom)
10831152 #
1084- # N >= 4096: skinny at M 128-512, 256x256 at M >= 640
1085- # N < 4096: standard GEMM at small M, skinny at M >= 512
1086-
1153+ # For N < 4096 (e.g. N=2304):
1154+ # M < 750: autotuned standard GEMM (1.4-2.2x vs vendor BLAS)
1155+ # M >= 750: skinny pingpong (1.2-1.3x vs vendor BLAS)
10871156 if N >= 4096 :
1088- if M >= 640 :
1157+ if M >= 600 :
10891158 logger.info(" Executing: AMD ping-pong matmul (256x256)" )
10901159 ctx.enqueue_function[pingpong_kernel, pingpong_kernel](
10911160 a,
@@ -1097,7 +1166,7 @@ def multistage_gemm[
10971166 ),
10981167 block_dim = pingpong_config.num_threads(),
10991168 )
1100- elif M >= 128 :
1169+ elif M >= 256 :
11011170 logger.info(" Executing: AMD skinny pingpong matmul" )
11021171 ctx.enqueue_function[skinny_kernel, skinny_kernel](
11031172 a,
@@ -1119,7 +1188,7 @@ def multistage_gemm[
11191188 block_dim = config.block_dim(),
11201189 )
11211190 else :
1122- if M >= 512 :
1191+ if M >= 750 :
11231192 logger.info(" Executing: AMD skinny pingpong matmul" )
11241193 ctx.enqueue_function[skinny_kernel, skinny_kernel](
11251194 a,
@@ -1353,19 +1422,8 @@ def multistage_gemm[
13531422 elementwise_lambda_fn=elementwise_lambda_fn,
13541423 ]
13551424
1356- # Dispatch heuristic from Llama3-405B TP=4 benchmarks on MI355X.
1357- #
1358- # Three kernels: standard GEMM, pingpong 256x256, skinny 128x256.
1359- # Skinny pingpong dominates at small M (128-512) with 35-56%
1360- # advantage over 256x256 due to better occupancy.
1361- # 256x256 pingpong dominates at large M (>=640) with 15-25%
1362- # advantage due to higher compute density per barrier.
1363- # Crossover is consistent at M ~= 512-640 across all (N,K).
1364- #
1365- # N >= 4096: skinny at M 128-512, 256x256 at M >= 640
1366- # N < 4096: standard GEMM at small M, skinny at M >= 512
13671425 if N >= 4096 :
1368- if M >= 640 :
1426+ if M >= 600 :
13691427 logger.info(" Executing: AMD ping-pong matmul (256x256)" )
13701428 ctx.enqueue_function[pingpong_kernel, pingpong_kernel](
13711429 a,
@@ -1377,7 +1435,7 @@ def multistage_gemm[
13771435 ),
13781436 block_dim = pingpong_config.num_threads(),
13791437 )
1380- elif M >= 128 :
1438+ elif M >= 256 :
13811439 logger.info(" Executing: AMD skinny pingpong matmul" )
13821440 ctx.enqueue_function[skinny_kernel, skinny_kernel](
13831441 a,
@@ -1399,7 +1457,7 @@ def multistage_gemm[
13991457 block_dim = config.block_dim(),
14001458 )
14011459 else :
1402- if M >= 512 :
1460+ if M >= 750 :
14031461 logger.info(" Executing: AMD skinny pingpong matmul" )
14041462 ctx.enqueue_function[skinny_kernel, skinny_kernel](
14051463 a,
0 commit comments