Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9d393ba
Add initial config for A16W16
vgokhale Jan 22, 2026
5f85097
Fix for pytorch manual_seed
vgokhale Jan 22, 2026
5c5eccb
test_moe_gemm_a8w4, test_fused_qkv_split_qk_rope changes to work with…
Boss2002n Jan 22, 2026
681d10b
fixes needed for test_fused_fp8_quant: added fp8 defaultdtype for gfx…
ahmed-bsod Jan 23, 2026
c753970
shapes tested and passing for fused_fp8_quant
ahmed-bsod Jan 23, 2026
e0230cf
Remove manual seeds. Add gfx12 to device list
azaidy Jan 23, 2026
9467ee3
removing the hip workaround added earlier in gemm_a8w8 kernel
ahmed-bsod Jan 29, 2026
d273c20
rmsnorm and fused_mul_add
Boss2002n Jan 29, 2026
e9e4aa9
moe_gemm_a8w8
Boss2002n Jan 29, 2026
b250948
added gfx1250 to get_fp8_dtypes() function
ahmed-bsod Feb 6, 2026
c386748
gluon gemm a8w8 in progress, slice layout issue
amirumoAMD Feb 11, 2026
d808afe
Revert "gluon gemm a8w8 in progress, slice layout issue"
amirumoAMD Feb 12, 2026
fdeef66
[Gluon] Unified Attention 3D development for gfx12 (#2048)
k50112113 Feb 16, 2026
6d56150
[TRITON] [GLUON] Adding 2d unified attention Gluon kernel (#2112)
cagrikymk Feb 26, 2026
e46d13d
[TRITON][GLUON] add TDM Gather to 2D Attention (#2155)
cagrikymk Mar 2, 2026
e7e62d9
[mi450] [gluon] UA3D updates (#2119)
k50112113 Mar 3, 2026
6fa3e3f
Add gfx1250 support: GFX_MAP + default GEMM/MHA configs (#2315)
vgokhale Mar 18, 2026
3bbe486
[TRITON]: Add gfx1250 arch enablement: fp8 support + test refactoring…
vgokhale Mar 19, 2026
4e8856a
Add gfx1250 to GFX_MAP in chip_info.py
Boss2002n Mar 24, 2026
09837f8
Merge branch 'main' into shared/triton-gfx12
Boss2002n Mar 24, 2026
1b94b36
Merge branch 'main' into shared/triton-gfx12
Boss2002n Mar 26, 2026
fe37ad7
Merge branch 'main' into shared/triton-gfx12
Boss2002n Mar 27, 2026
73c3f69
Merge branch 'main' into shared/triton-gfx12
Boss2002n Mar 27, 2026
bfcb3e7
Merge branch 'main' into shared/triton-gfx12
Boss2002n Mar 29, 2026
1c89036
Merge branch 'main' into shared/triton-gfx12
Boss2002n Mar 30, 2026
9347aad
Merge branch 'main' into shared/triton-gfx12
Boss2002n Mar 31, 2026
3cab881
fuse_mxfp4_quant gluon kernel for gfx1250
amd-jrosas May 8, 2026
cb11cb6
Format changes and removed unused variables
amd-jrosas May 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel deleted from 345a56
Empty file.
302 changes: 302 additions & 0 deletions aiter/ops/triton/_gluon_kernels/quant/fuse_mxfp4_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
import triton
from triton.experimental import gluon
from aiter.ops.triton._triton_kernels.quant.fused_mxfp4_quant import _mxfp4_quant_op
from triton.experimental.gluon import language as gl


@gluon.jit
def _rmsnorm_op(
row,
weights,
n_cols,
epsilon,
):

row_norm = row * row
row_norm = gl.sum(row_norm, axis=-1, keep_dims=True)
norm_factor = gl.rsqrt((row_norm / n_cols) + epsilon)

rms_norm = row * norm_factor * weights
return rms_norm


@triton.heuristics(
{
"EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0
and args["N1"] % (args["BLOCK_SIZE_N"]) == 0,
}
)
@gluon.jit
def _gluon_fused_rms_mxfp4_quant(
x1_ptr,
w1_ptr,
x2_ptr,
w2_ptr,
res1_ptr,
out1_fp4_ptr,
out1_bs_ptr,
out2_ptr,
out_res1_ptr,
out1_ptr,
eps1,
eps2,
M,
N1,
N2,
x1_stride_m,
x2_stride_m,
res1_stride_m,
out1_fp4_stride_m,
out1_bs_stride_m,
out1_bs_stride_n,
out2_stride_m,
out_res1_stride_m,
out1_stride_m,
BLOCK_SIZE_M: gl.constexpr,
BLOCK_SIZE_N: gl.constexpr,
BLOCK_SIZE_N2: gl.constexpr,
MXFP4_QUANT_BLOCK_SIZE: gl.constexpr,
HAS_SECOND_INPUT: gl.constexpr,
FIRST_INPUT_RES: gl.constexpr,
FIRST_INPUT_OUT: gl.constexpr,
SCALE_N: gl.constexpr,
SCALE_M_PAD: gl.constexpr,
SCALE_N_PAD: gl.constexpr,
SHUFFLE: gl.constexpr,
SHUFFLE_PAD: gl.constexpr,
EVEN_M_N: gl.constexpr,
):
start_pid = gl.program_id(0)
# get number of programs to determine is 1 or 2 passes
num_pid_m = gl.cdiv(M, BLOCK_SIZE_M)

# create block layouts
gLayout2D: gl.constexpr = gl.BlockedLayout(
[1, 2], # sizePerThread
[1, 32], # threadsPerWarp
[1, 4], # warpsPerCTA
[1, 0], # order
)

gLayoutM: gl.constexpr = gl.SliceLayout(1, gLayout2D)
gLayoutN: gl.constexpr = gl.SliceLayout(0, gLayout2D)

# 2D shared layout for matrix rows; 1D shared layout for weight vectors
sharedLayout2D: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[1, 0])
sharedLayoutN: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[0])

# Tensor descriptors for first input and its weights
x1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor(
x1_ptr,
[M, N1],
[x1_stride_m, 1],
[BLOCK_SIZE_M, BLOCK_SIZE_N],
sharedLayout2D,
)

w1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor(
w1_ptr,
[N1],
[1],
[BLOCK_SIZE_N],
sharedLayoutN,
)

# Shared memory for first input and its weights
smemX1 = gl.allocate_shared_memory(
x1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D
)
smemW1 = gl.allocate_shared_memory(
w1_ptr.dtype.element_ty, [BLOCK_SIZE_N], sharedLayoutN
)

# Tensor descriptor and shared memory for optional residual input
if FIRST_INPUT_RES:
res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor(
res1_ptr,
[M, N1],
[res1_stride_m, 1],
[BLOCK_SIZE_M, BLOCK_SIZE_N],
sharedLayout2D,
)
smemRes1 = gl.allocate_shared_memory(
res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D
)

# Second input path — programs with id >= num_pid_m handle x2
if start_pid >= num_pid_m:
if HAS_SECOND_INPUT:
x2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor(
x2_ptr,
[M, N2],
[x2_stride_m, 1],
[BLOCK_SIZE_M, BLOCK_SIZE_N2],
sharedLayout2D,
)
w2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor(
w2_ptr,
[N2],
[1],
[BLOCK_SIZE_N2],
sharedLayoutN,
)
smemX2 = gl.allocate_shared_memory(
x2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D
)
smemW2 = gl.allocate_shared_memory(
w2_ptr.dtype.element_ty, [BLOCK_SIZE_N2], sharedLayoutN
)

start_pid -= num_pid_m

# Load x2 and w2 in parallel then wait for both
gl.amd.gfx1250.tdm.async_load(
x2_desec, [start_pid * BLOCK_SIZE_M, 0], smemX2
)
gl.amd.gfx1250.tdm.async_load(w2_desec, [0], smemW2)
gl.amd.gfx1250.tdm.async_wait(0)

x2 = smemX2.load(gLayout2D).to(gl.float32)
w2 = smemW2.load(gLayoutN).to(gl.float32)
w2 = w2.reshape(1, BLOCK_SIZE_N2)
w2 = gl.convert_layout(w2, gLayout2D)
norm2 = _rmsnorm_op(x2, w2, N2, eps2)

# Store norm2 output via TDM
out2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor(
out2_ptr,
[M, N2],
[out2_stride_m, 1],
[BLOCK_SIZE_M, BLOCK_SIZE_N2],
sharedLayout2D,
)
smemOut2 = gl.allocate_shared_memory(
out2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D
)
smemOut2.store(norm2.to(out2_ptr.dtype.element_ty))
gl.amd.gfx1250.tdm.async_store(
out2_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut2
)
gl.amd.gfx1250.tdm.async_wait(0)
return

# First input path
NUM_QUANT_BLOCKS: gl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE
x_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M, gLayoutM)

# Load x1 and optionally res1 in parallel, then wait
gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1)
if FIRST_INPUT_RES:
gl.amd.gfx1250.tdm.async_load(
res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1
)
gl.amd.gfx1250.tdm.async_wait(0)

x1 = smemX1.load(gLayout2D).to(gl.float32)

if FIRST_INPUT_RES:
res1_loaded = smemRes1.load(gLayout2D).to(gl.float32)
x1 = x1 + res1_loaded

# Load w1 and wait
gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1)
gl.amd.gfx1250.tdm.async_wait(0)

w1 = smemW1.load(gLayoutN).to(gl.float32)
w1 = w1.reshape(1, BLOCK_SIZE_N)
w1 = gl.convert_layout(w1, gLayout2D)
norm1 = _rmsnorm_op(x1, w1, N1, eps1)

# Store unquantized output via TDM (optional)
if FIRST_INPUT_OUT:
out1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor(
out1_ptr,
[M, N1],
[out1_stride_m, 1],
[BLOCK_SIZE_M, BLOCK_SIZE_N],
sharedLayout2D,
)
smemOut1 = gl.allocate_shared_memory(
out1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D
)
smemOut1.store(norm1.to(out1_ptr.dtype.element_ty))
gl.amd.gfx1250.tdm.async_store(
out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1
)
gl.amd.gfx1250.tdm.async_wait(0)

out1_fp4, bs_e8m0 = _mxfp4_quant_op(
norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE
)
out1_fp4 = gl.convert_layout(out1_fp4, gLayout2D)

# out1_fp4 uses half-width (packed) offsets — keep as regular store
half_x_offs_n = gl.arange(0, BLOCK_SIZE_N // 2)
out_mask1 = (half_x_offs_n < (N1 // 2))[None, :]
if not EVEN_M_N:
out_mask1 = out_mask1 & (x_offs_m < M)[:, None]
gl.store(
out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :],
out1_fp4,
mask=out_mask1,
)

# shuffle
bs_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M)
bs_offs_n = gl.arange(0, NUM_QUANT_BLOCKS)
num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE
if SHUFFLE:
bs_offs_0 = bs_offs_m[:, None] // 32
bs_offs_1 = bs_offs_m[:, None] % 32
bs_offs_2 = bs_offs_1 % 16
bs_offs_1 = bs_offs_1 // 16
bs_offs_3 = bs_offs_n[None, :] // 8
bs_offs_4 = bs_offs_n[None, :] % 8
bs_offs_5 = bs_offs_4 % 4
bs_offs_4 = bs_offs_4 // 4
bs_offs = (
bs_offs_1
+ bs_offs_4 * 2
+ bs_offs_2 * 2 * 2
+ bs_offs_5 * 2 * 2 * 16
+ bs_offs_3 * 2 * 2 * 16 * 4
+ bs_offs_0 * 2 * 16 * SCALE_N_PAD
)
bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :]
bs_e8m0 = gl.where(bs_mask_127, bs_e8m0, 127)
else:
bs_offs = (
bs_offs_m[:, None] * out1_bs_stride_m
+ bs_offs_n[None, :] * out1_bs_stride_n
)

bs_mask = None
if not EVEN_M_N:
if not SHUFFLE_PAD:
bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :]
else:
bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[
None, :
]

gl.store(
out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask
)

# Store residual output via TDM
if FIRST_INPUT_RES:
out_res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor(
out_res1_ptr,
[M, N1],
[out_res1_stride_m, 1],
[BLOCK_SIZE_M, BLOCK_SIZE_N],
sharedLayout2D,
)
smemOutRes1 = gl.allocate_shared_memory(
out_res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D
)
smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty))
gl.amd.gfx1250.tdm.async_store(
out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1
)
gl.amd.gfx1250.tdm.async_wait(0)
Loading
Loading