Skip to content

Commit 3cbedec

Browse files
committed
compute C^T=B.A^T; store C
Signed-off-by: xintin <gaurav.verma@amd.com>
1 parent 67b07a7 commit 3cbedec

1 file changed

Lines changed: 68 additions & 57 deletions

File tree

examples/python/7.1_schedule.py

Lines changed: 68 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def convert_first_eliminate_cndmask(asm_text):
133133
lines = asm_text.split("\n")
134134
out = []
135135
tile_count = 0
136+
vcc_emitted = False
136137
i = 0
137138

138139
while i < len(lines):
@@ -216,6 +217,9 @@ def convert_first_eliminate_cndmask(asm_text):
216217
# Preserved lines (lane mask, offset select, addr comp, SRD)
217218
# must come first so v244 and v253 are set before we use them.
218219
out.extend(preserved)
220+
if not vcc_emitted:
221+
out.append(" v_cmp_ne_u32 vcc, v244, 0")
222+
vcc_emitted = True
219223
out.extend(
220224
[
221225
f" v_accvgpr_read_b32 v0, a{a0}",
@@ -224,23 +228,16 @@ def convert_first_eliminate_cndmask(asm_text):
224228
f" v_accvgpr_read_b32 v0, a{a2}",
225229
f" v_accvgpr_read_b32 v1, a{a3}",
226230
f" v_cvt_pk_bf16_f32 v3, v0, v1",
231+
" v_mov_b32 v8, v2",
232+
" v_mov_b32 v9, v3",
227233
"s_nop 1",
228234
" v_permlane16_swap_b32 v4, v2",
229235
"s_nop 1",
230236
" v_permlane16_swap_b32 v5, v3",
231-
f" v_accvgpr_read_b32 v0, a{a0}",
232-
f" v_accvgpr_read_b32 v1, a{a1}",
233-
f" v_cvt_pk_bf16_f32 v2, v0, v1",
234-
f" v_accvgpr_read_b32 v0, a{a2}",
235-
f" v_accvgpr_read_b32 v1, a{a3}",
236-
f" v_cvt_pk_bf16_f32 v3, v0, v1",
237-
" v_cmp_ne_u32 vcc, v244, 0",
238-
" v_cndmask_b32 v0, v4, v2",
239-
" v_cndmask_b32 v1, v5, v3",
240-
" v_cndmask_b32 v6, v2, v4",
241-
" v_cndmask_b32 v7, v3, v5",
242-
" v_mov_b32 v2, v6",
243-
" v_mov_b32 v3, v7",
237+
" v_cndmask_b32 v0, v4, v8",
238+
" v_cndmask_b32 v1, v5, v9",
239+
" v_cndmask_b32 v2, v8, v4",
240+
" v_cndmask_b32 v3, v9, v5",
244241
" v_lshlrev_b32 v245, 1, v253",
245242
f" buffer_store_dwordx4 v[0:3], v245, {srd}, 0 offen",
246243
]
@@ -785,42 +782,37 @@ def _run_mxfp_gemm(gemm, shape):
785782
def _run_mxfp_gemm_preshuffle(
786783
gemm,
787784
shape,
788-
all=False,
789-
only_scale=False,
790-
only_b=False,
791785
output_dtype=torch.float32,
792-
transpose_output=False,
786+
swap_inputs=False,
787+
**kwargs,
793788
):
794-
"""Run compiled GEMM kernel with preshuffled B and B_scale, verify against reference.
795-
796-
Shuffling is applied based on the flags:
797-
all - shuffle a_scale (x_scales), b_scale (w_scales), and b (w_t)
798-
only_scale - shuffle a_scale (x_scales) and b_scale (w_scales) only
799-
only_b - shuffle b_scale (w_scales) only
789+
"""Run compiled GEMM kernel, verify against reference.
800790
801-
When transpose_output is True, the kernel writes C^T [N, M] instead of C [M, N].
791+
When swap_inputs is True, the kernel computes C^T = B x A^T (with A=X, B=W)
792+
and writes C [M, N] directly via transpose_output + coalesced epilogue.
793+
When swap_inputs is False (baseline), uses standard input order.
802794
"""
803795
x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape)
804796
torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales)
805797

806798
w_t = w.T.contiguous()
807799

808-
w_t_ps = b_preshuffle(w_t) if all else w_t
809-
810-
x_scales_ps = e8m0_shuffle(x_scales) if (all or only_scale) else x_scales
811-
812-
w_scales_ps = e8m0_shuffle(w_scales) if (all or only_scale or only_b) else w_scales
813-
814-
x, w_t_ps = x.cuda(), w_t_ps.cuda()
815-
x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda()
816-
if transpose_output:
817-
out = torch.zeros(w_t_ps.shape[0], x.shape[0], dtype=output_dtype).cuda()
800+
if swap_inputs:
801+
kern_a = w_t.cuda()
802+
kern_a_scale = e8m0_shuffle(w_scales).cuda()
803+
kern_b = b_preshuffle(x).cuda()
804+
kern_b_scale = e8m0_shuffle(x_scales).cuda()
805+
out = torch.zeros(shape[0], shape[1], dtype=output_dtype).cuda()
806+
gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out)
807+
result = out.cpu()
818808
else:
819-
out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=output_dtype).cuda()
820-
821-
gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out)
822-
823-
result = out.T.contiguous().cpu() if transpose_output else out.cpu()
809+
kern_a = x.cuda()
810+
kern_b = b_preshuffle(w_t).cuda()
811+
kern_a_scale = e8m0_shuffle(x_scales).cuda()
812+
kern_b_scale = e8m0_shuffle(w_scales).cuda()
813+
out = torch.zeros(shape[0], shape[1], dtype=output_dtype).cuda()
814+
gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out)
815+
result = out.cpu()
824816

825817
if os.environ.get("WAVE_DEBUG_COMPARE"):
826818
ref = torch_out.to(torch.float32).cpu()
@@ -1441,7 +1433,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_cvt_first(
14411433
gemm = wave_compile(options, gemm, schedule)
14421434

14431435
_run_mxfp_gemm_preshuffle(
1444-
gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True
1436+
gemm, shape, output_dtype=torch.bfloat16, swap_inputs=True
14451437
)
14461438
print("MXFP GEMM bf16 convert-first epilogue (WaveASM backend) test passed!")
14471439

@@ -1525,7 +1517,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_lds_epilogue(
15251517
gemm = wave_compile(options, gemm, schedule)
15261518

15271519
_run_mxfp_gemm_preshuffle(
1528-
gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True
1520+
gemm, shape, output_dtype=torch.bfloat16, swap_inputs=True
15291521
)
15301522
print("MXFP GEMM bf16 pipelined bpermute epilogue (WaveASM backend) test passed!")
15311523

@@ -1539,15 +1531,21 @@ def _compile_bf16_kernel(
15391531
lds_epilogue=False,
15401532
bpermute_masked=False,
15411533
bpermute_pipelined=False,
1534+
swap_inputs=False,
15421535
transpose_only=False,
15431536
):
1544-
"""Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, transpose_output)."""
1537+
"""Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, mode_str).
1538+
1539+
mode_str is one of: False (baseline), True (transpose), "swap" (input swap).
1540+
"""
15451541
shape_placeholder = (block[0] * 4, block[1] * 4, block[2] * 4)
15461542
use_transpose = (
15471543
coalesce
15481544
or lds_epilogue
15491545
or bpermute_masked
15501546
or bpermute_pipelined
1547+
or swap_inputs
1548+
or convert_first
15511549
or transpose_only
15521550
)
15531551
kwargs = dict(
@@ -1568,49 +1566,62 @@ def _compile_bf16_kernel(
15681566
options.use_wave_asm_backend = True
15691567
options.wave_runtime = True
15701568
options.eliminate_epilogue = True
1571-
if coalesce or lds_epilogue or bpermute_masked or bpermute_pipelined:
1569+
if (
1570+
coalesce
1571+
or lds_epilogue
1572+
or bpermute_masked
1573+
or bpermute_pipelined
1574+
or swap_inputs
1575+
or convert_first
1576+
):
15721577
options.coalesce_epilogue_stores = True
15731578
if bpermute_pipelined:
15741579
options.asm_transform = bpermute_pipelined_epilogue_transform
1580+
elif convert_first:
1581+
options.asm_transform = convert_first_eliminate_cndmask
1582+
elif swap_inputs:
1583+
options.asm_transform = bpermute_masked_epilogue_transform
15751584
elif bpermute_masked:
15761585
options.asm_transform = bpermute_masked_epilogue_transform
15771586
elif lds_epilogue:
15781587
options.asm_transform = lds_epilogue_transform
1579-
elif convert_first:
1580-
options.asm_transform = convert_first_eliminate_cndmask
15811588
elif dwordx4:
15821589
options.asm_transform = coalesce_buffer_stores_dwordx4
15831590
schedule = get_mxfp4_asymmetric_schedule(
15841591
eliminate_epilogue=True,
15851592
is_bscale_shuffled=True,
15861593
)
15871594
options = set_default_run_config(options)
1588-
return wave_compile(options, gemm, schedule), use_transpose
1595+
mode = "swap" if swap_inputs else use_transpose
1596+
return wave_compile(options, gemm, schedule), mode
15891597

15901598

1591-
def _time_kernel(gemm, shape, transpose_output, warmup=2, iters=5):
1599+
def _time_kernel(gemm, shape, warmup=2, iters=5, swap_inputs=False, **kwargs):
15921600
"""Time a compiled GEMM kernel on the given shape. Returns median us."""
15931601
x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape)
15941602
w_t = w.T.contiguous()
1595-
w_t_ps = b_preshuffle(w_t)
1596-
x_scales_ps = e8m0_shuffle(x_scales)
1597-
w_scales_ps = e8m0_shuffle(w_scales)
1598-
x, w_t_ps = x.cuda(), w_t_ps.cuda()
1599-
x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda()
1600-
if transpose_output:
1601-
out = torch.zeros(w_t_ps.shape[0], x.shape[0], dtype=torch.bfloat16).cuda()
1603+
1604+
if swap_inputs:
1605+
kern_a = w_t.cuda()
1606+
kern_a_scale = e8m0_shuffle(w_scales).cuda()
1607+
kern_b = b_preshuffle(x).cuda()
1608+
kern_b_scale = e8m0_shuffle(x_scales).cuda()
16021609
else:
1603-
out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=torch.bfloat16).cuda()
1610+
kern_a = x.cuda()
1611+
kern_b = b_preshuffle(w_t).cuda()
1612+
kern_a_scale = e8m0_shuffle(x_scales).cuda()
1613+
kern_b_scale = e8m0_shuffle(w_scales).cuda()
1614+
out = torch.zeros(shape[0], shape[1], dtype=torch.bfloat16).cuda()
16041615

16051616
for _ in range(warmup):
1606-
gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out)
1617+
gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out)
16071618
torch.cuda.synchronize()
16081619

16091620
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
16101621
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
16111622
for i in range(iters):
16121623
start_events[i].record()
1613-
gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out)
1624+
gemm(kern_a, kern_a_scale, kern_b, kern_b_scale, out)
16141625
end_events[i].record()
16151626
torch.cuda.synchronize()
16161627

0 commit comments

Comments
 (0)