From bb904d5b17e3e6ae3ae9ad282d1daa0a4bbb546a Mon Sep 17 00:00:00 2001 From: jrosas Date: Fri, 29 May 2026 15:16:03 +0000 Subject: [PATCH 01/10] Initial commit with gluon fused_mxfp4_quant file --- 3rdparty/composable_kernel | 1 - .../gfx1250/quant/fuse_mxfp4_quant.py | 280 ++++++++++++++++++ aiter/ops/triton/quant/fused_mxfp4_quant.py | 25 +- .../quant/test_fused_mxfp4_quant.py | 6 + 4 files changed, 310 insertions(+), 2 deletions(-) delete mode 160000 3rdparty/composable_kernel create mode 100644 aiter/ops/triton/_gluon_kernels/gfx1250/quant/fuse_mxfp4_quant.py diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel deleted file mode 160000 index af7118e342..0000000000 --- a/3rdparty/composable_kernel +++ /dev/null @@ -1 +0,0 @@ -Subproject commit af7118e342580ecd3f71edce7b1d0ba465012ecf diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fuse_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fuse_mxfp4_quant.py new file mode 100644 index 0000000000..02d79cc04a --- /dev/null +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fuse_mxfp4_quant.py @@ -0,0 +1,280 @@ +import triton +from triton.experimental import gluon +from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op +from aiter.ops.triton._triton_kernels.quant.rmsnorm import _rmsnorm_op +from triton.experimental.gluon import language as gl + + +@triton.heuristics( + { + "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N1"] % (args["BLOCK_SIZE_N"]) == 0, + } +) +@gluon.jit +def _gluon_fused_rms_mxfp4_quant_kernel( + x1_ptr, + w1_ptr, + x2_ptr, + w2_ptr, + res1_ptr, + out1_fp4_ptr, + out1_bs_ptr, + out2_ptr, + out_res1_ptr, + out1_ptr, + eps1, + eps2, + M, + N1, + N2, + x1_stride_m, + x2_stride_m, + res1_stride_m, + out1_fp4_stride_m, + out1_bs_stride_m, + out1_bs_stride_n, + out2_stride_m, + out_res1_stride_m, + out1_stride_m, + BLOCK_SIZE_M: gl.constexpr, + BLOCK_SIZE_N: gl.constexpr, + BLOCK_SIZE_N2: gl.constexpr, + MXFP4_QUANT_BLOCK_SIZE: gl.constexpr, + HAS_SECOND_INPUT: gl.constexpr, + FIRST_INPUT_RES: gl.constexpr, + FIRST_INPUT_OUT: gl.constexpr, + SCALE_N: gl.constexpr, + SCALE_M_PAD: gl.constexpr, + SCALE_N_PAD: gl.constexpr, + SHUFFLE: gl.constexpr, + SHUFFLE_PAD: gl.constexpr, + EVEN_M_N: gl.constexpr, +): + start_pid = gl.program_id(0) + # get number of programs to determine is 1 or 2 passes + num_pid_m = gl.cdiv(M, BLOCK_SIZE_M) + + # create block layouts + gLayout2D: gl.constexpr = gl.BlockedLayout( + [1, 2], # sizePerThread + [1, 32], # threadsPerWarp + [1, 4], # warpsPerCTA + [1, 0], # order + ) + + gLayoutN: gl.constexpr = gl.SliceLayout(0, gLayout2D) + + # 2D shared layout for matrix rows; 1D shared layout for weight vectors + sharedLayout2D: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + sharedLayoutN: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[0]) + + # Tensor descriptors for first input and its weights + x1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + x1_ptr, + [M, N1], + [x1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + + # tensor descriptor for weight 1 + w1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + w1_ptr, + [N1], + [1], + [BLOCK_SIZE_N], + sharedLayoutN, + ) + + # Shared memory for first input and its weights + smemX1 = gl.allocate_shared_memory( + x1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + smemW1 = gl.allocate_shared_memory( + w1_ptr.dtype.element_ty, [BLOCK_SIZE_N], sharedLayoutN + ) + + # Load x1 and optionally res1 in parallel, then wait + gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1) + gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1) + + # Tensor descriptor and shared memory for optional residual input + if FIRST_INPUT_RES: + res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + res1_ptr, + [M, N1], + [res1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + + smemRes1 = gl.allocate_shared_memory( + res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + + gl.amd.gfx1250.tdm.async_load( + res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1 + ) + + out_res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out_res1_ptr, + [M, N1], + [out_res1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + + smemOutRes1 = gl.allocate_shared_memory( + out_res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + + # Second input path — programs with id >= num_pid_m handle x2 + if start_pid >= num_pid_m: + if HAS_SECOND_INPUT: + x2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + x2_ptr, + [M, N2], + [x2_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N2], + sharedLayout2D, + ) + # Load x2 and w2 in parallel then wait for both + w2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + w2_ptr, + [N2], + [1], + [BLOCK_SIZE_N2], + sharedLayoutN, + ) + smemX2 = gl.allocate_shared_memory( + x2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D + ) + smemW2 = gl.allocate_shared_memory( + w2_ptr.dtype.element_ty, [BLOCK_SIZE_N2], sharedLayoutN + ) + start_pid -= num_pid_m + + gl.amd.gfx1250.tdm.async_load( + x2_desec, [start_pid * BLOCK_SIZE_M, 0], smemX2 + ) + gl.amd.gfx1250.tdm.async_load(w2_desec, [0], smemW2) + gl.amd.gfx1250.tdm.async_wait(0) + + x2 = smemX2.load(gLayout2D).to(gl.float32) + w2 = smemW2.load(gLayoutN).to(gl.float32) + w2 = w2.reshape(1, BLOCK_SIZE_N2) + w2 = gl.convert_layout(w2, gLayout2D) + norm2 = _rmsnorm_op(x2, w2, N2, eps2) + + # Store norm2 output via TDM + out2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out2_ptr, + [M, N2], + [out2_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N2], + sharedLayout2D, + ) + smemOut2 = gl.allocate_shared_memory( + out2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D + ) + smemOut2.store(norm2.to(out2_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store( + out2_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut2 + ) + gl.amd.gfx1250.tdm.async_wait(0) + return + + # First input path + NUM_QUANT_BLOCKS: gl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE + x1 = smemX1.load(gLayout2D).to(gl.float32) + + if FIRST_INPUT_RES: + res1_loaded = smemRes1.load(gLayout2D).to(gl.float32) + x1 = x1 + res1_loaded + smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store( + out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1 + ) + + w1 = smemW1.load(gLayoutN).to(gl.float32) + w1 = w1.reshape(1, BLOCK_SIZE_N) + w1 = gl.convert_layout(w1, gLayout2D) + norm1 = _rmsnorm_op(x1, w1, N1, eps1) + + # Store unquantized output via TDM (optional) + if FIRST_INPUT_OUT: + out1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out1_ptr, + [M, N1], + [out1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + smemOut1 = gl.allocate_shared_memory( + out1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + smemOut1.store(norm1.to(out1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store( + out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1 + ) + + out1_fp4, bs_e8m0 = _mxfp4_quant_op( + norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + ) + gl.amd.gfx1250.tdm.async_wait(0) + + # out1_fp4 uses half-width (packed) offsets — keep as regular store + fp4_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) + half_x_offs_n = gl.arange(0, BLOCK_SIZE_N // 2) + out_mask1 = (half_x_offs_n < (N1 // 2))[None, :] + if not EVEN_M_N: + out_mask1 = out_mask1 & (fp4_offs_m < M)[:, None] + + gl.store( + out1_fp4_ptr + fp4_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], + out1_fp4, + mask=out_mask1, + ) + + # out1_bs uses non-linear shuffle offsets — keep as regular store + bs_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) + bs_offs_n = gl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + if SHUFFLE: + bs_offs_0 = bs_offs_m[:, None] >> 5 # // 32 + bs_offs_1 = bs_offs_m[:, None] & 31 # % 32 + bs_offs_2 = bs_offs_1 & 15 # % 16 + bs_offs_1 = bs_offs_1 >> 4 # // 16 + bs_offs_3 = bs_offs_n[None, :] >> 3 # // 8 + bs_offs_4 = bs_offs_n[None, :] & 7 # % 8 + bs_offs_5 = bs_offs_4 & 3 # % 4 + bs_offs_4 = bs_offs_4 >> 2 # // 4 + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * SCALE_N_PAD + ) + bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :] + bs_e8m0 = gl.where(bs_mask_127, bs_e8m0, 127) + else: + bs_offs = ( + bs_offs_m[:, None] * out1_bs_stride_m + + bs_offs_n[None, :] * out1_bs_stride_n + ) + + bs_mask = None + if not EVEN_M_N: + if SHUFFLE_PAD: + bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[ + None, : + ] + else: + bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] + + gl.store( + out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask + ) diff --git a/aiter/ops/triton/quant/fused_mxfp4_quant.py b/aiter/ops/triton/quant/fused_mxfp4_quant.py index c774b8944b..0131f53a7d 100644 --- a/aiter/ops/triton/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/quant/fused_mxfp4_quant.py @@ -3,6 +3,7 @@ import triton import triton.language as tl from typing import Optional +from aiter.ops.triton.utils._triton.arch_info import get_arch from aiter.utility import dtypes from aiter.ops.triton._triton_kernels.quant.fused_mxfp4_quant import ( _fused_rms_mxfp4_quant_kernel, @@ -11,6 +12,9 @@ _fused_reduce_rms_mxfp4_quant_kernel, _fused_dynamic_mxfp4_quant_moe_sort_kernel, ) +from aiter.ops.triton._gluon_kernels.gfx1250.quant.fuse_mxfp4_quant import ( + _gluon_fused_rms_mxfp4_quant_kernel, +) from aiter.ops.triton._triton_kernels.activation import ( _get_activation_from_str, ) @@ -30,6 +34,7 @@ def fused_rms_mxfp4_quant( shuffle: Optional[bool] = False, scale_shuffle_padding: Optional[bool] = False, output_unquantized_inp1=False, + inargs: str = "auto", ): """ This op contains several steps: @@ -104,8 +109,26 @@ def fused_rms_mxfp4_quant( x2_stride_m = x2.stride(0) out2_stride_m = out2.stride(0) + # checks args for gluon or triton. Auto will default to best kernel based on hardware arch + + if inargs == "auto": + if get_arch() == "gfx1250": + kernel = _gluon_fused_rms_mxfp4_quant_kernel + else: + kernel = _fused_rms_mxfp4_quant_kernel + elif inargs == "gluon": + if get_arch() != "gfx1250": + raise ValueError("Gluon kernel only supported on gfx1250") + kernel = _gluon_fused_rms_mxfp4_quant_kernel + elif inargs == "triton": + kernel = _fused_rms_mxfp4_quant_kernel + else: + raise ValueError( + f"Invalid argument: {inargs}. Chose from auto, gluon, or triton" + ) + grid = (triton.cdiv(M, BLOCK_SIZE_M) * (2 if (x2 is not None) else 1),) - _fused_rms_mxfp4_quant_kernel[grid]( + kernel[grid]( x1, x1_weight, x2, diff --git a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py index 7cdd5cd48a..28bc408921 100644 --- a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py +++ b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py @@ -158,6 +158,7 @@ def test_flatten_quant(B: int, M: int, N: int, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("scale_shuffle_padding", [True, False]) +@pytest.mark.parametrize("inargs", ["auto", "gluon", "triton"]) def test_fused_rms_quant( M: int, N1: int, @@ -168,10 +169,14 @@ def test_fused_rms_quant( dtype, shuffle: bool, scale_shuffle_padding: bool, + inargs: str, ): if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") + if inargs == "gluon" and arch_info.get_arch() != "gfx1250": + pytest.skip("Gluon kernel only supported on gfx1250") + torch.manual_seed(0) torch.cuda.empty_cache() # Helps avoid hangs in large tests @@ -202,6 +207,7 @@ def test_fused_rms_quant( shuffle=shuffle, scale_shuffle_padding=scale_shuffle_padding, output_unquantized_inp1=True, + inargs=inargs, ) ) From b611774d3efc408f2aeacab67b20dce7271fb537 Mon Sep 17 00:00:00 2001 From: jrosas Date: Thu, 14 May 2026 23:23:24 +0000 Subject: [PATCH 02/10] Initial gluon version of fused_mxfp4_quant --- .../gfx1250/quant/fused_mxfp4_quant.py | 285 ++++++++++++++++++ aiter/ops/triton/quant/fused_mxfp4_quant.py | 13 +- .../quant/test_fused_mxfp4_quant.py | 55 +--- 3 files changed, 295 insertions(+), 58 deletions(-) create mode 100644 aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py new file mode 100644 index 0000000000..0d886ada42 --- /dev/null +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py @@ -0,0 +1,285 @@ +import torch +import triton +from triton.experimental import gluon +from typing import Optional +from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op +from triton.experimental.gluon import language as gl + +@gluon.jit +def _rmsnorm_op( + row, + weights, + n_cols, + epsilon, +): + + row_norm = row*row + row_norm = gl.sum(row_norm, axis=-1, keep_dims=True) + norm_factor = gl.rsqrt((row_norm / n_cols) + epsilon) + + rms_norm = row * norm_factor * weights + return rms_norm + +@triton.heuristics( + { + "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N1"] % (args["BLOCK_SIZE_N"]) == 0, + "EVEN_M_N2": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N2"] % (args["BLOCK_SIZE_N2"]) == 0, + } +) +@gluon.jit +def _gluon_fused_rms_mxfp4_quant_kernel( + x1_ptr, + w1_ptr, + x2_ptr, + w2_ptr, + res1_ptr, + out1_fp4_ptr, + out1_bs_ptr, + out2_ptr, + out_res1_ptr, + out1_ptr, + eps1, + eps2, + M, + N1, + N2, + x1_stride_m, + x2_stride_m, + res1_stride_m, + out1_fp4_stride_m, + out1_bs_stride_m, + out1_bs_stride_n, + out2_stride_m, + out_res1_stride_m, + out1_stride_m, + BLOCK_SIZE_M: gl.constexpr, + BLOCK_SIZE_N: gl.constexpr, + BLOCK_SIZE_N2: gl.constexpr, + MXFP4_QUANT_BLOCK_SIZE: gl.constexpr, + HAS_SECOND_INPUT: gl.constexpr, + FIRST_INPUT_RES: gl.constexpr, + FIRST_INPUT_OUT: gl.constexpr, + SCALE_N: gl.constexpr, + SCALE_M_PAD: gl.constexpr, + SCALE_N_PAD: gl.constexpr, + SHUFFLE: gl.constexpr, + SHUFFLE_PAD: gl.constexpr, + EVEN_M_N: gl.constexpr, + EVEN_M_N2: gl.constexpr, +): + start_pid = gl.program_id(0) + #get number of programs to determine is 1 or 2 passes + num_pid_m = gl.cdiv(M, BLOCK_SIZE_M) + + #create block layouts + gLayout2D: gl.constexpr = gl.BlockedLayout( + [1, 2], # sizePerThread + [1, 32], # threadsPerWarp + [1, 4], # warpsPerCTA + [1, 0], # order + ) + + gLayoutM: gl.constexpr = gl.SliceLayout(1, gLayout2D) + gLayoutN: gl.constexpr = gl.SliceLayout(0, gLayout2D) + + # 2D shared layout for matrix rows; 1D shared layout for weight vectors + sharedLayout2D: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + sharedLayoutN: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[0]) + + # Tensor descriptors for first input and its weights + x1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + x1_ptr, + [M, N1], + [x1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + + w1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + w1_ptr, + [N1], + [1], + [BLOCK_SIZE_N], + sharedLayoutN, + ) + + # Shared memory for first input and its weights + smemX1 = gl.allocate_shared_memory( + x1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + smemW1 = gl.allocate_shared_memory( + w1_ptr.dtype.element_ty, [BLOCK_SIZE_N], sharedLayoutN + ) + + # Tensor descriptor and shared memory for optional residual input + if FIRST_INPUT_RES: + res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + res1_ptr, + [M, N1], + [res1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + smemRes1 = gl.allocate_shared_memory( + res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + + # Second input path — programs with id >= num_pid_m handle x2 + if start_pid >= num_pid_m: + if HAS_SECOND_INPUT: + x2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + x2_ptr, + [M, N2], + [x2_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N2], + sharedLayout2D, + ) + w2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + w2_ptr, + [N2], + [1], + [BLOCK_SIZE_N2], + sharedLayoutN, + ) + smemX2 = gl.allocate_shared_memory( + x2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D + ) + smemW2 = gl.allocate_shared_memory( + w2_ptr.dtype.element_ty, [BLOCK_SIZE_N2], sharedLayoutN + ) + + start_pid -= num_pid_m + + # Load x2 and w2 in parallel then wait for both + gl.amd.gfx1250.tdm.async_load(x2_desec, [start_pid * BLOCK_SIZE_M, 0], smemX2) + gl.amd.gfx1250.tdm.async_load(w2_desec, [0], smemW2) + gl.amd.gfx1250.tdm.async_wait(0) + + x2 = smemX2.load(gLayout2D).to(gl.float32) + w2 = smemW2.load(gLayoutN).to(gl.float32) + w2 = w2.reshape(1, BLOCK_SIZE_N2) + w2 = gl.convert_layout(w2, gLayout2D) + norm2 = _rmsnorm_op(x2, w2, N2, eps2) + + # Store norm2 output via TDM + out2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out2_ptr, + [M, N2], + [out2_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N2], + sharedLayout2D, + ) + smemOut2 = gl.allocate_shared_memory( + out2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D + ) + smemOut2.store(norm2.to(out2_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store(out2_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut2) + gl.amd.gfx1250.tdm.async_wait(0) + return + + # First input path + NUM_QUANT_BLOCKS: gl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE + x_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M, gLayoutM) + + # Load x1 and optionally res1 in parallel, then wait + gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1) + if FIRST_INPUT_RES: + gl.amd.gfx1250.tdm.async_load(res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1) + gl.amd.gfx1250.tdm.async_wait(0) + + x1 = smemX1.load(gLayout2D).to(gl.float32) + + if FIRST_INPUT_RES: + res1_loaded = smemRes1.load(gLayout2D).to(gl.float32) + x1 = x1 + res1_loaded + + # Load w1 and wait + gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1) + gl.amd.gfx1250.tdm.async_wait(0) + + w1 = smemW1.load(gLayoutN).to(gl.float32) + w1 = w1.reshape(1, BLOCK_SIZE_N) + w1 = gl.convert_layout(w1, gLayout2D) + norm1 = _rmsnorm_op(x1, w1, N1, eps1) + + # Store unquantized output via TDM (optional) + if FIRST_INPUT_OUT: + out1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out1_ptr, + [M, N1], + [out1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + smemOut1 = gl.allocate_shared_memory( + out1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + smemOut1.store(norm1.to(out1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store(out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1) + gl.amd.gfx1250.tdm.async_wait(0) + + out1_fp4, bs_e8m0 = _mxfp4_quant_op(norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE) + out1_fp4 = gl.convert_layout(out1_fp4, gLayout2D) + + # out1_fp4 uses half-width (packed) offsets — keep as regular store + half_x_offs_n = gl.arange(0, BLOCK_SIZE_N // 2) + out_mask1 = (half_x_offs_n < (N1 // 2))[None, :] + if not EVEN_M_N: + out_mask1 = out_mask1 & (x_offs_m < M)[:, None] + gl.store(out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], out1_fp4, mask=out_mask1) + + # out1_bs uses non-linear shuffle offsets — keep as regular store + bs_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) + bs_offs_n = gl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + if SHUFFLE: + bs_offs_0 = bs_offs_m[:, None] // 32 + bs_offs_1 = bs_offs_m[:, None] % 32 + bs_offs_2 = bs_offs_1 % 16 + bs_offs_1 = bs_offs_1 // 16 + bs_offs_3 = bs_offs_n[None, :] // 8 + bs_offs_4 = bs_offs_n[None, :] % 8 + bs_offs_5 = bs_offs_4 % 4 + bs_offs_4 = bs_offs_4 // 4 + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * SCALE_N_PAD + ) + bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :] + bs_e8m0 = gl.where(bs_mask_127, bs_e8m0, 127) + else: + bs_offs = ( + bs_offs_m[:, None] * out1_bs_stride_m + + bs_offs_n[None, :] * out1_bs_stride_n + ) + + bs_mask = None + if not EVEN_M_N: + if not SHUFFLE_PAD: + bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] + else: + bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[None, :] + + gl.store(out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask) + + # Store residual output via TDM + if FIRST_INPUT_RES: + out_res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out_res1_ptr, + [M, N1], + [out_res1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + smemOutRes1 = gl.allocate_shared_memory( + out_res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store(out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1) + gl.amd.gfx1250.tdm.async_wait(0) \ No newline at end of file diff --git a/aiter/ops/triton/quant/fused_mxfp4_quant.py b/aiter/ops/triton/quant/fused_mxfp4_quant.py index 0131f53a7d..60c40ea649 100644 --- a/aiter/ops/triton/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/quant/fused_mxfp4_quant.py @@ -5,6 +5,7 @@ from typing import Optional from aiter.ops.triton.utils._triton.arch_info import get_arch from aiter.utility import dtypes +from aiter.ops.triton.utils._triton.arch_info import get_arch from aiter.ops.triton._triton_kernels.quant.fused_mxfp4_quant import ( _fused_rms_mxfp4_quant_kernel, _fused_flatten_mxfp4_quant, @@ -18,8 +19,10 @@ from aiter.ops.triton._triton_kernels.activation import ( _get_activation_from_str, ) +from aiter.ops.triton._gluon_kernels.gfx1250.quant.fused_mxfp4_quant import _gluon_fused_rms_mxfp4_quant_kernel from aiter.ops.triton.utils.logger import AiterTritonLogger + _LOGGER = AiterTritonLogger() @@ -109,8 +112,7 @@ def fused_rms_mxfp4_quant( x2_stride_m = x2.stride(0) out2_stride_m = out2.stride(0) - # checks args for gluon or triton. Auto will default to best kernel based on hardware arch - + #checks args for either gluon, triton, or auto. auto will check for gfx1250 hardware and set to gluon if it exists, otherwise defaults to triton if inargs == "auto": if get_arch() == "gfx1250": kernel = _gluon_fused_rms_mxfp4_quant_kernel @@ -118,14 +120,13 @@ def fused_rms_mxfp4_quant( kernel = _fused_rms_mxfp4_quant_kernel elif inargs == "gluon": if get_arch() != "gfx1250": - raise ValueError("Gluon kernel only supported on gfx1250") + raise RuntimeError("Gluon kernel only supported on gfx1250 hardware") kernel = _gluon_fused_rms_mxfp4_quant_kernel elif inargs == "triton": kernel = _fused_rms_mxfp4_quant_kernel else: - raise ValueError( - f"Invalid argument: {inargs}. Chose from auto, gluon, or triton" - ) + raise ValueError(f"Invalid argument: {inargs}. Choose from auto, gluon, or triton") + grid = (triton.cdiv(M, BLOCK_SIZE_M) * (2 if (x2 is not None) else 1),) kernel[grid]( diff --git a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py index 28bc408921..d02e39bfbe 100644 --- a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py +++ b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py @@ -158,7 +158,7 @@ def test_flatten_quant(B: int, M: int, N: int, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("scale_shuffle_padding", [True, False]) -@pytest.mark.parametrize("inargs", ["auto", "gluon", "triton"]) +@pytest.mark.parametrize("inargs", ["triton", "gluon"]) def test_fused_rms_quant( M: int, N1: int, @@ -171,13 +171,12 @@ def test_fused_rms_quant( scale_shuffle_padding: bool, inargs: str, ): + if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") if inargs == "gluon" and arch_info.get_arch() != "gfx1250": - pytest.skip("Gluon kernel only supported on gfx1250") - - torch.manual_seed(0) + pytest.skip("Gluon kernel only supported on gfx1250 hardware") torch.cuda.empty_cache() # Helps avoid hangs in large tests x1, x2, rms1_w, rms2_w, resid1 = generate_fused_rms_quant_data( @@ -238,54 +237,6 @@ def test_fused_rms_quant( torch.testing.assert_close(y1_fp32_torch, y1_fp32_triton) -def run_torch_reduce_act_mul_mxfp4_group_quant(x, x2, activation, dtype, shuffle): - x = x.to(torch.float32) - d = x.shape[-1] // 2 - y2 = None - if x.dim() == 3: - x = x.sum(axis=0) - y2 = x2.sum(axis=0).to(dtype=dtype) - else: - assert x2 is None, "x2 must be None in x.dim() == 2 cases" - x, x_mul = x.split([d, d], dim=-1) - if activation == "silu": - out = F.silu(x) * x_mul - elif activation == "gelu": - out = F.gelu(x) * x_mul - out, out_scale = torch_dynamic_mxfp4_quant(out) - if shuffle: - # out_scale_pad = out_scale - M = out_scale.shape[0] - N = out.shape[1] * 2 - scaleM = (M + 255) // 256 * 256 - scaleN_valid = (N + 31) // 32 - scaleN = (scaleN_valid + 7) // 8 * 8 - out_scale_pad = torch.empty( - (scaleM, scaleN), dtype=out_scale.dtype, device=out_scale.device - ) - out_scale_pad[:M, :scaleN] = out_scale[:M, :scaleN] - out_scale = shuffle_scales(out_scale_pad) - out_scale = out_scale.view(out_scale.shape[0] * 32, -1) - return (out, out_scale), y2 - - -def generate_fused_reduce_act_mul_mxfp4_group_quant( - M: int, - N1: int, - dtype=torch.bfloat16, - SPK: int = 1, - N2: int = 1, -): - if SPK == 1: - x = torch.randn((M, N1 * 2), dtype=dtype).cuda() / 10 - else: - x = torch.randn((SPK, M, N1 * 2), dtype=torch.float32).cuda() / 10 - x2 = None - if SPK > 1: - x2 = torch.randn((SPK, M, N2), dtype=torch.float32).cuda() / 10 - - return x, x2 - @pytest.mark.parametrize( "M, N1, N2", From e2f0e9b65ac01ba790e9f077d91eb2f596dc6199 Mon Sep 17 00:00:00 2001 From: jrosas Date: Thu, 14 May 2026 23:49:51 +0000 Subject: [PATCH 03/10] linter and small fixes --- .../_gluon_kernels/gfx1250/quant/__init__.py | 0 .../gfx1250/quant/fused_mxfp4_quant.py | 52 +++++++++++++------ aiter/ops/triton/quant/fused_mxfp4_quant.py | 12 +++-- 3 files changed, 43 insertions(+), 21 deletions(-) create mode 100644 aiter/ops/triton/_gluon_kernels/gfx1250/quant/__init__.py diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/__init__.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py index 0d886ada42..fec708a42f 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py @@ -1,10 +1,9 @@ -import torch import triton from triton.experimental import gluon -from typing import Optional from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op from triton.experimental.gluon import language as gl + @gluon.jit def _rmsnorm_op( row, @@ -13,13 +12,14 @@ def _rmsnorm_op( epsilon, ): - row_norm = row*row + row_norm = row * row row_norm = gl.sum(row_norm, axis=-1, keep_dims=True) norm_factor = gl.rsqrt((row_norm / n_cols) + epsilon) rms_norm = row * norm_factor * weights return rms_norm + @triton.heuristics( { "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 @@ -70,13 +70,13 @@ def _gluon_fused_rms_mxfp4_quant_kernel( EVEN_M_N2: gl.constexpr, ): start_pid = gl.program_id(0) - #get number of programs to determine is 1 or 2 passes + # get number of programs to determine is 1 or 2 passes num_pid_m = gl.cdiv(M, BLOCK_SIZE_M) - #create block layouts + # create block layouts gLayout2D: gl.constexpr = gl.BlockedLayout( [1, 2], # sizePerThread - [1, 32], # threadsPerWarp + [1, 32], # threadsPerWarp [1, 4], # warpsPerCTA [1, 0], # order ) @@ -153,7 +153,9 @@ def _gluon_fused_rms_mxfp4_quant_kernel( start_pid -= num_pid_m # Load x2 and w2 in parallel then wait for both - gl.amd.gfx1250.tdm.async_load(x2_desec, [start_pid * BLOCK_SIZE_M, 0], smemX2) + gl.amd.gfx1250.tdm.async_load( + x2_desec, [start_pid * BLOCK_SIZE_M, 0], smemX2 + ) gl.amd.gfx1250.tdm.async_load(w2_desec, [0], smemW2) gl.amd.gfx1250.tdm.async_wait(0) @@ -175,7 +177,9 @@ def _gluon_fused_rms_mxfp4_quant_kernel( out2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D ) smemOut2.store(norm2.to(out2_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store(out2_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut2) + gl.amd.gfx1250.tdm.async_store( + out2_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut2 + ) gl.amd.gfx1250.tdm.async_wait(0) return @@ -186,7 +190,9 @@ def _gluon_fused_rms_mxfp4_quant_kernel( # Load x1 and optionally res1 in parallel, then wait gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1) if FIRST_INPUT_RES: - gl.amd.gfx1250.tdm.async_load(res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1) + gl.amd.gfx1250.tdm.async_load( + res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1 + ) gl.amd.gfx1250.tdm.async_wait(0) x1 = smemX1.load(gLayout2D).to(gl.float32) @@ -217,10 +223,14 @@ def _gluon_fused_rms_mxfp4_quant_kernel( out1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D ) smemOut1.store(norm1.to(out1_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store(out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1) + gl.amd.gfx1250.tdm.async_store( + out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1 + ) gl.amd.gfx1250.tdm.async_wait(0) - out1_fp4, bs_e8m0 = _mxfp4_quant_op(norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE) + out1_fp4, bs_e8m0 = _mxfp4_quant_op( + norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + ) out1_fp4 = gl.convert_layout(out1_fp4, gLayout2D) # out1_fp4 uses half-width (packed) offsets — keep as regular store @@ -228,7 +238,11 @@ def _gluon_fused_rms_mxfp4_quant_kernel( out_mask1 = (half_x_offs_n < (N1 // 2))[None, :] if not EVEN_M_N: out_mask1 = out_mask1 & (x_offs_m < M)[:, None] - gl.store(out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], out1_fp4, mask=out_mask1) + gl.store( + out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], + out1_fp4, + mask=out_mask1, + ) # out1_bs uses non-linear shuffle offsets — keep as regular store bs_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) @@ -264,9 +278,13 @@ def _gluon_fused_rms_mxfp4_quant_kernel( if not SHUFFLE_PAD: bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] else: - bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[None, :] + bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[ + None, : + ] - gl.store(out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask) + gl.store( + out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask + ) # Store residual output via TDM if FIRST_INPUT_RES: @@ -281,5 +299,7 @@ def _gluon_fused_rms_mxfp4_quant_kernel( out_res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D ) smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store(out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1) - gl.amd.gfx1250.tdm.async_wait(0) \ No newline at end of file + gl.amd.gfx1250.tdm.async_store( + out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1 + ) + gl.amd.gfx1250.tdm.async_wait(0) diff --git a/aiter/ops/triton/quant/fused_mxfp4_quant.py b/aiter/ops/triton/quant/fused_mxfp4_quant.py index 60c40ea649..3c6f4b1f25 100644 --- a/aiter/ops/triton/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/quant/fused_mxfp4_quant.py @@ -19,10 +19,11 @@ from aiter.ops.triton._triton_kernels.activation import ( _get_activation_from_str, ) -from aiter.ops.triton._gluon_kernels.gfx1250.quant.fused_mxfp4_quant import _gluon_fused_rms_mxfp4_quant_kernel +from aiter.ops.triton._gluon_kernels.gfx1250.quant.fused_mxfp4_quant import ( + _gluon_fused_rms_mxfp4_quant_kernel, +) from aiter.ops.triton.utils.logger import AiterTritonLogger - _LOGGER = AiterTritonLogger() @@ -112,7 +113,7 @@ def fused_rms_mxfp4_quant( x2_stride_m = x2.stride(0) out2_stride_m = out2.stride(0) - #checks args for either gluon, triton, or auto. auto will check for gfx1250 hardware and set to gluon if it exists, otherwise defaults to triton + # checks args for either gluon, triton, or auto. auto will check for gfx1250 hardware and set to gluon if it exists, otherwise defaults to triton if inargs == "auto": if get_arch() == "gfx1250": kernel = _gluon_fused_rms_mxfp4_quant_kernel @@ -125,8 +126,9 @@ def fused_rms_mxfp4_quant( elif inargs == "triton": kernel = _fused_rms_mxfp4_quant_kernel else: - raise ValueError(f"Invalid argument: {inargs}. Choose from auto, gluon, or triton") - + raise ValueError( + f"Invalid argument: {inargs}. Choose from auto, gluon, or triton" + ) grid = (triton.cdiv(M, BLOCK_SIZE_M) * (2 if (x2 is not None) else 1),) kernel[grid]( From fdfc7ad772bdcdf1dbba5f68d564b63958ded8c0 Mon Sep 17 00:00:00 2001 From: jrosas Date: Fri, 15 May 2026 00:00:39 +0000 Subject: [PATCH 04/10] run black for test_fused_mxfp4 --- .gitmodules | 3 --- op_tests/triton_tests/quant/test_fused_mxfp4_quant.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/.gitmodules b/.gitmodules index 61bc1cf7e0..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "3rdparty/composable_kernel"] - path = 3rdparty/composable_kernel - url = https://github.com/ROCm/composable_kernel.git diff --git a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py index d02e39bfbe..72dd40505b 100644 --- a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py +++ b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F import pytest from aiter.ops.triton.quant.fused_mxfp4_quant import ( fused_flatten_mxfp4_quant, @@ -237,7 +236,6 @@ def test_fused_rms_quant( torch.testing.assert_close(y1_fp32_torch, y1_fp32_triton) - @pytest.mark.parametrize( "M, N1, N2", [ From f37ce82c1a15db0015b89b07d03aa3b000d6e580 Mon Sep 17 00:00:00 2001 From: jrosas Date: Fri, 29 May 2026 15:51:36 +0000 Subject: [PATCH 05/10] Removed wrongly named file and updated fused_mxfp4_quant.py with necessary changes --- .../gfx1250/quant/fuse_mxfp4_quant.py | 280 ------------------ .../gfx1250/quant/fused_mxfp4_quant.py | 96 +++--- 2 files changed, 43 insertions(+), 333 deletions(-) delete mode 100644 aiter/ops/triton/_gluon_kernels/gfx1250/quant/fuse_mxfp4_quant.py diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fuse_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fuse_mxfp4_quant.py deleted file mode 100644 index 02d79cc04a..0000000000 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fuse_mxfp4_quant.py +++ /dev/null @@ -1,280 +0,0 @@ -import triton -from triton.experimental import gluon -from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op -from aiter.ops.triton._triton_kernels.quant.rmsnorm import _rmsnorm_op -from triton.experimental.gluon import language as gl - - -@triton.heuristics( - { - "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 - and args["N1"] % (args["BLOCK_SIZE_N"]) == 0, - } -) -@gluon.jit -def _gluon_fused_rms_mxfp4_quant_kernel( - x1_ptr, - w1_ptr, - x2_ptr, - w2_ptr, - res1_ptr, - out1_fp4_ptr, - out1_bs_ptr, - out2_ptr, - out_res1_ptr, - out1_ptr, - eps1, - eps2, - M, - N1, - N2, - x1_stride_m, - x2_stride_m, - res1_stride_m, - out1_fp4_stride_m, - out1_bs_stride_m, - out1_bs_stride_n, - out2_stride_m, - out_res1_stride_m, - out1_stride_m, - BLOCK_SIZE_M: gl.constexpr, - BLOCK_SIZE_N: gl.constexpr, - BLOCK_SIZE_N2: gl.constexpr, - MXFP4_QUANT_BLOCK_SIZE: gl.constexpr, - HAS_SECOND_INPUT: gl.constexpr, - FIRST_INPUT_RES: gl.constexpr, - FIRST_INPUT_OUT: gl.constexpr, - SCALE_N: gl.constexpr, - SCALE_M_PAD: gl.constexpr, - SCALE_N_PAD: gl.constexpr, - SHUFFLE: gl.constexpr, - SHUFFLE_PAD: gl.constexpr, - EVEN_M_N: gl.constexpr, -): - start_pid = gl.program_id(0) - # get number of programs to determine is 1 or 2 passes - num_pid_m = gl.cdiv(M, BLOCK_SIZE_M) - - # create block layouts - gLayout2D: gl.constexpr = gl.BlockedLayout( - [1, 2], # sizePerThread - [1, 32], # threadsPerWarp - [1, 4], # warpsPerCTA - [1, 0], # order - ) - - gLayoutN: gl.constexpr = gl.SliceLayout(0, gLayout2D) - - # 2D shared layout for matrix rows; 1D shared layout for weight vectors - sharedLayout2D: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) - sharedLayoutN: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[0]) - - # Tensor descriptors for first input and its weights - x1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - x1_ptr, - [M, N1], - [x1_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - sharedLayout2D, - ) - - # tensor descriptor for weight 1 - w1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - w1_ptr, - [N1], - [1], - [BLOCK_SIZE_N], - sharedLayoutN, - ) - - # Shared memory for first input and its weights - smemX1 = gl.allocate_shared_memory( - x1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D - ) - smemW1 = gl.allocate_shared_memory( - w1_ptr.dtype.element_ty, [BLOCK_SIZE_N], sharedLayoutN - ) - - # Load x1 and optionally res1 in parallel, then wait - gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1) - gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1) - - # Tensor descriptor and shared memory for optional residual input - if FIRST_INPUT_RES: - res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - res1_ptr, - [M, N1], - [res1_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - sharedLayout2D, - ) - - smemRes1 = gl.allocate_shared_memory( - res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D - ) - - gl.amd.gfx1250.tdm.async_load( - res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1 - ) - - out_res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - out_res1_ptr, - [M, N1], - [out_res1_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - sharedLayout2D, - ) - - smemOutRes1 = gl.allocate_shared_memory( - out_res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D - ) - - # Second input path — programs with id >= num_pid_m handle x2 - if start_pid >= num_pid_m: - if HAS_SECOND_INPUT: - x2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - x2_ptr, - [M, N2], - [x2_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N2], - sharedLayout2D, - ) - # Load x2 and w2 in parallel then wait for both - w2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - w2_ptr, - [N2], - [1], - [BLOCK_SIZE_N2], - sharedLayoutN, - ) - smemX2 = gl.allocate_shared_memory( - x2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D - ) - smemW2 = gl.allocate_shared_memory( - w2_ptr.dtype.element_ty, [BLOCK_SIZE_N2], sharedLayoutN - ) - start_pid -= num_pid_m - - gl.amd.gfx1250.tdm.async_load( - x2_desec, [start_pid * BLOCK_SIZE_M, 0], smemX2 - ) - gl.amd.gfx1250.tdm.async_load(w2_desec, [0], smemW2) - gl.amd.gfx1250.tdm.async_wait(0) - - x2 = smemX2.load(gLayout2D).to(gl.float32) - w2 = smemW2.load(gLayoutN).to(gl.float32) - w2 = w2.reshape(1, BLOCK_SIZE_N2) - w2 = gl.convert_layout(w2, gLayout2D) - norm2 = _rmsnorm_op(x2, w2, N2, eps2) - - # Store norm2 output via TDM - out2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - out2_ptr, - [M, N2], - [out2_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N2], - sharedLayout2D, - ) - smemOut2 = gl.allocate_shared_memory( - out2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D - ) - smemOut2.store(norm2.to(out2_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store( - out2_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut2 - ) - gl.amd.gfx1250.tdm.async_wait(0) - return - - # First input path - NUM_QUANT_BLOCKS: gl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE - x1 = smemX1.load(gLayout2D).to(gl.float32) - - if FIRST_INPUT_RES: - res1_loaded = smemRes1.load(gLayout2D).to(gl.float32) - x1 = x1 + res1_loaded - smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store( - out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1 - ) - - w1 = smemW1.load(gLayoutN).to(gl.float32) - w1 = w1.reshape(1, BLOCK_SIZE_N) - w1 = gl.convert_layout(w1, gLayout2D) - norm1 = _rmsnorm_op(x1, w1, N1, eps1) - - # Store unquantized output via TDM (optional) - if FIRST_INPUT_OUT: - out1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - out1_ptr, - [M, N1], - [out1_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - sharedLayout2D, - ) - smemOut1 = gl.allocate_shared_memory( - out1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D - ) - smemOut1.store(norm1.to(out1_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store( - out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1 - ) - - out1_fp4, bs_e8m0 = _mxfp4_quant_op( - norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE - ) - gl.amd.gfx1250.tdm.async_wait(0) - - # out1_fp4 uses half-width (packed) offsets — keep as regular store - fp4_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) - half_x_offs_n = gl.arange(0, BLOCK_SIZE_N // 2) - out_mask1 = (half_x_offs_n < (N1 // 2))[None, :] - if not EVEN_M_N: - out_mask1 = out_mask1 & (fp4_offs_m < M)[:, None] - - gl.store( - out1_fp4_ptr + fp4_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], - out1_fp4, - mask=out_mask1, - ) - - # out1_bs uses non-linear shuffle offsets — keep as regular store - bs_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) - bs_offs_n = gl.arange(0, NUM_QUANT_BLOCKS) - num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE - if SHUFFLE: - bs_offs_0 = bs_offs_m[:, None] >> 5 # // 32 - bs_offs_1 = bs_offs_m[:, None] & 31 # % 32 - bs_offs_2 = bs_offs_1 & 15 # % 16 - bs_offs_1 = bs_offs_1 >> 4 # // 16 - bs_offs_3 = bs_offs_n[None, :] >> 3 # // 8 - bs_offs_4 = bs_offs_n[None, :] & 7 # % 8 - bs_offs_5 = bs_offs_4 & 3 # % 4 - bs_offs_4 = bs_offs_4 >> 2 # // 4 - bs_offs = ( - bs_offs_1 - + bs_offs_4 * 2 - + bs_offs_2 * 2 * 2 - + bs_offs_5 * 2 * 2 * 16 - + bs_offs_3 * 2 * 2 * 16 * 4 - + bs_offs_0 * 2 * 16 * SCALE_N_PAD - ) - bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :] - bs_e8m0 = gl.where(bs_mask_127, bs_e8m0, 127) - else: - bs_offs = ( - bs_offs_m[:, None] * out1_bs_stride_m - + bs_offs_n[None, :] * out1_bs_stride_n - ) - - bs_mask = None - if not EVEN_M_N: - if SHUFFLE_PAD: - bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[ - None, : - ] - else: - bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] - - gl.store( - out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask - ) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py index fec708a42f..3d4602282e 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py @@ -24,8 +24,6 @@ def _rmsnorm_op( { "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 and args["N1"] % (args["BLOCK_SIZE_N"]) == 0, - "EVEN_M_N2": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 - and args["N2"] % (args["BLOCK_SIZE_N2"]) == 0, } ) @gluon.jit @@ -67,7 +65,6 @@ def _gluon_fused_rms_mxfp4_quant_kernel( SHUFFLE: gl.constexpr, SHUFFLE_PAD: gl.constexpr, EVEN_M_N: gl.constexpr, - EVEN_M_N2: gl.constexpr, ): start_pid = gl.program_id(0) # get number of programs to determine is 1 or 2 passes @@ -81,7 +78,6 @@ def _gluon_fused_rms_mxfp4_quant_kernel( [1, 0], # order ) - gLayoutM: gl.constexpr = gl.SliceLayout(1, gLayout2D) gLayoutN: gl.constexpr = gl.SliceLayout(0, gLayout2D) # 2D shared layout for matrix rows; 1D shared layout for weight vectors @@ -97,6 +93,7 @@ def _gluon_fused_rms_mxfp4_quant_kernel( sharedLayout2D, ) + # tensor descriptor for weight 1 w1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( w1_ptr, [N1], @@ -113,6 +110,10 @@ def _gluon_fused_rms_mxfp4_quant_kernel( w1_ptr.dtype.element_ty, [BLOCK_SIZE_N], sharedLayoutN ) + # Load x1 and optionally res1 in parallel, then wait + gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1) + gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1) + # Tensor descriptor and shared memory for optional residual input if FIRST_INPUT_RES: res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( @@ -122,10 +123,27 @@ def _gluon_fused_rms_mxfp4_quant_kernel( [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D, ) + smemRes1 = gl.allocate_shared_memory( res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D ) + gl.amd.gfx1250.tdm.async_load( + res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1 + ) + + out_res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out_res1_ptr, + [M, N1], + [out_res1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + + smemOutRes1 = gl.allocate_shared_memory( + out_res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + # Second input path — programs with id >= num_pid_m handle x2 if start_pid >= num_pid_m: if HAS_SECOND_INPUT: @@ -136,6 +154,7 @@ def _gluon_fused_rms_mxfp4_quant_kernel( [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D, ) + # Load x2 and w2 in parallel then wait for both w2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( w2_ptr, [N2], @@ -149,10 +168,8 @@ def _gluon_fused_rms_mxfp4_quant_kernel( smemW2 = gl.allocate_shared_memory( w2_ptr.dtype.element_ty, [BLOCK_SIZE_N2], sharedLayoutN ) - start_pid -= num_pid_m - # Load x2 and w2 in parallel then wait for both gl.amd.gfx1250.tdm.async_load( x2_desec, [start_pid * BLOCK_SIZE_M, 0], smemX2 ) @@ -185,25 +202,15 @@ def _gluon_fused_rms_mxfp4_quant_kernel( # First input path NUM_QUANT_BLOCKS: gl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE - x_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M, gLayoutM) - - # Load x1 and optionally res1 in parallel, then wait - gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1) - if FIRST_INPUT_RES: - gl.amd.gfx1250.tdm.async_load( - res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1 - ) - gl.amd.gfx1250.tdm.async_wait(0) - x1 = smemX1.load(gLayout2D).to(gl.float32) if FIRST_INPUT_RES: res1_loaded = smemRes1.load(gLayout2D).to(gl.float32) x1 = x1 + res1_loaded - - # Load w1 and wait - gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1) - gl.amd.gfx1250.tdm.async_wait(0) + smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store( + out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1 + ) w1 = smemW1.load(gLayoutN).to(gl.float32) w1 = w1.reshape(1, BLOCK_SIZE_N) @@ -226,20 +233,21 @@ def _gluon_fused_rms_mxfp4_quant_kernel( gl.amd.gfx1250.tdm.async_store( out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1 ) - gl.amd.gfx1250.tdm.async_wait(0) out1_fp4, bs_e8m0 = _mxfp4_quant_op( norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE ) - out1_fp4 = gl.convert_layout(out1_fp4, gLayout2D) + gl.amd.gfx1250.tdm.async_wait(0) # out1_fp4 uses half-width (packed) offsets — keep as regular store + fp4_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) half_x_offs_n = gl.arange(0, BLOCK_SIZE_N // 2) out_mask1 = (half_x_offs_n < (N1 // 2))[None, :] if not EVEN_M_N: - out_mask1 = out_mask1 & (x_offs_m < M)[:, None] + out_mask1 = out_mask1 & (fp4_offs_m < M)[:, None] + gl.store( - out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], + out1_fp4_ptr + fp4_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], out1_fp4, mask=out_mask1, ) @@ -249,14 +257,14 @@ def _gluon_fused_rms_mxfp4_quant_kernel( bs_offs_n = gl.arange(0, NUM_QUANT_BLOCKS) num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE if SHUFFLE: - bs_offs_0 = bs_offs_m[:, None] // 32 - bs_offs_1 = bs_offs_m[:, None] % 32 - bs_offs_2 = bs_offs_1 % 16 - bs_offs_1 = bs_offs_1 // 16 - bs_offs_3 = bs_offs_n[None, :] // 8 - bs_offs_4 = bs_offs_n[None, :] % 8 - bs_offs_5 = bs_offs_4 % 4 - bs_offs_4 = bs_offs_4 // 4 + bs_offs_0 = bs_offs_m[:, None] >> 5 # // 32 + bs_offs_1 = bs_offs_m[:, None] & 31 # % 32 + bs_offs_2 = bs_offs_1 & 15 # % 16 + bs_offs_1 = bs_offs_1 >> 4 # // 16 + bs_offs_3 = bs_offs_n[None, :] >> 3 # // 8 + bs_offs_4 = bs_offs_n[None, :] & 7 # % 8 + bs_offs_5 = bs_offs_4 & 3 # % 4 + bs_offs_4 = bs_offs_4 >> 2 # // 4 bs_offs = ( bs_offs_1 + bs_offs_4 * 2 @@ -275,31 +283,13 @@ def _gluon_fused_rms_mxfp4_quant_kernel( bs_mask = None if not EVEN_M_N: - if not SHUFFLE_PAD: - bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] - else: + if SHUFFLE_PAD: bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[ None, : ] + else: + bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] gl.store( out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask ) - - # Store residual output via TDM - if FIRST_INPUT_RES: - out_res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - out_res1_ptr, - [M, N1], - [out_res1_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - sharedLayout2D, - ) - smemOutRes1 = gl.allocate_shared_memory( - out_res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D - ) - smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store( - out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1 - ) - gl.amd.gfx1250.tdm.async_wait(0) From 5356190d280d4d5ca027e4b88967d33373ffd05f Mon Sep 17 00:00:00 2001 From: Omar Muhammad Date: Fri, 29 May 2026 22:42:05 +0000 Subject: [PATCH 06/10] Reorganized tdm async calls in fused_mxfp4_quant kernel. Changed gl.store to async_store for fp4 otuput. Fixed improper file call in api call --- .../gfx1250/quant/fused_mxfp4_quant.py | 111 ++++++++++-------- aiter/ops/triton/quant/fused_mxfp4_quant.py | 6 +- 2 files changed, 60 insertions(+), 57 deletions(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py index 3d4602282e..05fb0c459e 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py @@ -110,9 +110,8 @@ def _gluon_fused_rms_mxfp4_quant_kernel( w1_ptr.dtype.element_ty, [BLOCK_SIZE_N], sharedLayoutN ) - # Load x1 and optionally res1 in parallel, then wait + # x1 load issued unconditionally for early latency hiding (OOB-safe for second-input programs) gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1) - gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1) # Tensor descriptor and shared memory for optional residual input if FIRST_INPUT_RES: @@ -154,7 +153,6 @@ def _gluon_fused_rms_mxfp4_quant_kernel( [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D, ) - # Load x2 and w2 in parallel then wait for both w2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( w2_ptr, [N2], @@ -200,59 +198,22 @@ def _gluon_fused_rms_mxfp4_quant_kernel( gl.amd.gfx1250.tdm.async_wait(0) return - # First input path - NUM_QUANT_BLOCKS: gl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE - x1 = smemX1.load(gLayout2D).to(gl.float32) - - if FIRST_INPUT_RES: - res1_loaded = smemRes1.load(gLayout2D).to(gl.float32) - x1 = x1 + res1_loaded - smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store( - out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1 - ) - - w1 = smemW1.load(gLayoutN).to(gl.float32) - w1 = w1.reshape(1, BLOCK_SIZE_N) - w1 = gl.convert_layout(w1, gLayout2D) - norm1 = _rmsnorm_op(x1, w1, N1, eps1) + gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1) - # Store unquantized output via TDM (optional) - if FIRST_INPUT_OUT: - out1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - out1_ptr, - [M, N1], - [out1_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - sharedLayout2D, - ) - smemOut1 = gl.allocate_shared_memory( - out1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D - ) - smemOut1.store(norm1.to(out1_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store( - out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1 - ) + NUM_QUANT_BLOCKS: gl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE - out1_fp4, bs_e8m0 = _mxfp4_quant_op( - norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + # Descriptor and smem for fp4 TDM async_store + out1_fp4_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out1_fp4_ptr, + [M, N1 // 2], + [out1_fp4_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N // 2], + sharedLayout2D, ) - gl.amd.gfx1250.tdm.async_wait(0) - - # out1_fp4 uses half-width (packed) offsets — keep as regular store - fp4_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) - half_x_offs_n = gl.arange(0, BLOCK_SIZE_N // 2) - out_mask1 = (half_x_offs_n < (N1 // 2))[None, :] - if not EVEN_M_N: - out_mask1 = out_mask1 & (fp4_offs_m < M)[:, None] - - gl.store( - out1_fp4_ptr + fp4_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], - out1_fp4, - mask=out_mask1, + smemOutFp4 = gl.allocate_shared_memory( + out1_fp4_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2], sharedLayout2D ) - # out1_bs uses non-linear shuffle offsets — keep as regular store bs_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) bs_offs_n = gl.arange(0, NUM_QUANT_BLOCKS) num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE @@ -274,7 +235,6 @@ def _gluon_fused_rms_mxfp4_quant_kernel( + bs_offs_0 * 2 * 16 * SCALE_N_PAD ) bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :] - bs_e8m0 = gl.where(bs_mask_127, bs_e8m0, 127) else: bs_offs = ( bs_offs_m[:, None] * out1_bs_stride_m @@ -290,6 +250,53 @@ def _gluon_fused_rms_mxfp4_quant_kernel( else: bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] + x1 = smemX1.load(gLayout2D).to(gl.float32) + + if FIRST_INPUT_RES: + res1_loaded = smemRes1.load(gLayout2D).to(gl.float32) + x1 = x1 + res1_loaded + smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store( + out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1 + ) + + w1 = smemW1.load(gLayoutN).to(gl.float32) + w1 = w1.reshape(1, BLOCK_SIZE_N) + w1 = gl.convert_layout(w1, gLayout2D) + norm1 = _rmsnorm_op(x1, w1, N1, eps1) + + # Store unquantized output via TDM (optional) + if FIRST_INPUT_OUT: + out1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out1_ptr, + [M, N1], + [out1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + smemOut1 = gl.allocate_shared_memory( + out1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + smemOut1.store(norm1.to(out1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store( + out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1 + ) + + out1_fp4, bs_e8m0 = _mxfp4_quant_op( + norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + ) + + # Apply out-of-range scale mask + if SHUFFLE: + bs_e8m0 = gl.where(bs_mask_127, bs_e8m0, 127) + + smemOutFp4.store(out1_fp4) + gl.amd.gfx1250.tdm.async_store( + out1_fp4_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutFp4 + ) + gl.store( out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask ) + + gl.amd.gfx1250.tdm.async_wait(0) diff --git a/aiter/ops/triton/quant/fused_mxfp4_quant.py b/aiter/ops/triton/quant/fused_mxfp4_quant.py index 3c6f4b1f25..426fe3b4b7 100644 --- a/aiter/ops/triton/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/quant/fused_mxfp4_quant.py @@ -5,7 +5,6 @@ from typing import Optional from aiter.ops.triton.utils._triton.arch_info import get_arch from aiter.utility import dtypes -from aiter.ops.triton.utils._triton.arch_info import get_arch from aiter.ops.triton._triton_kernels.quant.fused_mxfp4_quant import ( _fused_rms_mxfp4_quant_kernel, _fused_flatten_mxfp4_quant, @@ -13,15 +12,12 @@ _fused_reduce_rms_mxfp4_quant_kernel, _fused_dynamic_mxfp4_quant_moe_sort_kernel, ) -from aiter.ops.triton._gluon_kernels.gfx1250.quant.fuse_mxfp4_quant import ( +from aiter.ops.triton._gluon_kernels.gfx1250.quant.fused_mxfp4_quant import ( _gluon_fused_rms_mxfp4_quant_kernel, ) from aiter.ops.triton._triton_kernels.activation import ( _get_activation_from_str, ) -from aiter.ops.triton._gluon_kernels.gfx1250.quant.fused_mxfp4_quant import ( - _gluon_fused_rms_mxfp4_quant_kernel, -) from aiter.ops.triton.utils.logger import AiterTritonLogger _LOGGER = AiterTritonLogger() From fd06c367ccccb2299a71b1b824a30835698b93cc Mon Sep 17 00:00:00 2001 From: Omar Muhammad Date: Sat, 30 May 2026 03:59:04 +0000 Subject: [PATCH 07/10] Add aysnc_wait() at line 358. Included triton kernel _mxfp4_quant_op in gluon kernel for workaround. --- .../gfx1250/quant/fused_mxfp4_quant.py | 109 +++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py index 05fb0c459e..5b4f7623a5 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py @@ -1,6 +1,5 @@ import triton from triton.experimental import gluon -from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op from triton.experimental.gluon import language as gl @@ -20,6 +19,112 @@ def _rmsnorm_op( return rms_norm +@triton.jit +def _mxfp4_quant_op( + x, + BLOCK_SIZE_N, + BLOCK_SIZE_M, + MXFP4_QUANT_BLOCK_SIZE, +): + """ + Converts given x (in fp32) to mxfp4 format. + x: [BLOCK_SIZE_M, BLOCK_SIZE_N], fp32 + + """ + EXP_BIAS_FP32: tl.constexpr = 127 + EXP_BIAS_FP4: tl.constexpr = 1 + EBITS_F32: tl.constexpr = 8 + EBITS_FP4: tl.constexpr = 2 + MBITS_F32: tl.constexpr = 23 + MBITS_FP4: tl.constexpr = 1 + + max_normal: tl.constexpr = 6 + min_normal: tl.constexpr = 1 + + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE + x = x.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE) + # Calculate scale + amax = tl.max(tl.abs(x), axis=-1, keep_dims=True) + amax = amax.to(tl.int32, bitcast=True) + amax = (amax + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax = amax.to(tl.float32, bitcast=True) + scale_e8m0_unbiased = tl.log2(amax).floor() - 2 + scale_e8m0_unbiased = tl.clamp(scale_e8m0_unbiased, min=-127, max=127) + + # blockscale_e8m0 + bs_e8m0 = scale_e8m0_unbiased.to(tl.uint8) + 127 # in fp32, we have 2&(e - 127) + + quant_scale = tl.exp2(-scale_e8m0_unbiased) + + # Compute quantized x + qx = x * quant_scale + + # Convert quantized fp32 tensor to uint32 before converting to mxfp4 format + # Note: MXFP4 S:1-bit, E:2-bit, M:1-bit + # Zeros: S000 -> +/-0 + # Denormal Numbers: S001 -> +/- 0.5 + # Normal Numbers: + # S010 -> +/- 1.0 + # S011 -> +/- 1.5 + # S100 -> +/- 2.0 + # S101 -> +/- 3.0 + # S110 -> +/- 4.0 + # S111 -> +/- 6.0 + qx = qx.to(tl.uint32, bitcast=True) + + # Extract sign + s = qx & 0x80000000 + # Set everything to positive, will add sign back at the end + qx = qx ^ s + + qx_fp32 = qx.to(tl.float32, bitcast=True) + saturate_mask = qx_fp32 >= max_normal + denormal_mask = (not saturate_mask) & (qx_fp32 < min_normal) + normal_mask = not (saturate_mask | denormal_mask) + + # Denormal numbers + denorm_exp: tl.constexpr = ( + (EXP_BIAS_FP32 - EXP_BIAS_FP4) + (MBITS_F32 - MBITS_FP4) + 1 + ) + denorm_mask_int: tl.constexpr = denorm_exp << MBITS_F32 + denorm_mask_float: tl.constexpr = tl.cast(denorm_mask_int, tl.float32, bitcast=True) + + denormal_x = qx_fp32 + denorm_mask_float + denormal_x = denormal_x.to(tl.uint32, bitcast=True) + denormal_x -= denorm_mask_int + denormal_x = denormal_x.to(tl.uint8) + + # Normal numbers + normal_x = qx + # resulting mantissa is odd + mant_odd = (normal_x >> (MBITS_F32 - MBITS_FP4)) & 1 + # update exponent, rounding bias part 1 + val_to_add = ((EXP_BIAS_FP4 - EXP_BIAS_FP32) << MBITS_F32) + (1 << 21) - 1 + normal_x += val_to_add + # rounding bias part 2 + normal_x += mant_odd + # take the bits! + normal_x = normal_x >> (MBITS_F32 - MBITS_FP4) + normal_x = normal_x.to(tl.uint8) + + # Merge results + e2m1_value = tl.full(qx.type.get_block_shapes(), 0x7, dtype=tl.uint8) + e2m1_value = tl.where(normal_mask, normal_x, e2m1_value) + e2m1_value = tl.where(denormal_mask, denormal_x, e2m1_value) + # add sign back + sign_lp = s >> (MBITS_F32 + EBITS_F32 - MBITS_FP4 - EBITS_FP4) + sign_lp = sign_lp.to(tl.uint8) + e2m1_value = e2m1_value | sign_lp + e2m1_value = tl.reshape( + e2m1_value, [BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE // 2, 2] + ) + evens, odds = tl.split(e2m1_value) + x_fp4 = evens | (odds << 4) + x_fp4 = x_fp4.reshape(BLOCK_SIZE_M, BLOCK_SIZE_N // 2) + + return x_fp4, bs_e8m0.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS) + + @triton.heuristics( { "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 @@ -250,6 +355,8 @@ def _gluon_fused_rms_mxfp4_quant_kernel( else: bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] + gl.amd.gfx1250.tdm.async_wait(0) + x1 = smemX1.load(gLayout2D).to(gl.float32) if FIRST_INPUT_RES: From eec4570f28905c02b54074ce1fbc008e1fd1feaf Mon Sep 17 00:00:00 2001 From: Omar Muhammad Date: Sat, 30 May 2026 04:05:42 +0000 Subject: [PATCH 08/10] Include import for triton lanaguge for _mxfp4_quant_op() --- .../ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py index 5b4f7623a5..83b845e2ec 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py @@ -1,4 +1,5 @@ import triton +import triton.language as tl from triton.experimental import gluon from triton.experimental.gluon import language as gl From 7bd76cd7ab73ee7364a31d61c319878f1ccb1b67 Mon Sep 17 00:00:00 2001 From: Omar Muhammad Date: Fri, 5 Jun 2026 21:16:34 +0000 Subject: [PATCH 09/10] Updated fused_mxfp4_quant to utilize more threads alongside software pipelining. --- .../gfx1250/quant/fused_mxfp4_quant.py | 317 +++++++++--------- aiter/ops/triton/quant/fused_mxfp4_quant.py | 35 +- 2 files changed, 190 insertions(+), 162 deletions(-) diff --git a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py index 83b845e2ec..21816973aa 100644 --- a/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/_gluon_kernels/gfx1250/quant/fused_mxfp4_quant.py @@ -4,18 +4,16 @@ from triton.experimental.gluon import language as gl -@gluon.jit +@triton.jit def _rmsnorm_op( row, weights, n_cols, epsilon, ): - row_norm = row * row row_norm = gl.sum(row_norm, axis=-1, keep_dims=True) norm_factor = gl.rsqrt((row_norm / n_cols) + epsilon) - rms_norm = row * norm_factor * weights return rms_norm @@ -128,7 +126,7 @@ def _mxfp4_quant_op( @triton.heuristics( { - "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + "EVEN_M_N": lambda args: args["M"] % args["ROWS_PER_CTA"] == 0 and args["N1"] % (args["BLOCK_SIZE_N"]) == 0, } ) @@ -171,26 +169,36 @@ def _gluon_fused_rms_mxfp4_quant_kernel( SHUFFLE: gl.constexpr, SHUFFLE_PAD: gl.constexpr, EVEN_M_N: gl.constexpr, + ROWS_PER_CTA: gl.constexpr, ): start_pid = gl.program_id(0) - # get number of programs to determine is 1 or 2 passes - num_pid_m = gl.cdiv(M, BLOCK_SIZE_M) - - # create block layouts - gLayout2D: gl.constexpr = gl.BlockedLayout( - [1, 2], # sizePerThread - [1, 32], # threadsPerWarp - [1, 4], # warpsPerCTA - [1, 0], # order + # Calculate numbers of grouped CTAs and the base row index + num_pid_m = gl.cdiv(M, ROWS_PER_CTA) + cta_base = start_pid * ROWS_PER_CTA + + # Layout descriptors for the first input + X1_SPT: gl.constexpr = min(16, BLOCK_SIZE_N // 128) + gLayout2D_x1: gl.constexpr = gl.BlockedLayout( + [1, X1_SPT], + [1, 32], + [1, 4], + [1, 0], + ) + gLayoutN_x1: gl.constexpr = gl.SliceLayout(0, gLayout2D_x1) + + # Layout descriptors for the second input + X2_SPT: gl.constexpr = min(16, BLOCK_SIZE_N2 // 128) + gLayout2D_x2: gl.constexpr = gl.BlockedLayout( + [1, X2_SPT], + [1, 32], + [1, 4], + [1, 0], ) - gLayoutN: gl.constexpr = gl.SliceLayout(0, gLayout2D) - - # 2D shared layout for matrix rows; 1D shared layout for weight vectors + gLayoutN_x2: gl.constexpr = gl.SliceLayout(0, gLayout2D_x2) sharedLayout2D: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) sharedLayoutN: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[0]) - # Tensor descriptors for first input and its weights x1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( x1_ptr, [M, N1], @@ -198,8 +206,6 @@ def _gluon_fused_rms_mxfp4_quant_kernel( [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D, ) - - # tensor descriptor for weight 1 w1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( w1_ptr, [N1], @@ -208,7 +214,6 @@ def _gluon_fused_rms_mxfp4_quant_kernel( sharedLayoutN, ) - # Shared memory for first input and its weights smemX1 = gl.allocate_shared_memory( x1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D ) @@ -216,10 +221,7 @@ def _gluon_fused_rms_mxfp4_quant_kernel( w1_ptr.dtype.element_ty, [BLOCK_SIZE_N], sharedLayoutN ) - # x1 load issued unconditionally for early latency hiding (OOB-safe for second-input programs) - gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1) - - # Tensor descriptor and shared memory for optional residual input + # Creates tensor descriptors and preloads residual input and output into shared memory if present if FIRST_INPUT_RES: res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( res1_ptr, @@ -228,15 +230,6 @@ def _gluon_fused_rms_mxfp4_quant_kernel( [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D, ) - - smemRes1 = gl.allocate_shared_memory( - res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D - ) - - gl.amd.gfx1250.tdm.async_load( - res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1 - ) - out_res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( out_res1_ptr, [M, N1], @@ -244,71 +237,78 @@ def _gluon_fused_rms_mxfp4_quant_kernel( [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D, ) - + smemRes1 = gl.allocate_shared_memory( + res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) smemOutRes1 = gl.allocate_shared_memory( out_res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D ) - # Second input path — programs with id >= num_pid_m handle x2 + # Handles second input path if present + if HAS_SECOND_INPUT: + x2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + x2_ptr, + [M, N2], + [x2_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N2], + sharedLayout2D, + ) + w2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + w2_ptr, + [N2], + [1], + [BLOCK_SIZE_N2], + sharedLayoutN, + ) + out2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out2_ptr, + [M, N2], + [out2_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N2], + sharedLayout2D, + ) + smemX2 = gl.allocate_shared_memory( + x2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D + ) + smemW2 = gl.allocate_shared_memory( + w2_ptr.dtype.element_ty, [BLOCK_SIZE_N2], sharedLayoutN + ) + smemOut2 = gl.allocate_shared_memory( + out2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D + ) + + # Checks if the current PID is in the second input path if start_pid >= num_pid_m: if HAS_SECOND_INPUT: - x2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - x2_ptr, - [M, N2], - [x2_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N2], - sharedLayout2D, - ) - w2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - w2_ptr, - [N2], - [1], - [BLOCK_SIZE_N2], - sharedLayoutN, - ) - smemX2 = gl.allocate_shared_memory( - x2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D - ) - smemW2 = gl.allocate_shared_memory( - w2_ptr.dtype.element_ty, [BLOCK_SIZE_N2], sharedLayoutN - ) - start_pid -= num_pid_m + x2_local_pid = start_pid - num_pid_m + x2_cta_base = x2_local_pid * ROWS_PER_CTA - gl.amd.gfx1250.tdm.async_load( - x2_desec, [start_pid * BLOCK_SIZE_M, 0], smemX2 - ) + gl.amd.gfx1250.tdm.async_load(x2_desec, [x2_cta_base, 0], smemX2) gl.amd.gfx1250.tdm.async_load(w2_desec, [0], smemW2) gl.amd.gfx1250.tdm.async_wait(0) - x2 = smemX2.load(gLayout2D).to(gl.float32) - w2 = smemW2.load(gLayoutN).to(gl.float32) + w2 = smemW2.load(gLayoutN_x2).to(gl.float32) w2 = w2.reshape(1, BLOCK_SIZE_N2) - w2 = gl.convert_layout(w2, gLayout2D) - norm2 = _rmsnorm_op(x2, w2, N2, eps2) - - # Store norm2 output via TDM - out2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( - out2_ptr, - [M, N2], - [out2_stride_m, 1], - [BLOCK_SIZE_M, BLOCK_SIZE_N2], - sharedLayout2D, - ) - smemOut2 = gl.allocate_shared_memory( - out2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D - ) - smemOut2.store(norm2.to(out2_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store( - out2_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut2 - ) - gl.amd.gfx1250.tdm.async_wait(0) + w2 = gl.convert_layout(w2, gLayout2D_x2) + + for i in range(ROWS_PER_CTA): + x2_row_abs = x2_cta_base + i + x2 = smemX2.load(gLayout2D_x2).to(gl.float32) + if i + 1 < ROWS_PER_CTA: + gl.amd.gfx1250.tdm.async_load(x2_desec, [x2_row_abs + 1, 0], smemX2) + norm2 = _rmsnorm_op(x2, w2, N2, eps2) + smemOut2.store(norm2.to(out2_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store(out2_desec, [x2_row_abs, 0], smemOut2) + gl.amd.gfx1250.tdm.async_wait(0) return + gl.amd.gfx1250.tdm.async_load(x1_desec, [cta_base, 0], smemX1) + if FIRST_INPUT_RES: + gl.amd.gfx1250.tdm.async_load(res1_desec, [cta_base, 0], smemRes1) gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1) NUM_QUANT_BLOCKS: gl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE - # Descriptor and smem for fp4 TDM async_store out1_fp4_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( out1_fp4_ptr, [M, N1 // 2], @@ -320,60 +320,6 @@ def _gluon_fused_rms_mxfp4_quant_kernel( out1_fp4_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N // 2], sharedLayout2D ) - bs_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) - bs_offs_n = gl.arange(0, NUM_QUANT_BLOCKS) - num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE - if SHUFFLE: - bs_offs_0 = bs_offs_m[:, None] >> 5 # // 32 - bs_offs_1 = bs_offs_m[:, None] & 31 # % 32 - bs_offs_2 = bs_offs_1 & 15 # % 16 - bs_offs_1 = bs_offs_1 >> 4 # // 16 - bs_offs_3 = bs_offs_n[None, :] >> 3 # // 8 - bs_offs_4 = bs_offs_n[None, :] & 7 # % 8 - bs_offs_5 = bs_offs_4 & 3 # % 4 - bs_offs_4 = bs_offs_4 >> 2 # // 4 - bs_offs = ( - bs_offs_1 - + bs_offs_4 * 2 - + bs_offs_2 * 2 * 2 - + bs_offs_5 * 2 * 2 * 16 - + bs_offs_3 * 2 * 2 * 16 * 4 - + bs_offs_0 * 2 * 16 * SCALE_N_PAD - ) - bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :] - else: - bs_offs = ( - bs_offs_m[:, None] * out1_bs_stride_m - + bs_offs_n[None, :] * out1_bs_stride_n - ) - - bs_mask = None - if not EVEN_M_N: - if SHUFFLE_PAD: - bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[ - None, : - ] - else: - bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] - - gl.amd.gfx1250.tdm.async_wait(0) - - x1 = smemX1.load(gLayout2D).to(gl.float32) - - if FIRST_INPUT_RES: - res1_loaded = smemRes1.load(gLayout2D).to(gl.float32) - x1 = x1 + res1_loaded - smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store( - out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1 - ) - - w1 = smemW1.load(gLayoutN).to(gl.float32) - w1 = w1.reshape(1, BLOCK_SIZE_N) - w1 = gl.convert_layout(w1, gLayout2D) - norm1 = _rmsnorm_op(x1, w1, N1, eps1) - - # Store unquantized output via TDM (optional) if FIRST_INPUT_OUT: out1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( out1_ptr, @@ -385,26 +331,91 @@ def _gluon_fused_rms_mxfp4_quant_kernel( smemOut1 = gl.allocate_shared_memory( out1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D ) - smemOut1.store(norm1.to(out1_ptr.dtype.element_ty)) - gl.amd.gfx1250.tdm.async_store( - out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1 - ) - out1_fp4, bs_e8m0 = _mxfp4_quant_op( - norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE - ) + num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + bs_offs_n = gl.arange(0, NUM_QUANT_BLOCKS) + gl.amd.gfx1250.tdm.async_wait(0) + + w1 = smemW1.load(gLayoutN_x1).to(gl.float32) + w1 = w1.reshape(1, BLOCK_SIZE_N) + w1 = gl.convert_layout(w1, gLayout2D_x1) + + # Loop through each row in the CTA + for i in range(ROWS_PER_CTA): + row_abs = cta_base + i + bs_offs_m_i = row_abs + gl.arange(0, BLOCK_SIZE_M) + + # Blockscale offset computation. + if SHUFFLE: + bs_offs_0 = bs_offs_m_i[:, None] >> 5 + bs_offs_1 = bs_offs_m_i[:, None] & 31 + bs_offs_2 = bs_offs_1 & 15 + bs_offs_1 = bs_offs_1 >> 4 + bs_offs_3 = bs_offs_n[None, :] >> 3 + bs_offs_4 = bs_offs_n[None, :] & 7 + bs_offs_5 = bs_offs_4 & 3 + bs_offs_4 = bs_offs_4 >> 2 + bs_offs_i = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * SCALE_N_PAD + ) + bs_mask_127_i = (bs_offs_m_i < M)[:, None] & (bs_offs_n < num_bs_cols)[ + None, : + ] + else: + bs_offs_i = ( + bs_offs_m_i[:, None] * out1_bs_stride_m + + bs_offs_n[None, :] * out1_bs_stride_n + ) - # Apply out-of-range scale mask - if SHUFFLE: - bs_e8m0 = gl.where(bs_mask_127, bs_e8m0, 127) + bs_mask_i = None + if not EVEN_M_N: + if SHUFFLE_PAD: + bs_mask_i = (bs_offs_m_i < SCALE_M_PAD)[:, None] & ( + bs_offs_n < SCALE_N_PAD + )[None, :] + else: + bs_mask_i = (bs_offs_m_i < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] + + x1 = smemX1.load(gLayout2D_x1).to(gl.float32) + if FIRST_INPUT_RES: + res1_loaded = smemRes1.load(gLayout2D_x1).to(gl.float32) + + # Fetch the next row while the current row is being processed + if i + 1 < ROWS_PER_CTA: + gl.amd.gfx1250.tdm.async_load(x1_desec, [row_abs + 1, 0], smemX1) + if FIRST_INPUT_RES: + gl.amd.gfx1250.tdm.async_load(res1_desec, [row_abs + 1, 0], smemRes1) + + if FIRST_INPUT_RES: + x1 = x1 + res1_loaded + smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store(out_res1_desec, [row_abs, 0], smemOutRes1) + + norm1 = _rmsnorm_op(x1, w1, N1, eps1) + + if FIRST_INPUT_OUT: + smemOut1.store(norm1.to(out1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store(out1_desec, [row_abs, 0], smemOut1) + + out1_fp4, bs_e8m0 = _mxfp4_quant_op( + norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + ) - smemOutFp4.store(out1_fp4) - gl.amd.gfx1250.tdm.async_store( - out1_fp4_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutFp4 - ) + if SHUFFLE: + bs_e8m0 = gl.where(bs_mask_127_i, bs_e8m0, 127) - gl.store( - out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask - ) + # Store the quantized and blockscale values + smemOutFp4.store(out1_fp4) + gl.amd.gfx1250.tdm.async_store(out1_fp4_desec, [row_abs, 0], smemOutFp4) + gl.store( + out1_bs_ptr + bs_offs_i, + bs_e8m0.to(out1_bs_ptr.type.element_ty), + mask=bs_mask_i, + ) - gl.amd.gfx1250.tdm.async_wait(0) + gl.amd.gfx1250.tdm.async_wait(0) diff --git a/aiter/ops/triton/quant/fused_mxfp4_quant.py b/aiter/ops/triton/quant/fused_mxfp4_quant.py index 426fe3b4b7..2d81aa535d 100644 --- a/aiter/ops/triton/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/quant/fused_mxfp4_quant.py @@ -68,7 +68,7 @@ def fused_rms_mxfp4_quant( # as we merge 2 fp4s to 1 uint8 assert N1 % 2 == 0 BLOCK_SIZE_M = 1 - # BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = max(BLOCK_SIZE_N, MXFP4_QUANT_BLOCK_SIZE) out1_fp4 = torch.empty((M, N1 // 2), dtype=torch.uint8, device=x1.device) SCALE_N_valid = triton.cdiv(N1, MXFP4_QUANT_BLOCK_SIZE) @@ -111,23 +111,19 @@ def fused_rms_mxfp4_quant( # checks args for either gluon, triton, or auto. auto will check for gfx1250 hardware and set to gluon if it exists, otherwise defaults to triton if inargs == "auto": - if get_arch() == "gfx1250": - kernel = _gluon_fused_rms_mxfp4_quant_kernel - else: - kernel = _fused_rms_mxfp4_quant_kernel + use_gluon = get_arch() == "gfx1250" elif inargs == "gluon": if get_arch() != "gfx1250": raise RuntimeError("Gluon kernel only supported on gfx1250 hardware") - kernel = _gluon_fused_rms_mxfp4_quant_kernel + use_gluon = True elif inargs == "triton": - kernel = _fused_rms_mxfp4_quant_kernel + use_gluon = False else: raise ValueError( f"Invalid argument: {inargs}. Choose from auto, gluon, or triton" ) - grid = (triton.cdiv(M, BLOCK_SIZE_M) * (2 if (x2 is not None) else 1),) - kernel[grid]( + _common_args = ( x1, x1_weight, x2, @@ -151,6 +147,8 @@ def fused_rms_mxfp4_quant( out2_stride_m, out_res1_stride_m, out1_stride_m, + ) + _common_kwargs = dict( BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_N2=BLOCK_SIZE_N2, @@ -165,6 +163,25 @@ def fused_rms_mxfp4_quant( SHUFFLE_PAD=use_scale_shuffle_padding, ) + if use_gluon: + # Aim for at least 32 CTAs to keep all WGPs fed. + _TARGET_MIN_CTAS = 32 + ROWS_PER_CTA = max(1, min(8, M // _TARGET_MIN_CTAS)) + while ROWS_PER_CTA > 1 and M % ROWS_PER_CTA != 0: + ROWS_PER_CTA //= 2 + grid = (triton.cdiv(M, ROWS_PER_CTA) * (1 if x2 is None else 2),) + _gluon_fused_rms_mxfp4_quant_kernel[grid]( + *_common_args, + **_common_kwargs, + ROWS_PER_CTA=ROWS_PER_CTA, + ) + else: + grid = (M * (1 if x2 is None else 2),) + _fused_rms_mxfp4_quant_kernel[grid]( + *_common_args, + **_common_kwargs, + ) + return (out1_fp4, out1_bs), out1, out2, out_res1 From 4293a528c243aa3d72ec4f8080a91ee15f816e94 Mon Sep 17 00:00:00 2001 From: Omar Muhammad Date: Fri, 5 Jun 2026 21:43:34 +0000 Subject: [PATCH 10/10] fix: restore composable_kernel submodule from main --- .gitmodules | 3 +++ 3rdparty/composable_kernel | 1 + 2 files changed, 4 insertions(+) create mode 160000 3rdparty/composable_kernel diff --git a/.gitmodules b/.gitmodules index e69de29bb2..61bc1cf7e0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "3rdparty/composable_kernel"] + path = 3rdparty/composable_kernel + url = https://github.com/ROCm/composable_kernel.git diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel new file mode 160000 index 0000000000..af7118e342 --- /dev/null +++ b/3rdparty/composable_kernel @@ -0,0 +1 @@ +Subproject commit af7118e342580ecd3f71edce7b1d0ba465012ecf