Skip to content

Commit 314dad9

Browse files
committed
Add schedule for 256x224x256 macro tile
It is a no_unroll schedule to get under the register budget. This gets the macro tile functional with the waveasm backend. For the 7.1 example, it adds - `--wave_shape` flag -- Previously (1,4) was hard-coded, but the 256x224x256 tile needed (2, 2) because the N dimension was not divisible by 4 after pipelining... I think was the reason we chose that. - `--no_unroll` flag to access the new no_unroll schedule. The particular 7.1 example target for this work was `python examples/python/7.1_schedule.py --block 256,224,256 --shape 1024,896,8192 --wave_shape 2,2 --no-unroll --test test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp` This also adds an e2e waveasm test. At this stage no real effort has been made to make the schedule performant, just to get it working.
1 parent be3a321 commit 314dad9

5 files changed

Lines changed: 448 additions & 8 deletions

File tree

examples/python/7.1_schedule.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_mxfp4_dbuf_pingpong_schedule,
2828
get_mxfp4_dbuf_mixed_pingpong_schedule,
2929
get_mxfp4_asymmetric_schedule,
30+
get_mxfp4_asymmetric_nounroll_schedule,
3031
get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule,
3132
get_mxfp4_dbuf_pingpong_schedule_Bshuffled,
3233
get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds,
@@ -372,19 +373,26 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp(
372373
is_debug=False,
373374
shape=(512, 1024, 8192), # 4*T0, 4*T1, 8192
374375
block=(128, 256, 256),
376+
wave_shape=(1, 4),
375377
eliminate_epilogue=True,
378+
no_unroll=False,
376379
):
377380
"""Preshuffle-B MXFP4 GEMM using C++ WaveASM backend."""
378-
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4))
381+
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=wave_shape)
379382
options.backend = "asm"
380383
options.use_buffer_ops = True
381-
options.wave_runtime = True
382384
options.use_wave_asm_backend = True
385+
options.wave_runtime = True
383386
options.dump_intermediates = "build/intermediates"
384387
options.eliminate_epilogue = eliminate_epilogue
385-
schedule = get_mxfp4_asymmetric_schedule(
386-
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
387-
)
388+
if no_unroll:
389+
schedule = get_mxfp4_asymmetric_nounroll_schedule(
390+
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
391+
)
392+
else:
393+
schedule = get_mxfp4_asymmetric_schedule(
394+
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
395+
)
388396
options.print_ir_after = "all" if is_debug else []
389397
options = set_default_run_config(options)
390398
gemm = wave_compile(options, gemm, schedule)
@@ -444,5 +452,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm(
444452
args.shape,
445453
args.block,
446454
args.eliminate_epilogue,
455+
args.wave_shape,
456+
no_unroll=args.no_unroll,
447457
)
448458
exit(0 if success else 1)

examples/python/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ def parse_args():
3636
default=None,
3737
help="Enable epilogue elimination (true/false)",
3838
)
39+
parser.add_argument(
40+
"--wave_shape",
41+
type=str,
42+
default=None,
43+
help="Wave shape, e.g. 2,2",
44+
)
45+
parser.add_argument(
46+
"--no-unroll",
47+
action="store_true",
48+
help="Use nounroll (unroll_factor=1) schedule variant",
49+
)
3950

4051
args = parser.parse_args()
4152

@@ -44,6 +55,8 @@ def parse_args():
4455
args.shape = tuple(map(int, args.shape.split(",")))
4556
if isinstance(args.block, str):
4657
args.block = tuple(map(int, args.block.split(",")))
58+
if isinstance(args.wave_shape, str):
59+
args.wave_shape = tuple(map(int, args.wave_shape.split(",")))
4760

4861
return args
4962

@@ -64,6 +77,8 @@ def run_test(
6477
shape=None,
6578
block=None,
6679
eliminate_epilogue=None,
80+
wave_shape=None,
81+
no_unroll=False,
6782
):
6883
"""Run a test function multiple times."""
6984
if test_name not in module_globals:
@@ -78,6 +93,10 @@ def run_test(
7893
kwargs["block"] = block
7994
if eliminate_epilogue is not None:
8095
kwargs["eliminate_epilogue"] = eliminate_epilogue
96+
if wave_shape is not None:
97+
kwargs["wave_shape"] = wave_shape
98+
if no_unroll:
99+
kwargs["no_unroll"] = True
81100

82101
for i in range(repeat):
83102
try:

tests/kernel/wave/asm/test_waveasm_e2e.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,7 @@ def _dbuf_mxfp4_helper(
11421142
wave_shape=None,
11431143
reorder_workgroups=None,
11441144
eliminate_epilogue=False,
1145+
no_unroll=False,
11451146
):
11461147
"""Shared helper for double-buffered MXFP4 scheduled GEMM tests.
11471148
@@ -1168,6 +1169,7 @@ def _dbuf_mxfp4_helper(
11681169
from wave_lang.kernel.wave.schedules import (
11691170
get_mxfp4_dbuf_schedule,
11701171
get_mxfp4_asymmetric_schedule,
1172+
get_mxfp4_asymmetric_nounroll_schedule,
11711173
)
11721174
from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType
11731175
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
@@ -1200,9 +1202,14 @@ def _dbuf_mxfp4_helper(
12001202
)
12011203
options.eliminate_epilogue = eliminate_epilogue
12021204
if use_schedule:
1203-
schedule = get_mxfp4_asymmetric_schedule(
1204-
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
1205-
)
1205+
if no_unroll:
1206+
schedule = get_mxfp4_asymmetric_nounroll_schedule(
1207+
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
1208+
)
1209+
else:
1210+
schedule = get_mxfp4_asymmetric_schedule(
1211+
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
1212+
)
12061213
else:
12071214
schedule = None
12081215
options.schedule = SchedulingType.NONE
@@ -1442,6 +1449,44 @@ def test_dbuf_4wave_mxfp4_gemm_cpp_backend(
14421449
)
14431450

14441451

1452+
@pytest.mark.parametrize("eliminate_epilogue", [True], ids=["ee"])
1453+
@pytest.mark.parametrize(
1454+
"shape,block,wave_shape",
1455+
[
1456+
pytest.param((1024, 896, 8192), (256, 224, 256), (2, 2), id="256x224x256"),
1457+
],
1458+
)
1459+
def test_dbuf_4wave_mxfp4_nounroll_gemm_cpp_backend(
1460+
shape,
1461+
block,
1462+
wave_shape,
1463+
eliminate_epilogue,
1464+
compiler,
1465+
dump_asm,
1466+
):
1467+
"""End-to-end test for asymmetric MXFP4 GEMM with no-unroll schedule.
1468+
1469+
The no-unroll schedule (unroll_factor=1) reduces register pressure by
1470+
not unrolling the K-loop body, allowing larger block sizes like
1471+
256x224x256 that would otherwise exceed the 256-VGPR hardware limit
1472+
with the standard asymmetric schedule.
1473+
"""
1474+
_dbuf_mxfp4_helper(
1475+
shape=shape,
1476+
block=block,
1477+
num_waves=4,
1478+
use_stagger=False,
1479+
compiler=compiler,
1480+
dump_asm=dump_asm,
1481+
use_buffer_ops=True,
1482+
use_schedule=True,
1483+
output_dtype="f32",
1484+
wave_shape=wave_shape,
1485+
eliminate_epilogue=eliminate_epilogue,
1486+
no_unroll=True,
1487+
)
1488+
1489+
14451490
@pytest.mark.parametrize(
14461491
"shape,block,wave_shape",
14471492
[

wave_lang/kernel/wave/schedules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
get_mxfp4_dbuf_mixed_pingpong_schedule,
1919
get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule,
2020
get_mxfp4_asymmetric_schedule,
21+
get_mxfp4_asymmetric_nounroll_schedule,
2122
get_mxfp4_dbuf_pingpong_schedule_Bshuffled,
2223
get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds,
2324
)
@@ -34,6 +35,7 @@
3435
"get_mxfp4_dbuf_pingpong_schedule_Bshuffled",
3536
"get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds",
3637
"get_mxfp4_asymmetric_schedule",
38+
"get_mxfp4_asymmetric_nounroll_schedule",
3739
"get_mxfp4_dbuf_mixed_pingpong_shuffle_schedule",
3840
"get_attention_prefetch_schedule",
3941
]

0 commit comments

Comments
 (0)