Skip to content

Commit 3de5084

Browse files
committed
compute C^T=B.A^T; store C
Signed-off-by: xintin <gaurav.verma@amd.com>
1 parent 67b07a7 commit 3de5084

1 file changed

Lines changed: 58 additions & 44 deletions

File tree

examples/python/7.1_schedule.py

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -785,42 +785,37 @@ def _run_mxfp_gemm(gemm, shape):
785785
def _run_mxfp_gemm_preshuffle(
786786
gemm,
787787
shape,
788-
all=False,
789-
only_scale=False,
790-
only_b=False,
791788
output_dtype=torch.float32,
792-
transpose_output=False,
789+
swap_inputs=False,
790+
**kwargs,
793791
):
794-
"""Run compiled GEMM kernel with preshuffled B and B_scale, verify against reference.
792+
"""Run compiled GEMM kernel, verify against reference.
795793
796-
Shuffling is applied based on the flags:
797-
all - shuffle a_scale (x_scales), b_scale (w_scales), and b (w_t)
798-
only_scale - shuffle a_scale (x_scales) and b_scale (w_scales) only
799-
only_b - shuffle b_scale (w_scales) only
800-
801-
When transpose_output is True, the kernel writes C^T [N, M] instead of C [M, N].
794+
When swap_inputs is True, the kernel computes C^T = B x A^T (with A=X, B=W)
795+
and writes C [M, N] directly via transpose_output + coalesced epilogue.
796+
When swap_inputs is False (baseline), uses standard input order.
802797
"""
803798
x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape)
804799
torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales)
805800

806801
w_t = w.T.contiguous()
807802

808-
w_t_ps = b_preshuffle(w_t) if all else w_t
809-
810-
x_scales_ps = e8m0_shuffle(x_scales) if (all or only_scale) else x_scales
811-
812-
w_scales_ps = e8m0_shuffle(w_scales) if (all or only_scale or only_b) else w_scales
813-
814-
x, w_t_ps = x.cuda(), w_t_ps.cuda()
815-
x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda()
816-
if transpose_output:
817-
out = torch.zeros(w_t_ps.shape[0], x.shape[0], dtype=output_dtype).cuda()
803+
if swap_inputs:
804+
kern_a = w_t.cuda()
805+
kern_a_scale = e8m0_shuffle(w_scales).cuda()
806+
kern_b = b_preshuffle(x).cuda()
807+
kern_b_scale = e8m0_shuffle(x_scales).cuda()
808+
out = torch.zeros(shape[0], shape[1], dtype=output_dtype).cuda()
809+
gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out)
810+
result = out.cpu()
818811
else:
819-
out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=output_dtype).cuda()
820-
821-
gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out)
822-
823-
result = out.T.contiguous().cpu() if transpose_output else out.cpu()
812+
kern_a = x.cuda()
813+
kern_b = b_preshuffle(w_t).cuda()
814+
kern_a_scale = e8m0_shuffle(x_scales).cuda()
815+
kern_b_scale = e8m0_shuffle(w_scales).cuda()
816+
out = torch.zeros(shape[0], shape[1], dtype=output_dtype).cuda()
817+
gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out)
818+
result = out.cpu()
824819

825820
if os.environ.get("WAVE_DEBUG_COMPARE"):
826821
ref = torch_out.to(torch.float32).cpu()
@@ -1441,7 +1436,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_cvt_first(
14411436
gemm = wave_compile(options, gemm, schedule)
14421437

14431438
_run_mxfp_gemm_preshuffle(
1444-
gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True
1439+
gemm, shape, output_dtype=torch.bfloat16, swap_inputs=True
14451440
)
14461441
print("MXFP GEMM bf16 convert-first epilogue (WaveASM backend) test passed!")
14471442

@@ -1525,7 +1520,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_lds_epilogue(
15251520
gemm = wave_compile(options, gemm, schedule)
15261521

15271522
_run_mxfp_gemm_preshuffle(
1528-
gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True
1523+
gemm, shape, output_dtype=torch.bfloat16, swap_inputs=True
15291524
)
15301525
print("MXFP GEMM bf16 pipelined bpermute epilogue (WaveASM backend) test passed!")
15311526

@@ -1539,15 +1534,21 @@ def _compile_bf16_kernel(
15391534
lds_epilogue=False,
15401535
bpermute_masked=False,
15411536
bpermute_pipelined=False,
1537+
swap_inputs=False,
15421538
transpose_only=False,
15431539
):
1544-
"""Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, transpose_output)."""
1540+
"""Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, mode_str).
1541+
1542+
mode_str is one of: False (baseline), True (transpose), "swap" (input swap).
1543+
"""
15451544
shape_placeholder = (block[0] * 4, block[1] * 4, block[2] * 4)
15461545
use_transpose = (
15471546
coalesce
15481547
or lds_epilogue
15491548
or bpermute_masked
15501549
or bpermute_pipelined
1550+
or swap_inputs
1551+
or convert_first
15511552
or transpose_only
15521553
)
15531554
kwargs = dict(
@@ -1568,49 +1569,62 @@ def _compile_bf16_kernel(
15681569
options.use_wave_asm_backend = True
15691570
options.wave_runtime = True
15701571
options.eliminate_epilogue = True
1571-
if coalesce or lds_epilogue or bpermute_masked or bpermute_pipelined:
1572+
if (
1573+
coalesce
1574+
or lds_epilogue
1575+
or bpermute_masked
1576+
or bpermute_pipelined
1577+
or swap_inputs
1578+
or convert_first
1579+
):
15721580
options.coalesce_epilogue_stores = True
15731581
if bpermute_pipelined:
15741582
options.asm_transform = bpermute_pipelined_epilogue_transform
1583+
elif convert_first:
1584+
options.asm_transform = convert_first_eliminate_cndmask
1585+
elif swap_inputs:
1586+
options.asm_transform = bpermute_masked_epilogue_transform
15751587
elif bpermute_masked:
15761588
options.asm_transform = bpermute_masked_epilogue_transform
15771589
elif lds_epilogue:
15781590
options.asm_transform = lds_epilogue_transform
1579-
elif convert_first:
1580-
options.asm_transform = convert_first_eliminate_cndmask
15811591
elif dwordx4:
15821592
options.asm_transform = coalesce_buffer_stores_dwordx4
15831593
schedule = get_mxfp4_asymmetric_schedule(
15841594
eliminate_epilogue=True,
15851595
is_bscale_shuffled=True,
15861596
)
15871597
options = set_default_run_config(options)
1588-
return wave_compile(options, gemm, schedule), use_transpose
1598+
mode = "swap" if swap_inputs else use_transpose
1599+
return wave_compile(options, gemm, schedule), mode
15891600

15901601

1591-
def _time_kernel(gemm, shape, transpose_output, warmup=2, iters=5):
1602+
def _time_kernel(gemm, shape, warmup=2, iters=5, swap_inputs=False, **kwargs):
15921603
"""Time a compiled GEMM kernel on the given shape. Returns median us."""
15931604
x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape)
15941605
w_t = w.T.contiguous()
1595-
w_t_ps = b_preshuffle(w_t)
1596-
x_scales_ps = e8m0_shuffle(x_scales)
1597-
w_scales_ps = e8m0_shuffle(w_scales)
1598-
x, w_t_ps = x.cuda(), w_t_ps.cuda()
1599-
x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda()
1600-
if transpose_output:
1601-
out = torch.zeros(w_t_ps.shape[0], x.shape[0], dtype=torch.bfloat16).cuda()
1606+
1607+
if swap_inputs:
1608+
kern_a = w_t.cuda()
1609+
kern_a_scale = e8m0_shuffle(w_scales).cuda()
1610+
kern_b = b_preshuffle(x).cuda()
1611+
kern_b_scale = e8m0_shuffle(x_scales).cuda()
16021612
else:
1603-
out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=torch.bfloat16).cuda()
1613+
kern_a = x.cuda()
1614+
kern_b = b_preshuffle(w_t).cuda()
1615+
kern_a_scale = e8m0_shuffle(x_scales).cuda()
1616+
kern_b_scale = e8m0_shuffle(w_scales).cuda()
1617+
out = torch.zeros(shape[0], shape[1], dtype=torch.bfloat16).cuda()
16041618

16051619
for _ in range(warmup):
1606-
gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out)
1620+
gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out)
16071621
torch.cuda.synchronize()
16081622

16091623
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
16101624
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
16111625
for i in range(iters):
16121626
start_events[i].record()
1613-
gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out)
1627+
gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out)
16141628
end_events[i].record()
16151629
torch.cuda.synchronize()
16161630

0 commit comments

Comments
 (0)