Skip to content

[MLA ASM kernel] add v4 MLA decode for mi350 (recompile mi300)#3112

Open
liyjiang wants to merge 1 commit intomainfrom
liyjiang/asm_mla_mi350_nm
Open

[MLA ASM kernel] add v4 MLA decode for mi350 (recompile mi300)#3112
liyjiang wants to merge 1 commit intomainfrom
liyjiang/asm_mla_mi350_nm

Conversation

@liyjiang
Copy link
Copy Markdown
Contributor

The v4 new model kernel uses an 18-slot kernarg ABI that's incompatible with v3's 14-slot layout:

  • slot 8 raw gqa_ratio (not s_MQA = gqa_ratio * max_seqlen_q)
  • slot 10 total_kv = kv_seq_lens * num_seqs
  • 4 new tail slots: ptr_STP, out_16_nosplit, ptr_QROPE, ptr_KVROPE
  • scalar = 1/sqrt(kV4DimNope + kV4DimRope) = 1/sqrt(512), independent of head_size
  • kernel hardcodes wave64 (gfx950) block dim = wv_tg * 64 = 256

Files:

  • csrc/cpp_itfs/mla/asm_mla_v4_nm.{py,cpp.jinja}: dispatcher + Jinja template for the C++ entry; uses absolute include paths via AITER_CORE_DIR.
  • hsa/gfx950/mla/mla_a8w8_qh64_qseqlen4_gqaratio16_nm_recmp.co: compiled shader. Built from poc_kl shaders/mla_a8w8_qh64_1tg_16mx4_ 16nx1_nm_recompile.s with the literal mla_kernel_func symbol sed-replaced to _ZN5aiter42mla_a8w8_qh64_qseqlen4_gqaratio16_nm_recmpE before reassembling with clang -x assembler --offload-arch=gfx950.
  • hsa/gfx950/mla/mla_asm.csv: registers the new variant for the AsmKernel registry.
  • aiter/mla.py: thin Python wrapper mla_decode_fwd_v4_nm.
  • op_tests/test_mla_v4_nm.py: 5 pytest cases covering smoke, wave64-launch regression guard, multi-batch determinism, the out_16_nosplit arg path, and unknown-variant error handling.

Variant currently shipped:
gqa_ratio=16, sub_Q=64, page_size=1, q_dtype=fp8, kv_dtype=fp8

Validated:

  • Same configuration passes in poc_kl with TEST PASSED (batch=2 kv_seq_lens=64 q_seq_lens=4 sub_Q=64 passes=1).
  • All 5 aiter pytest cases pass on gfx950 (mi350).

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@liyjiang liyjiang requested review from a team and valarLip May 11, 2026 02:47
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3112 --add-label <label>

The v4 new-model kernel uses an 18-slot kernarg ABI that's incompatible
with v3's 14-slot layout:

- slot 8  raw `gqa_ratio` (not `s_MQA = gqa_ratio * max_seqlen_q`)
- slot 10 `total_kv = kv_seq_lens * num_seqs`
- 4 new tail slots: ptr_STP, out_16_nosplit, ptr_QROPE, ptr_KVROPE
- scalar = 1/sqrt(kV4DimNope + kV4DimRope) = 1/sqrt(512), independent of
  head_size
- kernel hardcodes wave64 (gfx950) block dim = wv_tg * 64 = 256

Integration follows the canonical aiter JIT module pattern (peer of
asm_mla.cu / module_mla_asm, not the compile_template_op jinja path):

- csrc/py_itfs_cu/asm_mla_v4.cu: C-ABI dispatcher using aiter_tensor_t*,
  wrapped in AITER_CTYPES_DEFINE_ENTRYPOINT_VOID so AITER_CHECK failures
  surface as clean Python RuntimeError via the aiter_get_last_error TLS
  bridge (no std::abort()).
- hsa/gfx950/mla_v4/mla_v4_asm.csv: CSV registry consumed by
  hsa/codegen.py -m mla_v4 -> asm_mla_v4_configs.hpp / cfg_mla_v4_asm.
  Schema (qType, kvType, gqa, sub_Q, page_size, num_kv_splits, prefill,
  causal, knl_name, co_name) is independent of mla_asm.csv so v3 / v4
  dispatchers don't collide.
- hsa/gfx950/mla_v4/mla_a8w8_qh64_qseqlen4_gqaratio16_nm_recmp.co:
  compiled shader. Built from poc_kl/mi350/mla_asm/shaders/mla_a8w8_qh64
  _1tg_16mx4_16nx1_nm_recompile.s with the literal `mla_kernel_func`
  symbol sed-replaced to `_ZN5aiter42mla_a8w8_qh64_qseqlen4_gqaratio16
  _nm_recmpE` then reassembled with `clang -x assembler
  --offload-arch=gfx950`.
- aiter/jit/optCompilerConfig.json: module_mla_v4_asm entry with
  blob_gen_cmd = codegen.py -m mla_v4.
- aiter/ops/attention.py: mla_decode_v4_asm Python signature
  (@compile_ops ffi_type="ctypes").
- aiter/mla.py: thin wrapper `mla_decode_fwd_v4_nm` routes through
  aiter.mla_decode_v4_asm and allocates the canonical 5D
  [num_seqs, num_kv_splits, num_kv_heads, gqa*max_seqlen_q, v_head_dim]
  split-logits / attn_lse internally.
- aiter/utility/dtypes.py: _torch_to_aiter_dtype now accepts both
  torch.float8_e4m3fn (OCP) and torch.float8_e4m3fnuz (NUZ); both map
  to AITER_DTYPE_fp8, so callers on gfx942 / gfx950 can use either
  variant transparently.

Tests (op_tests/test_mla_v4_nm.py):

- test_v4_nm_smoke_default_shape    : dispatcher -> kernel launch ->
                                       SENTINEL coverage check.
- test_v4_nm_no_half_zero_pattern   : wave64-launch regression guard
                                       (the historic 256-NaN + 256-zero
                                       output pattern on gfx950).
- test_v4_nm_determinism            : two back-to-back launches are
                                       bit-identical.
- test_v4_nm_out_16_nosplit_arg_accepted
                                    : out_16_nosplit=1 path doesn't
                                       crash the dispatcher.
- test_v4_nm_unknown_variant_raises : sub_Q=128 (not in cfg_mla_v4_asm)
                                       triggers a clean RuntimeError via
                                       the C-ABI error bridge, no
                                       silently-loaded wrong .co.
- test_v4_nm_kernarg_scalar_slots   : env-gated kernarg hexdump
                                       captured via capfd; scalar slots
                                       7-12 + 15 byte-exact match
                                       poc_kl's MlaV4HipKernelArgs.

Numerical correctness (op_tests/test_mla_v4_nm_golden.py):

- Byte-exact compare against poc_kl gpu_SPLIT_DATA.hex. Inputs loaded
  via raw hipMemcpy H2D from device-side `.bin` dumps produced by an
  env-gated patch on poc_kl side
  (POC_KL_V4_DUMP_DEVBINS=1 in mla_execute_v4_hip.inl). Bypasses the
  brittle path of reconstructing inputs from hex dumps; sidesteps the
  v4 nm kernel's degenerate code path on uniform e8m0 scale bytes
  (which would produce 100% NaN for synthetic all-1.0 inputs).
- CI-safe: libamdhip64.so is lazy-loaded inside the test body, and the
  test SKIPs when POC_KL_DUMP_DIR / devbins / gpu_SPLIT_DATA.hex are
  missing or when the GPU is not gfx950.

Variant currently shipped:
  gqa_ratio=16, sub_Q=64, page_size=1, q_dtype=fp8, kv_dtype=fp8

Validated on gfx950 (MI355X):
- 6 baseline pytest cases pass.
- Byte-exact golden test passes (0/65536 differ vs poc_kl).

Co-authored-by: Cursor <cursoragent@cursor.com>
@liyjiang liyjiang force-pushed the liyjiang/asm_mla_mi350_nm branch from 39e764a to 8abec4b Compare May 11, 2026 11:36
import re
lines = captured.err.splitlines()
try:
start = next(i for i, l in enumerate(lines)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E741> reported by reviewdog 🐶
Ambiguous variable name: l

import struct
return struct.unpack("<f", slot(i)[:4])[0]

import math, struct
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E401> reported by reviewdog 🐶
Multiple imports on one line

Suggested change
import math, struct
import math
import struct

import struct
return struct.unpack("<f", slot(i)[:4])[0]

import math, struct
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
math imported but unused

Suggested change
import math, struct
import struct

# otherwise hard-error before the needs_gfx950 / needs_dumps skipif decorators
# get a chance to run).
# ---------------------------------------------------------------------------
import ctypes
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant