Skip to content

Commit 8abec4b

Browse files
liyjiangcursoragent
andcommitted
[MLA ASM kernel] add v4 MLA decode for mi350 (recompile mi300)
The v4 new-model kernel uses an 18-slot kernarg ABI that's incompatible with v3's 14-slot layout: - slot 8 raw `gqa_ratio` (not `s_MQA = gqa_ratio * max_seqlen_q`) - slot 10 `total_kv = kv_seq_lens * num_seqs` - 4 new tail slots: ptr_STP, out_16_nosplit, ptr_QROPE, ptr_KVROPE - scalar = 1/sqrt(kV4DimNope + kV4DimRope) = 1/sqrt(512), independent of head_size - kernel hardcodes wave64 (gfx950) block dim = wv_tg * 64 = 256 Integration follows the canonical aiter JIT module pattern (peer of asm_mla.cu / module_mla_asm, not the compile_template_op jinja path): - csrc/py_itfs_cu/asm_mla_v4.cu: C-ABI dispatcher using aiter_tensor_t*, wrapped in AITER_CTYPES_DEFINE_ENTRYPOINT_VOID so AITER_CHECK failures surface as clean Python RuntimeError via the aiter_get_last_error TLS bridge (no std::abort()). - hsa/gfx950/mla_v4/mla_v4_asm.csv: CSV registry consumed by hsa/codegen.py -m mla_v4 -> asm_mla_v4_configs.hpp / cfg_mla_v4_asm. Schema (qType, kvType, gqa, sub_Q, page_size, num_kv_splits, prefill, causal, knl_name, co_name) is independent of mla_asm.csv so v3 / v4 dispatchers don't collide. - hsa/gfx950/mla_v4/mla_a8w8_qh64_qseqlen4_gqaratio16_nm_recmp.co: compiled shader. Built from poc_kl/mi350/mla_asm/shaders/mla_a8w8_qh64 _1tg_16mx4_16nx1_nm_recompile.s with the literal `mla_kernel_func` symbol sed-replaced to `_ZN5aiter42mla_a8w8_qh64_qseqlen4_gqaratio16 _nm_recmpE` then reassembled with `clang -x assembler --offload-arch=gfx950`. - aiter/jit/optCompilerConfig.json: module_mla_v4_asm entry with blob_gen_cmd = codegen.py -m mla_v4. - aiter/ops/attention.py: mla_decode_v4_asm Python signature (@compile_ops ffi_type="ctypes"). - aiter/mla.py: thin wrapper `mla_decode_fwd_v4_nm` routes through aiter.mla_decode_v4_asm and allocates the canonical 5D [num_seqs, num_kv_splits, num_kv_heads, gqa*max_seqlen_q, v_head_dim] split-logits / attn_lse internally. - aiter/utility/dtypes.py: _torch_to_aiter_dtype now accepts both torch.float8_e4m3fn (OCP) and torch.float8_e4m3fnuz (NUZ); both map to AITER_DTYPE_fp8, so callers on gfx942 / gfx950 can use either variant transparently. Tests (op_tests/test_mla_v4_nm.py): - test_v4_nm_smoke_default_shape : dispatcher -> kernel launch -> SENTINEL coverage check. - test_v4_nm_no_half_zero_pattern : wave64-launch regression guard (the historic 256-NaN + 256-zero output pattern on gfx950). - test_v4_nm_determinism : two back-to-back launches are bit-identical. - test_v4_nm_out_16_nosplit_arg_accepted : out_16_nosplit=1 path doesn't crash the dispatcher. - test_v4_nm_unknown_variant_raises : sub_Q=128 (not in cfg_mla_v4_asm) triggers a clean RuntimeError via the C-ABI error bridge, no silently-loaded wrong .co. - test_v4_nm_kernarg_scalar_slots : env-gated kernarg hexdump captured via capfd; scalar slots 7-12 + 15 byte-exact match poc_kl's MlaV4HipKernelArgs. Numerical correctness (op_tests/test_mla_v4_nm_golden.py): - Byte-exact compare against poc_kl gpu_SPLIT_DATA.hex. Inputs loaded via raw hipMemcpy H2D from device-side `.bin` dumps produced by an env-gated patch on poc_kl side (POC_KL_V4_DUMP_DEVBINS=1 in mla_execute_v4_hip.inl). Bypasses the brittle path of reconstructing inputs from hex dumps; sidesteps the v4 nm kernel's degenerate code path on uniform e8m0 scale bytes (which would produce 100% NaN for synthetic all-1.0 inputs). - CI-safe: libamdhip64.so is lazy-loaded inside the test body, and the test SKIPs when POC_KL_DUMP_DIR / devbins / gpu_SPLIT_DATA.hex are missing or when the GPU is not gfx950. Variant currently shipped: gqa_ratio=16, sub_Q=64, page_size=1, q_dtype=fp8, kv_dtype=fp8 Validated on gfx950 (MI355X): - 6 baseline pytest cases pass. - Byte-exact golden test passes (0/65536 differ vs poc_kl). Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 45c428e commit 8abec4b

11 files changed

Lines changed: 1135 additions & 1 deletion

aiter/jit/optCompilerConfig.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,17 @@
138138
"verbose": "False",
139139
"blob_gen_cmd": "f'{AITER_META_DIR}/hsa/codegen.py -m mla --output_dir {{}}'"
140140
},
141+
"module_mla_v4_asm": {
142+
"srcs": [
143+
"f'{AITER_CSRC_DIR}/py_itfs_cu/asm_mla_v4.cu'"
144+
],
145+
"flags_extra_cc": [],
146+
"flags_extra_hip": [],
147+
"extra_ldflags": "None",
148+
"extra_include": [],
149+
"verbose": "False",
150+
"blob_gen_cmd": "f'{AITER_META_DIR}/hsa/codegen.py -m mla_v4 --output_dir {{}}'"
151+
},
141152
"module_cache": {
142153
"srcs": [
143154
"f'{AITER_CSRC_DIR}/pybind/cache_pybind.cu'",

aiter/mla.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,3 +973,74 @@ def mla_prefill_reduce(
973973
) # [tile_q, v_head_dim]
974974

975975
output[qo_start:qo_end, head_idx, :] = final_output[:q_len, :]
976+
977+
978+
# ---------------------------------------------------------------------------
979+
# DSv4 MLA — additive entry point. Does NOT touch any existing
980+
# gqa_ratio=16, sub_Q=64, page_size=1, q_dtype=fp8, kv_dtype=fp8
981+
# ---------------------------------------------------------------------------
982+
def mla_decode_fwd_v4_nm(
983+
q, # [total_query_len, num_heads, head_size] FP8 packed Q+e8m0
984+
qrope, # [total_query_len, num_heads, kv_rotary] BF16
985+
kv_buffer, # [num_page, page_size, num_kv_heads, dim_qk_packed]
986+
kvrope, # [num_page, page_size, num_kv_heads, kv_rotary]
987+
output, # [total_query_len, num_heads, v_head_dim] BF16 (used for out_16_nosplit=1)
988+
qo_indptr, # [num_seqs+1]
989+
kv_indptr, # [num_seqs+1]
990+
kv_page_indices, # [num_page_used]
991+
kv_last_page_lens, # [num_seqs]
992+
split_indptr, # [num_seqs+1]
993+
max_seqlen_q,
994+
sm_scale=None, # ignored on v4 nm; kernel hardcodes 1/sqrt(512)
995+
out_16_nosplit=0,
996+
num_kv_splits=1,
997+
sub_Q=64,
998+
logits=None,
999+
attn_lse=None,
1000+
):
1001+
"""v4 nm-recompile MLA decode forward (mi350 / gfx950 wave64).
1002+
1003+
Routes through the canonical aiter JIT C-ABI module
1004+
`module_mla_v4_asm` (csrc/py_itfs_cu/asm_mla_v4.cu). Returns
1005+
(logits, attn_lse) — caller is responsible for any post-reduce /
1006+
final-O work.
1007+
1008+
logits/attn_lse may be pre-allocated (e.g. for SENTINEL pre-fill in
1009+
correctness tests); if not, we allocate the canonical 5D layout
1010+
`[num_seqs, num_kv_splits, num_kv_heads, gqa*max_seqlen_q, v_head_dim]`
1011+
(and `[..., 1]` for attn_lse) inferred from the input shapes.
1012+
"""
1013+
num_seqs = qo_indptr.shape[0] - 1
1014+
num_heads = q.size(1)
1015+
v_head_dim = output.size(2)
1016+
num_kv_heads = kv_buffer.size(2)
1017+
gqa_ratio = num_heads // num_kv_heads
1018+
q_seq_lens_internal = gqa_ratio * max_seqlen_q
1019+
1020+
if logits is None:
1021+
logits = torch.empty(
1022+
(num_seqs, num_kv_splits, num_kv_heads, q_seq_lens_internal, v_head_dim),
1023+
dtype=dtypes.fp32, device=q.device,
1024+
)
1025+
if attn_lse is None:
1026+
attn_lse = torch.empty(
1027+
(num_seqs, num_kv_splits, num_kv_heads, q_seq_lens_internal, 1),
1028+
dtype=dtypes.fp32, device=q.device,
1029+
)
1030+
1031+
# softmax_scale is ignored by the v4 nm kernel (hardcodes 1/sqrt(512));
1032+
# we still pass *something* through to satisfy the C ABI.
1033+
sm_scale_arg = 0.0 if sm_scale is None else float(sm_scale)
1034+
1035+
aiter.mla_decode_v4_asm(
1036+
q, qrope, kv_buffer, kvrope,
1037+
qo_indptr, kv_indptr, kv_page_indices, kv_last_page_lens,
1038+
split_indptr,
1039+
max_seqlen_q,
1040+
sm_scale_arg,
1041+
int(out_16_nosplit),
1042+
int(sub_Q),
1043+
int(num_kv_splits),
1044+
logits, attn_lse, output,
1045+
)
1046+
return logits, attn_lse

aiter/ops/attention.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,48 @@ def mla_decode_stage1_asm_fwd(
688688
) -> None: ...
689689

690690

691+
MD_NAME_V4 = "module_mla_v4_asm"
692+
693+
694+
@compile_ops(MD_NAME_V4, ffi_type="ctypes")
695+
def mla_decode_v4_asm(
696+
# [total_query_len, num_heads, head_size] FP8 packed Q + e8m0 scale region
697+
Q: torch.Tensor,
698+
# [total_query_len, num_heads, kv_rotary] BF16
699+
qrope: torch.Tensor,
700+
# [num_page, page_size, num_kv_heads, head_size] FP8
701+
KV: torch.Tensor,
702+
# [num_page, page_size, num_kv_heads, kv_rotary] BF16
703+
kvrope: torch.Tensor,
704+
# [num_seqs+1]
705+
qo_indptr: torch.Tensor,
706+
# [num_seqs+1]
707+
kv_indptr: torch.Tensor,
708+
# [num_page_used]
709+
kv_page_indices: torch.Tensor,
710+
# [num_seqs]
711+
kv_last_page_lens: torch.Tensor,
712+
# [num_seqs+1]
713+
split_indptr: torch.Tensor,
714+
max_seqlen_q: int,
715+
# ignored on v4 nm; kernel hardcodes 1/sqrt(kV4DimNope+kV4DimRope)=1/sqrt(512)
716+
softmax_scale: float,
717+
# 0 = fp32 split-out path; 1 = bf16 nosplit reduce path
718+
out_16_nosplit: int,
719+
# poc_kl `sub_Q` (= per-WG Q tile); only 64 currently shipped
720+
sub_Q: int,
721+
# poc_kl `passes`
722+
num_kv_splits: int,
723+
# outputs
724+
# [num_seqs, num_kv_splits, num_kv_heads, gqa*max_seqlen_q, v_head_dim] FP32
725+
splitData: torch.Tensor,
726+
# [num_seqs, num_kv_splits, num_kv_heads, gqa*max_seqlen_q, 1] FP32
727+
splitLse: torch.Tensor,
728+
# [total_query_len, num_heads, v_head_dim] BF16 (used when out_16_nosplit==1)
729+
output: torch.Tensor,
730+
) -> None: ...
731+
732+
691733
@compile_ops(MD_NAME, ffi_type="ctypes")
692734
def mla_prefill_asm_fwd(
693735
# [num_seqs, num_heads, head_size]

aiter/utility/dtypes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ def get_dtype_fp8():
3939
globals().update({f"AITER_DTYPE_{name}": idx for name, idx in aiter_dtypes.items()})
4040
_torch_to_aiter_dtype = {globals()[name]: idx for name, idx in aiter_dtypes.items()}
4141

42+
# Both e4m3fn (OCP) and e4m3fnuz (ROCm NUZ) are valid FP8 variants at the
43+
# byte level for kernels that just read raw FP8 bytes. Map both torch dtypes
44+
# to the same AITER_DTYPE_fp8 enum so the strict dtype check in
45+
# torch_to_aiter() / torch_to_aiter_pybind() accepts whichever variant the
46+
# caller has — letting v3/v4 MLA tests use either dtype interchangeably.
47+
if "fp8" in aiter_dtypes:
48+
for _alt_fp8 in (torch.float8_e4m3fn, torch.float8_e4m3fnuz):
49+
_torch_to_aiter_dtype.setdefault(_alt_fp8, aiter_dtypes["fp8"])
50+
4251

4352
def torch_to_aiter_pybind(tensor: torch.Tensor):
4453
"""Convert torch.Tensor to pybind aiter_tensor_t for passing to C++ ops.

0 commit comments

Comments
 (0)