Skip to content

Commit 809d2cb

Browse files
committed
Add schedule for 256x224x256 macro tile
It is a no_unroll schedule to get under the register budget. This gets the macro tile functional with the waveasm backend. For the 7.1 example, it adds - `--wave_shape` flag -- Previously (1,4) was hard-coded, but the 256x224x256 tile needed (2, 2) because the N dimension was not divisible by 4 after pipelining... I think was the reason we chose that. - `--no_unroll` flag to access the new no_unroll schedule. The particular 7.1 example target for this work was `python examples/python/7.1_schedule.py --block 256,224,256 --shape 1024,896,8192 --wave_shape 2,2 --no-unroll --test test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp` This also adds an e2e waveasm test. At this stage no real effort has been made to make the schedule performant, just to get it working. Signed-off-by: William G Hatch <william@hatch.uno>
1 parent d519efd commit 809d2cb

5 files changed

Lines changed: 450 additions & 8 deletions

File tree

examples/python/7.1_schedule.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_mxfp4_dbuf_pingpong_schedule,
2828
get_mxfp4_dbuf_mixed_pingpong_schedule,
2929
get_mxfp4_asymmetric_schedule,
30+
get_mxfp4_asymmetric_nounroll_schedule,
3031
get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule,
3132
get_mxfp4_dbuf_pingpong_schedule_Bshuffled,
3233
get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds,
@@ -372,19 +373,28 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp(
372373
is_debug=False,
373374
shape=(512, 1024, 8192), # 4*T0, 4*T1, 8192
374375
block=(128, 256, 256),
376+
wave_shape=(1, 4),
375377
eliminate_epilogue=True,
378+
no_unroll=False,
376379
):
377380
"""Preshuffle-B MXFP4 GEMM using C++ WaveASM backend."""
378-
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4))
381+
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(
382+
shape, block, wave_shape=wave_shape
383+
)
379384
options.backend = "asm"
380385
options.use_buffer_ops = True
381-
options.wave_runtime = True
382386
options.use_wave_asm_backend = True
387+
options.wave_runtime = True
383388
options.dump_intermediates = "build/intermediates"
384389
options.eliminate_epilogue = eliminate_epilogue
385-
schedule = get_mxfp4_asymmetric_schedule(
386-
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
387-
)
390+
if no_unroll:
391+
schedule = get_mxfp4_asymmetric_nounroll_schedule(
392+
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
393+
)
394+
else:
395+
schedule = get_mxfp4_asymmetric_schedule(
396+
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
397+
)
388398
options.print_ir_after = "all" if is_debug else []
389399
options = set_default_run_config(options)
390400
gemm = wave_compile(options, gemm, schedule)
@@ -444,5 +454,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm(
444454
args.shape,
445455
args.block,
446456
args.eliminate_epilogue,
457+
args.wave_shape,
458+
no_unroll=args.no_unroll,
447459
)
448460
exit(0 if success else 1)

examples/python/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ def parse_args():
3636
default=None,
3737
help="Enable epilogue elimination (true/false)",
3838
)
39+
parser.add_argument(
40+
"--wave_shape",
41+
type=str,
42+
default=None,
43+
help="Wave shape, e.g. 2,2",
44+
)
45+
parser.add_argument(
46+
"--no-unroll",
47+
action="store_true",
48+
help="Use nounroll (unroll_factor=1) schedule variant",
49+
)
3950

4051
args = parser.parse_args()
4152

@@ -44,6 +55,8 @@ def parse_args():
4455
args.shape = tuple(map(int, args.shape.split(",")))
4556
if isinstance(args.block, str):
4657
args.block = tuple(map(int, args.block.split(",")))
58+
if isinstance(args.wave_shape, str):
59+
args.wave_shape = tuple(map(int, args.wave_shape.split(",")))
4760

4861
return args
4962

@@ -64,6 +77,8 @@ def run_test(
6477
shape=None,
6578
block=None,
6679
eliminate_epilogue=None,
80+
wave_shape=None,
81+
no_unroll=False,
6782
):
6883
"""Run a test function multiple times."""
6984
if test_name not in module_globals:
@@ -78,6 +93,10 @@ def run_test(
7893
kwargs["block"] = block
7994
if eliminate_epilogue is not None:
8095
kwargs["eliminate_epilogue"] = eliminate_epilogue
96+
if wave_shape is not None:
97+
kwargs["wave_shape"] = wave_shape
98+
if no_unroll:
99+
kwargs["no_unroll"] = True
81100

82101
for i in range(repeat):
83102
try:

tests/kernel/wave/asm/test_waveasm_e2e.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,7 @@ def _dbuf_mxfp4_helper(
11421142
wave_shape=None,
11431143
reorder_workgroups=None,
11441144
eliminate_epilogue=False,
1145+
no_unroll=False,
11451146
):
11461147
"""Shared helper for double-buffered MXFP4 scheduled GEMM tests.
11471148
@@ -1168,6 +1169,7 @@ def _dbuf_mxfp4_helper(
11681169
from wave_lang.kernel.wave.schedules import (
11691170
get_mxfp4_dbuf_schedule,
11701171
get_mxfp4_asymmetric_schedule,
1172+
get_mxfp4_asymmetric_nounroll_schedule,
11711173
)
11721174
from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType
11731175
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
@@ -1200,9 +1202,14 @@ def _dbuf_mxfp4_helper(
12001202
)
12011203
options.eliminate_epilogue = eliminate_epilogue
12021204
if use_schedule:
1203-
schedule = get_mxfp4_asymmetric_schedule(
1204-
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
1205-
)
1205+
if no_unroll:
1206+
schedule = get_mxfp4_asymmetric_nounroll_schedule(
1207+
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
1208+
)
1209+
else:
1210+
schedule = get_mxfp4_asymmetric_schedule(
1211+
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
1212+
)
12061213
else:
12071214
schedule = None
12081215
options.schedule = SchedulingType.NONE
@@ -1445,6 +1452,44 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend(
14451452
)
14461453

14471454

1455+
@pytest.mark.parametrize("eliminate_epilogue", [True], ids=["ee"])
1456+
@pytest.mark.parametrize(
1457+
"shape,block,wave_shape",
1458+
[
1459+
pytest.param((1024, 896, 8192), (256, 224, 256), (2, 2), id="256x224x256"),
1460+
],
1461+
)
1462+
def test_dbuf_4wave_mxfp4_nounroll_gemm_cpp_backend(
1463+
shape,
1464+
block,
1465+
wave_shape,
1466+
eliminate_epilogue,
1467+
compiler,
1468+
dump_asm,
1469+
):
1470+
"""End-to-end test for asymmetric MXFP4 GEMM with no-unroll schedule.
1471+
1472+
The no-unroll schedule (unroll_factor=1) reduces register pressure by
1473+
not unrolling the K-loop body, allowing larger block sizes like
1474+
256x224x256 that would otherwise exceed the 256-VGPR hardware limit
1475+
with the standard asymmetric schedule.
1476+
"""
1477+
_dbuf_mxfp4_helper(
1478+
shape=shape,
1479+
block=block,
1480+
num_waves=4,
1481+
use_stagger=False,
1482+
compiler=compiler,
1483+
dump_asm=dump_asm,
1484+
use_buffer_ops=True,
1485+
use_schedule=True,
1486+
output_dtype="f32",
1487+
wave_shape=wave_shape,
1488+
eliminate_epilogue=eliminate_epilogue,
1489+
no_unroll=True,
1490+
)
1491+
1492+
14481493
@pytest.mark.parametrize(
14491494
"shape,block,wave_shape",
14501495
[

wave_lang/kernel/wave/schedules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
get_mxfp4_dbuf_mixed_pingpong_schedule,
1919
get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule,
2020
get_mxfp4_asymmetric_schedule,
21+
get_mxfp4_asymmetric_nounroll_schedule,
2122
get_mxfp4_dbuf_pingpong_schedule_Bshuffled,
2223
get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds,
2324
)
@@ -34,6 +35,7 @@
3435
"get_mxfp4_dbuf_pingpong_schedule_Bshuffled",
3536
"get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds",
3637
"get_mxfp4_asymmetric_schedule",
38+
"get_mxfp4_asymmetric_nounroll_schedule",
3739
"get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule",
3840
"get_attention_prefetch_schedule",
3941
]

0 commit comments

Comments
 (0)