@@ -785,42 +785,37 @@ def _run_mxfp_gemm(gemm, shape):
785785def _run_mxfp_gemm_preshuffle (
786786 gemm ,
787787 shape ,
788- all = False ,
789- only_scale = False ,
790- only_b = False ,
791788 output_dtype = torch .float32 ,
792- transpose_output = False ,
789+ swap_inputs = False ,
790+ ** kwargs ,
793791):
794- """Run compiled GEMM kernel with preshuffled B and B_scale , verify against reference.
792+ """Run compiled GEMM kernel, verify against reference.
795793
796- Shuffling is applied based on the flags:
797- all - shuffle a_scale (x_scales), b_scale (w_scales), and b (w_t)
798- only_scale - shuffle a_scale (x_scales) and b_scale (w_scales) only
799- only_b - shuffle b_scale (w_scales) only
800-
801- When transpose_output is True, the kernel writes C^T [N, M] instead of C [M, N].
794+ When swap_inputs is True, the kernel computes C^T = B x A^T (with A=X, B=W)
795+ and writes C [M, N] directly via transpose_output + coalesced epilogue.
796+ When swap_inputs is False (baseline), uses standard input order.
802797 """
803798 x , w , x_scales , w_scales = generate_gemm_afp4wfp4_inputs (shape )
804799 torch_out = torchScaledGemmMXFP4 (x , w , x_scales , w_scales )
805800
806801 w_t = w .T .contiguous ()
807802
808- w_t_ps = b_preshuffle (w_t ) if all else w_t
809-
810- x_scales_ps = e8m0_shuffle (x_scales ) if (all or only_scale ) else x_scales
811-
812- w_scales_ps = e8m0_shuffle (w_scales ) if (all or only_scale or only_b ) else w_scales
813-
814- x , w_t_ps = x .cuda (), w_t_ps .cuda ()
815- x_scales_ps , w_scales_ps = x_scales_ps .cuda (), w_scales_ps .cuda ()
816- if transpose_output :
817- out = torch .zeros (w_t_ps .shape [0 ], x .shape [0 ], dtype = output_dtype ).cuda ()
803+ if swap_inputs :
804+ kern_a = w_t .cuda ()
805+ kern_a_scale = e8m0_shuffle (w_scales ).cuda ()
806+ kern_b = b_preshuffle (x ).cuda ()
807+ kern_b_scale = e8m0_shuffle (x_scales ).cuda ()
808+ out = torch .zeros (shape [0 ], shape [1 ], dtype = output_dtype ).cuda ()
809+ gemm (kern_a , kern_a_scale , kern_b , kern_b_scale , out )
810+ result = out .cpu ()
818811 else :
819- out = torch .zeros (x .shape [0 ], w_t_ps .shape [0 ], dtype = output_dtype ).cuda ()
820-
821- gemm (x , x_scales_ps , w_t_ps , w_scales_ps , out )
822-
823- result = out .T .contiguous ().cpu () if transpose_output else out .cpu ()
812+ kern_a = x .cuda ()
813+ kern_b = b_preshuffle (w_t ).cuda ()
814+ kern_a_scale = e8m0_shuffle (x_scales ).cuda ()
815+ kern_b_scale = e8m0_shuffle (w_scales ).cuda ()
816+ out = torch .zeros (shape [0 ], shape [1 ], dtype = output_dtype ).cuda ()
817+ gemm (kern_a , kern_a_scale , kern_b , kern_b_scale , out )
818+ result = out .cpu ()
824819
825820 if os .environ .get ("WAVE_DEBUG_COMPARE" ):
826821 ref = torch_out .to (torch .float32 ).cpu ()
@@ -1441,7 +1436,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_cvt_first(
14411436 gemm = wave_compile (options , gemm , schedule )
14421437
14431438 _run_mxfp_gemm_preshuffle (
1444- gemm , shape , all = True , output_dtype = torch .bfloat16 , transpose_output = True
1439+ gemm , shape , output_dtype = torch .bfloat16 , swap_inputs = True
14451440 )
14461441 print ("MXFP GEMM bf16 convert-first epilogue (WaveASM backend) test passed!" )
14471442
@@ -1525,7 +1520,7 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm_bf16_lds_epilogue(
15251520 gemm = wave_compile (options , gemm , schedule )
15261521
15271522 _run_mxfp_gemm_preshuffle (
1528- gemm , shape , all = True , output_dtype = torch .bfloat16 , transpose_output = True
1523+ gemm , shape , output_dtype = torch .bfloat16 , swap_inputs = True
15291524 )
15301525 print ("MXFP GEMM bf16 pipelined bpermute epilogue (WaveASM backend) test passed!" )
15311526
@@ -1539,15 +1534,21 @@ def _compile_bf16_kernel(
15391534 lds_epilogue = False ,
15401535 bpermute_masked = False ,
15411536 bpermute_pipelined = False ,
1537+ swap_inputs = False ,
15421538 transpose_only = False ,
15431539):
1544- """Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, transpose_output)."""
1540+ """Compile a bf16 kernel once (M,N,K dynamic). Returns (kernel, mode_str).
1541+
1542+ mode_str is one of: False (baseline), True (transpose), "swap" (input swap).
1543+ """
15451544 shape_placeholder = (block [0 ] * 4 , block [1 ] * 4 , block [2 ] * 4 )
15461545 use_transpose = (
15471546 coalesce
15481547 or lds_epilogue
15491548 or bpermute_masked
15501549 or bpermute_pipelined
1550+ or swap_inputs
1551+ or convert_first
15511552 or transpose_only
15521553 )
15531554 kwargs = dict (
@@ -1568,49 +1569,62 @@ def _compile_bf16_kernel(
15681569 options .use_wave_asm_backend = True
15691570 options .wave_runtime = True
15701571 options .eliminate_epilogue = True
1571- if coalesce or lds_epilogue or bpermute_masked or bpermute_pipelined :
1572+ if (
1573+ coalesce
1574+ or lds_epilogue
1575+ or bpermute_masked
1576+ or bpermute_pipelined
1577+ or swap_inputs
1578+ or convert_first
1579+ ):
15721580 options .coalesce_epilogue_stores = True
15731581 if bpermute_pipelined :
15741582 options .asm_transform = bpermute_pipelined_epilogue_transform
1583+ elif convert_first :
1584+ options .asm_transform = convert_first_eliminate_cndmask
1585+ elif swap_inputs :
1586+ options .asm_transform = bpermute_masked_epilogue_transform
15751587 elif bpermute_masked :
15761588 options .asm_transform = bpermute_masked_epilogue_transform
15771589 elif lds_epilogue :
15781590 options .asm_transform = lds_epilogue_transform
1579- elif convert_first :
1580- options .asm_transform = convert_first_eliminate_cndmask
15811591 elif dwordx4 :
15821592 options .asm_transform = coalesce_buffer_stores_dwordx4
15831593 schedule = get_mxfp4_asymmetric_schedule (
15841594 eliminate_epilogue = True ,
15851595 is_bscale_shuffled = True ,
15861596 )
15871597 options = set_default_run_config (options )
1588- return wave_compile (options , gemm , schedule ), use_transpose
1598+ mode = "swap" if swap_inputs else use_transpose
1599+ return wave_compile (options , gemm , schedule ), mode
15891600
15901601
1591- def _time_kernel (gemm , shape , transpose_output , warmup = 2 , iters = 5 ):
1602+ def _time_kernel (gemm , shape , warmup = 2 , iters = 5 , swap_inputs = False , ** kwargs ):
15921603 """Time a compiled GEMM kernel on the given shape. Returns median us."""
15931604 x , w , x_scales , w_scales = generate_gemm_afp4wfp4_inputs (shape )
15941605 w_t = w .T .contiguous ()
1595- w_t_ps = b_preshuffle (w_t )
1596- x_scales_ps = e8m0_shuffle (x_scales )
1597- w_scales_ps = e8m0_shuffle (w_scales )
1598- x , w_t_ps = x .cuda (), w_t_ps .cuda ()
1599- x_scales_ps , w_scales_ps = x_scales_ps .cuda (), w_scales_ps .cuda ()
1600- if transpose_output :
1601- out = torch .zeros (w_t_ps .shape [0 ], x .shape [0 ], dtype = torch .bfloat16 ).cuda ()
1606+
1607+ if swap_inputs :
1608+ kern_a = w_t .cuda ()
1609+ kern_a_scale = e8m0_shuffle (w_scales ).cuda ()
1610+ kern_b = b_preshuffle (x ).cuda ()
1611+ kern_b_scale = e8m0_shuffle (x_scales ).cuda ()
16021612 else :
1603- out = torch .zeros (x .shape [0 ], w_t_ps .shape [0 ], dtype = torch .bfloat16 ).cuda ()
1613+ kern_a = x .cuda ()
1614+ kern_b = b_preshuffle (w_t ).cuda ()
1615+ kern_a_scale = e8m0_shuffle (x_scales ).cuda ()
1616+ kern_b_scale = e8m0_shuffle (w_scales ).cuda ()
1617+ out = torch .zeros (shape [0 ], shape [1 ], dtype = torch .bfloat16 ).cuda ()
16041618
16051619 for _ in range (warmup ):
1606- gemm (x , x_scales_ps , w_t_ps , w_scales_ps , out )
1620+ gemm (kern_a , kern_a_scale , kern_b , kern_b_scale , out )
16071621 torch .cuda .synchronize ()
16081622
16091623 start_events = [torch .cuda .Event (enable_timing = True ) for _ in range (iters )]
16101624 end_events = [torch .cuda .Event (enable_timing = True ) for _ in range (iters )]
16111625 for i in range (iters ):
16121626 start_events [i ].record ()
1613- gemm (x , x_scales_ps , w_t_ps , w_scales_ps , out )
1627+ gemm (kern_a , kern_a_scale , kern_b , kern_b_scale , out )
16141628 end_events [i ].record ()
16151629 torch .cuda .synchronize ()
16161630
0 commit comments