@@ -133,6 +133,7 @@ def convert_first_eliminate_cndmask(asm_text):
133133 lines = asm_text .split ("\n " )
134134 out = []
135135 tile_count = 0
136+ vcc_emitted = False
136137 i = 0
137138
138139 while i < len (lines ):
@@ -216,6 +217,9 @@ def convert_first_eliminate_cndmask(asm_text):
216217 # Preserved lines (lane mask, offset select, addr comp, SRD)
217218 # must come first so v244 and v253 are set before we use them.
218219 out .extend (preserved )
220+ if not vcc_emitted :
221+ out .append (" v_cmp_ne_u32 vcc, v244, 0" )
222+ vcc_emitted = True
219223 out .extend (
220224 [
221225 f" v_accvgpr_read_b32 v0, a{ a0 } " ,
@@ -224,23 +228,16 @@ def convert_first_eliminate_cndmask(asm_text):
224228 f" v_accvgpr_read_b32 v0, a{ a2 } " ,
225229 f" v_accvgpr_read_b32 v1, a{ a3 } " ,
226230 f" v_cvt_pk_bf16_f32 v3, v0, v1" ,
231+ " v_mov_b32 v8, v2" ,
232+ " v_mov_b32 v9, v3" ,
227233 "s_nop 1" ,
228234 " v_permlane16_swap_b32 v4, v2" ,
229235 "s_nop 1" ,
230236 " v_permlane16_swap_b32 v5, v3" ,
231- f" v_accvgpr_read_b32 v0, a{ a0 } " ,
232- f" v_accvgpr_read_b32 v1, a{ a1 } " ,
233- f" v_cvt_pk_bf16_f32 v2, v0, v1" ,
234- f" v_accvgpr_read_b32 v0, a{ a2 } " ,
235- f" v_accvgpr_read_b32 v1, a{ a3 } " ,
236- f" v_cvt_pk_bf16_f32 v3, v0, v1" ,
237- " v_cmp_ne_u32 vcc, v244, 0" ,
238- " v_cndmask_b32 v0, v4, v2" ,
239- " v_cndmask_b32 v1, v5, v3" ,
240- " v_cndmask_b32 v6, v2, v4" ,
241- " v_cndmask_b32 v7, v3, v5" ,
242- " v_mov_b32 v2, v6" ,
243- " v_mov_b32 v3, v7" ,
237+ " v_cndmask_b32 v0, v4, v8" ,
238+ " v_cndmask_b32 v1, v5, v9" ,
239+ " v_cndmask_b32 v2, v8, v4" ,
240+ " v_cndmask_b32 v3, v9, v5" ,
244241 " v_lshlrev_b32 v245, 1, v253" ,
245242 f" buffer_store_dwordx4 v[0:3], v245, { srd } , 0 offen" ,
246243 ]
@@ -785,42 +782,37 @@ def _run_mxfp_gemm(gemm, shape):
785782def _run_mxfp_gemm_preshuffle (
786783 gemm ,
787784 shape ,
788- all = False ,
789- only_scale = False ,
790- only_b = False ,
791785 output_dtype = torch .float32 ,
792- transpose_output = False ,
786+ swap_inputs = False ,
787+ ** kwargs ,
793788):
794- """Run compiled GEMM kernel with preshuffled B and B_scale, verify against reference.
795-
796- Shuffling is applied based on the flags:
797- all - shuffle a_scale (x_scales), b_scale (w_scales), and b (w_t)
798- only_scale - shuffle a_scale (x_scales) and b_scale (w_scales) only
799- only_b - shuffle b_scale (w_scales) only
789+ """Run compiled GEMM kernel, verify against reference.
800790
801- When transpose_output is True, the kernel writes C^T [N, M] instead of C [M, N].
791+ When swap_inputs is True, the kernel computes C^T = B x A^T (with A=X, B=W)
792+ and writes C [M, N] directly via transpose_output + coalesced epilogue.
793+ When swap_inputs is False (baseline), uses standard input order.
802794 """
803795 x , w , x_scales , w_scales = generate_gemm_afp4wfp4_inputs (shape )
804796 torch_out = torchScaledGemmMXFP4 (x , w , x_scales , w_scales )
805797
806798 w_t = w .T .contiguous ()
807799
808- w_t_ps = b_preshuffle (w_t ) if all else w_t
809-
810- x_scales_ps = e8m0_shuffle (x_scales ) if (all or only_scale ) else x_scales
811-
812- w_scales_ps = e8m0_shuffle (w_scales ) if (all or only_scale or only_b ) else w_scales
813-
814- x , w_t_ps = x .cuda (), w_t_ps .cuda ()
815- x_scales_ps , w_scales_ps = x_scales_ps .cuda (), w_scales_ps .cuda ()
816- if transpose_output :
817- out = torch .zeros (w_t_ps .shape [0 ], x .shape [0 ], dtype = output_dtype ).cuda ()
800+ if swap_inputs :
801+ kern_a = w_t .cuda ()
802+ kern_a_scale = e8m0_shuffle (w_scales ).cuda ()
803+ kern_b = b_preshuffle (x ).cuda ()
804+ kern_b_scale = e8m0_shuffle (x_scales ).cuda ()
805+ out = torch .zeros (shape [0 ], shape [1 ], dtype = output_dtype ).cuda ()
806+ gemm (kern_a , kern_a_scale , kern_b , kern_b_scale , out )
807+ result = out .cpu ()
818808 else :
819- out = torch .zeros (x .shape [0 ], w_t_ps .shape [0 ], dtype = output_dtype ).cuda ()
820-
821- gemm (x , x_scales_ps , w_t_ps , w_scales_ps , out )
822-
823- result = out .T .contiguous ().cpu () if transpose_output else out .cpu ()
809+ kern_a = x .cuda ()
810+ kern_b = b_preshuffle (w_t ).cuda ()
811+ kern_a_scale = e8m0_shuffle (x_scales ).cuda ()
812+ kern_b_scale = e8m0_shuffle (w_scales ).cuda ()
813+ out = torch .zeros (shape [0 ], shape [1 ], dtype = output_dtype ).cuda ()
814+ gemm (kern_a , kern_a_scale , kern_b , kern_b_scale , out )
815+ result = out .cpu ()
824816
825817 if os .environ .get ("WAVE_DEBUG_COMPARE" ):
826818 ref = torch_out .to (torch .float32 ).cpu ()
@@ -1441,7 +1433,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_cvt_first(
14411433 gemm = wave_compile (options , gemm , schedule )
14421434
14431435 _run_mxfp_gemm_preshuffle (
1444- gemm , shape , all = True , output_dtype = torch .bfloat16 , transpose_output = True
1436+ gemm , shape , output_dtype = torch .bfloat16 , swap_inputs = True
14451437 )
14461438 print ("MXFP GEMM bf16 convert-first epilogue (WaveASM backend) test passed!" )
14471439
@@ -1525,7 +1517,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_lds_epilogue(
15251517 gemm = wave_compile (options , gemm , schedule )
15261518
15271519 _run_mxfp_gemm_preshuffle (
1528- gemm , shape , all = True , output_dtype = torch .bfloat16 , transpose_output = True
1520+ gemm , shape , output_dtype = torch .bfloat16 , swap_inputs = True
15291521 )
15301522 print ("MXFP GEMM bf16 pipelined bpermute epilogue (WaveASM backend) test passed!" )
15311523
@@ -1539,15 +1531,21 @@ def _compile_bf16_kernel(
15391531 lds_epilogue = False ,
15401532 bpermute_masked = False ,
15411533 bpermute_pipelined = False ,
1534+ swap_inputs = False ,
15421535 transpose_only = False ,
15431536):
1544- """Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, transpose_output)."""
1537+ """Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, mode_str).
1538+
1539+ mode_str is one of: False (baseline), True (transpose), "swap" (input swap).
1540+ """
15451541 shape_placeholder = (block [0 ] * 4 , block [1 ] * 4 , block [2 ] * 4 )
15461542 use_transpose = (
15471543 coalesce
15481544 or lds_epilogue
15491545 or bpermute_masked
15501546 or bpermute_pipelined
1547+ or swap_inputs
1548+ or convert_first
15511549 or transpose_only
15521550 )
15531551 kwargs = dict (
@@ -1568,49 +1566,62 @@ def _compile_bf16_kernel(
15681566 options .use_wave_asm_backend = True
15691567 options .wave_runtime = True
15701568 options .eliminate_epilogue = True
1571- if coalesce or lds_epilogue or bpermute_masked or bpermute_pipelined :
1569+ if (
1570+ coalesce
1571+ or lds_epilogue
1572+ or bpermute_masked
1573+ or bpermute_pipelined
1574+ or swap_inputs
1575+ or convert_first
1576+ ):
15721577 options .coalesce_epilogue_stores = True
15731578 if bpermute_pipelined :
15741579 options .asm_transform = bpermute_pipelined_epilogue_transform
1580+ elif convert_first :
1581+ options .asm_transform = convert_first_eliminate_cndmask
1582+ elif swap_inputs :
1583+ options .asm_transform = bpermute_masked_epilogue_transform
15751584 elif bpermute_masked :
15761585 options .asm_transform = bpermute_masked_epilogue_transform
15771586 elif lds_epilogue :
15781587 options .asm_transform = lds_epilogue_transform
1579- elif convert_first :
1580- options .asm_transform = convert_first_eliminate_cndmask
15811588 elif dwordx4 :
15821589 options .asm_transform = coalesce_buffer_stores_dwordx4
15831590 schedule = get_mxfp4_asymmetric_schedule (
15841591 eliminate_epilogue = True ,
15851592 is_bscale_shuffled = True ,
15861593 )
15871594 options = set_default_run_config (options )
1588- return wave_compile (options , gemm , schedule ), use_transpose
1595+ mode = "swap" if swap_inputs else use_transpose
1596+ return wave_compile (options , gemm , schedule ), mode
15891597
15901598
1591- def _time_kernel (gemm , shape , transpose_output , warmup = 2 , iters = 5 ):
1599+ def _time_kernel (gemm , shape , warmup = 2 , iters = 5 , swap_inputs = False , ** kwargs ):
15921600 """Time a compiled GEMM kernel on the given shape. Returns median us."""
15931601 x , w , x_scales , w_scales = generate_gemm_afp4wfp4_inputs (shape )
15941602 w_t = w .T .contiguous ()
1595- w_t_ps = b_preshuffle (w_t )
1596- x_scales_ps = e8m0_shuffle (x_scales )
1597- w_scales_ps = e8m0_shuffle (w_scales )
1598- x , w_t_ps = x .cuda (), w_t_ps .cuda ()
1599- x_scales_ps , w_scales_ps = x_scales_ps .cuda (), w_scales_ps .cuda ()
1600- if transpose_output :
1601- out = torch .zeros (w_t_ps .shape [0 ], x .shape [0 ], dtype = torch .bfloat16 ).cuda ()
1603+
1604+ if swap_inputs :
1605+ kern_a = w_t .cuda ()
1606+ kern_a_scale = e8m0_shuffle (w_scales ).cuda ()
1607+ kern_b = b_preshuffle (x ).cuda ()
1608+ kern_b_scale = e8m0_shuffle (x_scales ).cuda ()
16021609 else :
1603- out = torch .zeros (x .shape [0 ], w_t_ps .shape [0 ], dtype = torch .bfloat16 ).cuda ()
1610+ kern_a = x .cuda ()
1611+ kern_b = b_preshuffle (w_t ).cuda ()
1612+ kern_a_scale = e8m0_shuffle (x_scales ).cuda ()
1613+ kern_b_scale = e8m0_shuffle (w_scales ).cuda ()
1614+ out = torch .zeros (shape [0 ], shape [1 ], dtype = torch .bfloat16 ).cuda ()
16041615
16051616 for _ in range (warmup ):
1606- gemm (x , x_scales_ps , w_t_ps , w_scales_ps , out )
1617+ gemm (kern_a , kern_a_scale , kern_b , kern_b_scale , out )
16071618 torch .cuda .synchronize ()
16081619
16091620 start_events = [torch .cuda .Event (enable_timing = True ) for _ in range (iters )]
16101621 end_events = [torch .cuda .Event (enable_timing = True ) for _ in range (iters )]
16111622 for i in range (iters ):
16121623 start_events [i ].record ()
1613- gemm (x , x_scales_ps , w_t_ps , w_scales_ps , out )
1624+ gemm (kern_a , kern_a_scale , kern_b , kern_b_scale , out )
16141625 end_events [i ].record ()
16151626 torch .cuda .synchronize ()
16161627
0 commit comments