Skip to content

Commit 2c03c78

Browse files
committed
addressed review commits
Signed-off-by: xintin <gaurav.verma@amd.com>
1 parent 33d140f commit 2c03c78

7 files changed

Lines changed: 130 additions & 55 deletions

File tree

examples/python/7.1_schedule.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from wave_lang.kernel.wave.templates import (
3333
get_tagged_mxfp4_gemm,
3434
get_tagged_mxfp4_gemm_preshuffle_b,
35+
get_tagged_mxfp4_gemm_preshuffle_b_wide_store,
3536
get_tagged_mxfp4_gemm_preshuffle_scales,
3637
get_tagged_mxfp4_gemm_preshuffle_scales_and_B,
3738
)
@@ -432,18 +433,16 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_wide_stores(
432433
):
433434
"""Preshuffle-B MXFP4 GEMM with dynamic M, N, K and wide epilogue stores.
434435
435-
Uses wide_stores=True to swap MFMA operands (B as LHS, A as RHS),
436+
Uses the wide_store variant to swap MFMA operands (B as LHS, A as RHS),
436437
aligning the accumulator's contiguous values with the output's stride-1
437438
dimension. The coalesce_wide_stores pass emits v_permlane16_swap_b32
438439
+ buffer_store_dwordx4 (8 bf16 per store) instead of buffer_store_short.
439440
"""
440-
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(
441+
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b_wide_store(
441442
shape,
442443
block,
443444
wave_shape=(2, 2),
444445
reorder_workgroups=True,
445-
output_dtype=tkl.bf16,
446-
wide_stores=True,
447446
)
448447
dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K]
449448
for sym in dynamic_symbols:

lit_tests/kernel/wave/wide_stores_mxfp4.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
Test wide store coalescing for preshuffle-B MXFP4 GEMM with bf16 output.
55
6-
When wide_stores=True, the kernel swaps MFMA operands (B as LHS, A as RHS)
6+
The wide_store variant kernel swaps MFMA operands (B as LHS, A as RHS)
77
so the accumulator's 4-contiguous values align with the output's stride-1
88
dimension. The coalesce_wide_stores pass tags eligible bf16 global
99
writes, and the codegen emits v_permlane16_swap_b32 to exchange data
@@ -23,21 +23,21 @@
2323
from wave_lang.kernel.wave.compile import wave_compile
2424
from wave_lang.kernel.wave.constraints import ScaledMMAType
2525
from wave_lang.kernel.wave.schedules import get_mxfp4_asymmetric_schedule
26-
from wave_lang.kernel.wave.templates import get_tagged_mxfp4_gemm_preshuffle_b
26+
from wave_lang.kernel.wave.templates import (
27+
get_tagged_mxfp4_gemm_preshuffle_b_wide_store,
28+
)
2729
from wave_lang.kernel.wave.utils.general_utils import run_test
2830

2931

3032
@run_test
3133
def test_wide_stores_preshuffle_b_mxfp4():
3234
shape = (1024, 3072, 8192)
3335
block = (256, 192, 256)
34-
kernel, options = get_tagged_mxfp4_gemm_preshuffle_b(
36+
kernel, options = get_tagged_mxfp4_gemm_preshuffle_b_wide_store(
3537
shape,
3638
block,
3739
wave_shape=(2, 2),
3840
reorder_workgroups=True,
39-
output_dtype=tkl.bf16,
40-
wide_stores=True,
4141
mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4,
4242
)
4343
dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K]

tests/kernel/wave_gemm_mxfp_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from wave_lang.kernel.wave.templates import (
3030
get_tagged_mxfp4_gemm,
3131
get_tagged_mxfp4_gemm_preshuffle_b,
32+
get_tagged_mxfp4_gemm_preshuffle_b_wide_store,
3233
get_tagged_mxfp4_gemm_preshuffle_scales,
3334
get_tagged_mxfp4_gemm_preshuffle_scales_and_B,
3435
)
@@ -1051,17 +1052,15 @@ def testScaledGemmMXFP4PreshuffleBWideStores(
10511052
):
10521053
"""End-to-end test for MXFP4 GEMM with wide epilogue stores (dwordx4).
10531054
1054-
Uses wide_stores=True to swap MFMA operands and emit buffer_store_dwordx4
1055-
via v_permlane16_swap_b32 for bf16 output.
1055+
Uses the wide_store variant to swap MFMA operands and emit
1056+
buffer_store_dwordx4 via v_permlane16_swap_b32 for bf16 output.
10561057
"""
1057-
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(
1058+
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b_wide_store(
10581059
shape,
10591060
block_shape,
10601061
wave_shape=wave_shape,
10611062
mfma_variant=mfma_variant,
10621063
reorder_workgroups=True,
1063-
output_dtype=tkl.bf16,
1064-
wide_stores=True,
10651064
)
10661065
dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K]
10671066
for sym in dynamic_symbols:

wave_lang/kernel/compiler/wave_codegen/read_write.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,14 +1436,16 @@ def _write_permlane_pack_to_global(
14361436
wide_i32 = vector_d.from_elements(v4i32_type, [d0, d1, d2, d3])
14371437
wide_vec = vector_d.bitcast(v8bf16_type, wide_i32)
14381438

1439-
four = arith_d.constant(idx_type, 4)
1439+
elems_per_thread = arith_d.constant(idx_type, num_elems)
14401440

14411441
adj_th = list(start_indices_th)
1442-
adj_th[-1] = arith_d.select(is_lower, adj_th[-1], arith_d.subi(adj_th[-1], four))
1442+
adj_th[-1] = arith_d.select(
1443+
is_lower, adj_th[-1], arith_d.subi(adj_th[-1], elems_per_thread)
1444+
)
14431445

14441446
adj_full = list(start_indices)
14451447
adj_full[-1] = arith_d.select(
1446-
is_lower, adj_full[-1], arith_d.subi(adj_full[-1], four)
1448+
is_lower, adj_full[-1], arith_d.subi(adj_full[-1], elems_per_thread)
14471449
)
14481450

14491451
_create_vec_read_write(

wave_lang/kernel/wave/templates/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .tagged_mxfp4_gemm import (
1010
get_tagged_mxfp4_gemm,
1111
get_tagged_mxfp4_gemm_preshuffle_b,
12+
get_tagged_mxfp4_gemm_preshuffle_b_wide_store,
1213
get_tagged_mxfp4_gemm_preshuffle_scales,
1314
get_tagged_mxfp4_gemm_preshuffle_scales_and_B,
1415
)
@@ -18,6 +19,7 @@
1819
"get_tagged_bshd_attention_kernel",
1920
"get_tagged_mxfp4_gemm",
2021
"get_tagged_mxfp4_gemm_preshuffle_b",
22+
"get_tagged_mxfp4_gemm_preshuffle_b_wide_store",
2123
"get_tagged_mxfp4_gemm_preshuffle_scales",
2224
"get_tagged_mxfp4_gemm_preshuffle_scales_and_B",
2325
]

wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py

Lines changed: 101 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
All ops are tagged for use with MXFP4 schedule functions (e.g. get_mxfp4_dbuf_schedule).
1111
1212
Provides:
13-
- get_tagged_mxfp4_gemm: vanilla (A, B via LDS)
14-
- get_tagged_mxfp4_gemm_preshuffle_b: B + B_scale preshuffled (direct global reads)
13+
- get_tagged_mxfp4_gemm: vanilla (A, B via LDS)
14+
- get_tagged_mxfp4_gemm_preshuffle_b: B + B_scale preshuffled (direct global reads)
15+
- get_tagged_mxfp4_gemm_preshuffle_b_wide_store: same + wide epilogue stores via permlane swap
1516
1617
Required tags: k_loop, read_a, read_a_scale, read_b, read_b_scale,
1718
bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale, scaled_mma.
@@ -377,37 +378,20 @@ def get_tagged_mxfp4_gemm_preshuffle_scales_and_B(
377378
)
378379

379380

380-
def get_tagged_mxfp4_gemm_preshuffle_b(
381-
shape: tuple[int, int, int] = (1024, 1024, 8192),
382-
block_shape: tuple[int, int, int] = (256, 256, 256),
383-
wave_shape: tuple[int, int] = (2, 2),
384-
mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4,
385-
a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE,
381+
def _get_tagged_mxfp4_gemm_preshuffle_b_impl(
382+
shape: tuple[int, int, int],
383+
block_shape: tuple[int, int, int],
384+
wave_shape: tuple[int, int],
385+
mfma_variant: ScaledMMAType,
386+
a_address_space: tkl.AddressSpace,
387+
*,
386388
a_scale_preshuffle: bool = True,
387-
reorder_workgroups=True,
388-
group_size_n=32,
389+
reorder_workgroups: bool = True,
390+
group_size_n: int = 32,
389391
output_dtype=tkl.f32,
390392
wide_stores: bool = False,
391393
):
392-
"""Return a tagged MXFP4 scaled GEMM kernel with preshuffled B and B_scale.
393-
394-
B data is read directly from global memory using a preshuffle mapping
395-
(aiter shuffle_weight permutation). B scales are also read from global
396-
memory using an e8m0 scale preshuffle mapping. A and A_scale go through
397-
shared memory (LDS) as usual.
398-
399-
All ops are tagged for use with MXFP4 schedule functions.
400-
401-
Args:
402-
shape: (M, N, K) problem dimensions.
403-
block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes.
404-
wave_shape: (WAVE_M, WAVE_N) waves per workgroup.
405-
mfma_variant: Scaled MMA instruction type.
406-
a_address_space: Address space for A and A_scale (typically SHARED).
407-
408-
Returns:
409-
(kernel_function, WaveCompileOptions)
410-
"""
394+
"""Shared implementation for preshuffle-B MXFP4 GEMM with optional wide stores."""
411395
M = tkl.sym.M
412396
N = tkl.sym.N
413397
K = tkl.sym.K
@@ -599,6 +583,94 @@ def repeat(
599583
return gemm, options
600584

601585

586+
def get_tagged_mxfp4_gemm_preshuffle_b(
587+
shape: tuple[int, int, int] = (1024, 1024, 8192),
588+
block_shape: tuple[int, int, int] = (256, 256, 256),
589+
wave_shape: tuple[int, int] = (2, 2),
590+
mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4,
591+
a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE,
592+
a_scale_preshuffle: bool = True,
593+
reorder_workgroups=True,
594+
group_size_n=32,
595+
output_dtype=tkl.f32,
596+
):
597+
"""Return a tagged MXFP4 scaled GEMM kernel with preshuffled B and B_scale.
598+
599+
B data is read directly from global memory using a preshuffle mapping
600+
(aiter shuffle_weight permutation). B scales are also read from global
601+
memory using an e8m0 scale preshuffle mapping. A and A_scale go through
602+
shared memory (LDS) as usual.
603+
604+
All ops are tagged for use with MXFP4 schedule functions.
605+
606+
Args:
607+
shape: (M, N, K) problem dimensions.
608+
block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes.
609+
wave_shape: (WAVE_M, WAVE_N) waves per workgroup.
610+
mfma_variant: Scaled MMA instruction type.
611+
a_address_space: Address space for A and A_scale (typically SHARED).
612+
613+
Returns:
614+
(kernel_function, WaveCompileOptions)
615+
"""
616+
return _get_tagged_mxfp4_gemm_preshuffle_b_impl(
617+
shape,
618+
block_shape,
619+
wave_shape,
620+
mfma_variant,
621+
a_address_space,
622+
a_scale_preshuffle=a_scale_preshuffle,
623+
reorder_workgroups=reorder_workgroups,
624+
group_size_n=group_size_n,
625+
output_dtype=output_dtype,
626+
wide_stores=False,
627+
)
628+
629+
630+
def get_tagged_mxfp4_gemm_preshuffle_b_wide_store(
631+
shape: tuple[int, int, int] = (1024, 1024, 8192),
632+
block_shape: tuple[int, int, int] = (256, 256, 256),
633+
wave_shape: tuple[int, int] = (2, 2),
634+
mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4,
635+
a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE,
636+
a_scale_preshuffle: bool = True,
637+
reorder_workgroups=True,
638+
group_size_n=32,
639+
output_dtype=tkl.bf16,
640+
):
641+
"""Return a tagged MXFP4 scaled GEMM kernel with preshuffled B, B_scale, and wide stores.
642+
643+
Like :func:`get_tagged_mxfp4_gemm_preshuffle_b` but swaps MFMA operands
644+
(B as LHS, A as RHS) so the accumulator's 4-contiguous values align with
645+
the output memory's stride-1 dimension. The ``coalesce_wide_stores`` pass
646+
emits ``v_permlane16_swap_b32`` + ``buffer_store_dwordx4`` (8 bf16 per
647+
store) instead of scalar ``buffer_store_short``.
648+
649+
Args:
650+
shape: (M, N, K) problem dimensions.
651+
block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes.
652+
wave_shape: (WAVE_M, WAVE_N) waves per workgroup.
653+
mfma_variant: Scaled MMA instruction type.
654+
a_address_space: Address space for A and A_scale (typically SHARED).
655+
output_dtype: Output element type (default bf16).
656+
657+
Returns:
658+
(kernel_function, WaveCompileOptions)
659+
"""
660+
return _get_tagged_mxfp4_gemm_preshuffle_b_impl(
661+
shape,
662+
block_shape,
663+
wave_shape,
664+
mfma_variant,
665+
a_address_space,
666+
a_scale_preshuffle=a_scale_preshuffle,
667+
reorder_workgroups=reorder_workgroups,
668+
group_size_n=group_size_n,
669+
output_dtype=output_dtype,
670+
wide_stores=True,
671+
)
672+
673+
602674
def _reorder_mxfp4_workgroups(m, n, block_m, block_n, group_size_n):
603675
"""Remap workgroup indices to a new order based on group_size_n along N dimension.
604676

wave_lang/kernel/wave/wide_store_coalescing.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
"""
66
Graph pass that tags eligible epilogue bf16 stores for wide store coalescing.
77
8-
When a kernel uses swapped MFMA operands (wide_stores=True), the
9-
accumulator's 4-contiguous values align with the output's stride-1
10-
dimension. This pass identifies Write nodes that use the source/target
11-
dimension remapping pattern (indicating swapped operands) and tags them
12-
so the codegen emits v_permlane16_swap_b32 + buffer_store_dwordx4
13-
instead of scalar buffer_store_short.
8+
When a kernel uses swapped MFMA operands (e.g.
9+
``get_tagged_mxfp4_gemm_preshuffle_b_wide_store``), the accumulator's
10+
4-contiguous values align with the output's stride-1 dimension. This
11+
pass identifies Write nodes that use the source/target dimension
12+
remapping pattern (indicating swapped operands) and tags them so the
13+
codegen emits v_permlane16_swap_b32 + buffer_store_dwordx4 instead of
14+
scalar buffer_store_short.
1415
1516
Only tags writes that satisfy ALL conditions:
1617
1. Target memory is global address space
@@ -30,9 +31,9 @@ def coalesce_wide_stores(trace: CapturedTrace):
3031
"""Tag eligible bf16 global writes for permlane16_swap wide stores.
3132
3233
Only tags Write nodes that use the source/target dimension remapping
33-
pattern, which indicates the kernel was built with ``wide_stores=True``
34-
(swapped MFMA operands). Writes without source/target are left
35-
untouched, making this pass safe to run unconditionally.
34+
pattern (swapped MFMA operands, as produced by the wide_store kernel
35+
variant). Writes without source/target are left untouched, making
36+
this pass safe to run unconditionally.
3637
"""
3738
import wave_lang.kernel.lang as tkl
3839

0 commit comments

Comments
 (0)