|
10 | 10 | All ops are tagged for use with MXFP4 schedule functions (e.g. get_mxfp4_dbuf_schedule). |
11 | 11 |
|
12 | 12 | Provides: |
13 | | - - get_tagged_mxfp4_gemm: vanilla (A, B via LDS) |
14 | | - - get_tagged_mxfp4_gemm_preshuffle_b: B + B_scale preshuffled (direct global reads) |
| 13 | + - get_tagged_mxfp4_gemm: vanilla (A, B via LDS) |
| 14 | + - get_tagged_mxfp4_gemm_preshuffle_b: B + B_scale preshuffled (direct global reads) |
| 15 | + - get_tagged_mxfp4_gemm_preshuffle_b_wide_store: same + wide epilogue stores via permlane swap |
15 | 16 |
|
16 | 17 | Required tags: k_loop, read_a, read_a_scale, read_b, read_b_scale, |
17 | 18 | bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale, scaled_mma. |
@@ -377,37 +378,20 @@ def get_tagged_mxfp4_gemm_preshuffle_scales_and_B( |
377 | 378 | ) |
378 | 379 |
|
379 | 380 |
|
380 | | -def get_tagged_mxfp4_gemm_preshuffle_b( |
381 | | - shape: tuple[int, int, int] = (1024, 1024, 8192), |
382 | | - block_shape: tuple[int, int, int] = (256, 256, 256), |
383 | | - wave_shape: tuple[int, int] = (2, 2), |
384 | | - mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4, |
385 | | - a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE, |
| 381 | +def _get_tagged_mxfp4_gemm_preshuffle_b_impl( |
| 382 | + shape: tuple[int, int, int], |
| 383 | + block_shape: tuple[int, int, int], |
| 384 | + wave_shape: tuple[int, int], |
| 385 | + mfma_variant: ScaledMMAType, |
| 386 | + a_address_space: tkl.AddressSpace, |
| 387 | + *, |
386 | 388 | a_scale_preshuffle: bool = True, |
387 | | - reorder_workgroups=True, |
388 | | - group_size_n=32, |
| 389 | + reorder_workgroups: bool = True, |
| 390 | + group_size_n: int = 32, |
389 | 391 | output_dtype=tkl.f32, |
390 | 392 | wide_stores: bool = False, |
391 | 393 | ): |
392 | | - """Return a tagged MXFP4 scaled GEMM kernel with preshuffled B and B_scale. |
393 | | -
|
394 | | - B data is read directly from global memory using a preshuffle mapping |
395 | | - (aiter shuffle_weight permutation). B scales are also read from global |
396 | | - memory using an e8m0 scale preshuffle mapping. A and A_scale go through |
397 | | - shared memory (LDS) as usual. |
398 | | -
|
399 | | - All ops are tagged for use with MXFP4 schedule functions. |
400 | | -
|
401 | | - Args: |
402 | | - shape: (M, N, K) problem dimensions. |
403 | | - block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes. |
404 | | - wave_shape: (WAVE_M, WAVE_N) waves per workgroup. |
405 | | - mfma_variant: Scaled MMA instruction type. |
406 | | - a_address_space: Address space for A and A_scale (typically SHARED). |
407 | | -
|
408 | | - Returns: |
409 | | - (kernel_function, WaveCompileOptions) |
410 | | - """ |
| 394 | + """Shared implementation for preshuffle-B MXFP4 GEMM with optional wide stores.""" |
411 | 395 | M = tkl.sym.M |
412 | 396 | N = tkl.sym.N |
413 | 397 | K = tkl.sym.K |
@@ -599,6 +583,94 @@ def repeat( |
599 | 583 | return gemm, options |
600 | 584 |
|
601 | 585 |
|
| 586 | +def get_tagged_mxfp4_gemm_preshuffle_b( |
| 587 | + shape: tuple[int, int, int] = (1024, 1024, 8192), |
| 588 | + block_shape: tuple[int, int, int] = (256, 256, 256), |
| 589 | + wave_shape: tuple[int, int] = (2, 2), |
| 590 | + mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4, |
| 591 | + a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE, |
| 592 | + a_scale_preshuffle: bool = True, |
| 593 | + reorder_workgroups=True, |
| 594 | + group_size_n=32, |
| 595 | + output_dtype=tkl.f32, |
| 596 | +): |
| 597 | + """Return a tagged MXFP4 scaled GEMM kernel with preshuffled B and B_scale. |
| 598 | +
|
| 599 | + B data is read directly from global memory using a preshuffle mapping |
| 600 | + (aiter shuffle_weight permutation). B scales are also read from global |
| 601 | + memory using an e8m0 scale preshuffle mapping. A and A_scale go through |
| 602 | + shared memory (LDS) as usual. |
| 603 | +
|
| 604 | + All ops are tagged for use with MXFP4 schedule functions. |
| 605 | +
|
| 606 | + Args: |
| 607 | + shape: (M, N, K) problem dimensions. |
| 608 | + block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes. |
| 609 | + wave_shape: (WAVE_M, WAVE_N) waves per workgroup. |
| 610 | + mfma_variant: Scaled MMA instruction type. |
| 611 | + a_address_space: Address space for A and A_scale (typically SHARED). |
| 612 | +
|
| 613 | + Returns: |
| 614 | + (kernel_function, WaveCompileOptions) |
| 615 | + """ |
| 616 | + return _get_tagged_mxfp4_gemm_preshuffle_b_impl( |
| 617 | + shape, |
| 618 | + block_shape, |
| 619 | + wave_shape, |
| 620 | + mfma_variant, |
| 621 | + a_address_space, |
| 622 | + a_scale_preshuffle=a_scale_preshuffle, |
| 623 | + reorder_workgroups=reorder_workgroups, |
| 624 | + group_size_n=group_size_n, |
| 625 | + output_dtype=output_dtype, |
| 626 | + wide_stores=False, |
| 627 | + ) |
| 628 | + |
| 629 | + |
| 630 | +def get_tagged_mxfp4_gemm_preshuffle_b_wide_store( |
| 631 | + shape: tuple[int, int, int] = (1024, 1024, 8192), |
| 632 | + block_shape: tuple[int, int, int] = (256, 256, 256), |
| 633 | + wave_shape: tuple[int, int] = (2, 2), |
| 634 | + mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4, |
| 635 | + a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE, |
| 636 | + a_scale_preshuffle: bool = True, |
| 637 | + reorder_workgroups=True, |
| 638 | + group_size_n=32, |
| 639 | + output_dtype=tkl.bf16, |
| 640 | +): |
| 641 | + """Return a tagged MXFP4 scaled GEMM kernel with preshuffled B, B_scale, and wide stores. |
| 642 | +
|
| 643 | + Like :func:`get_tagged_mxfp4_gemm_preshuffle_b` but swaps MFMA operands |
| 644 | + (B as LHS, A as RHS) so the accumulator's 4-contiguous values align with |
| 645 | + the output memory's stride-1 dimension. The ``coalesce_wide_stores`` pass |
| 646 | + emits ``v_permlane16_swap_b32`` + ``buffer_store_dwordx4`` (8 bf16 per |
| 647 | + store) instead of scalar ``buffer_store_short``. |
| 648 | +
|
| 649 | + Args: |
| 650 | + shape: (M, N, K) problem dimensions. |
| 651 | + block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes. |
| 652 | + wave_shape: (WAVE_M, WAVE_N) waves per workgroup. |
| 653 | + mfma_variant: Scaled MMA instruction type. |
| 654 | + a_address_space: Address space for A and A_scale (typically SHARED). |
| 655 | + output_dtype: Output element type (default bf16). |
| 656 | +
|
| 657 | + Returns: |
| 658 | + (kernel_function, WaveCompileOptions) |
| 659 | + """ |
| 660 | + return _get_tagged_mxfp4_gemm_preshuffle_b_impl( |
| 661 | + shape, |
| 662 | + block_shape, |
| 663 | + wave_shape, |
| 664 | + mfma_variant, |
| 665 | + a_address_space, |
| 666 | + a_scale_preshuffle=a_scale_preshuffle, |
| 667 | + reorder_workgroups=reorder_workgroups, |
| 668 | + group_size_n=group_size_n, |
| 669 | + output_dtype=output_dtype, |
| 670 | + wide_stores=True, |
| 671 | + ) |
| 672 | + |
| 673 | + |
602 | 674 | def _reorder_mxfp4_workgroups(m, n, block_m, block_n, group_size_n): |
603 | 675 | """Remap workgroup indices to a new order based on group_size_n along N dimension. |
604 | 676 |
|
|
0 commit comments