[MLA ASM kernel] add v4 MLA decode for mi350 (recompile mi300)#3112
Open
[MLA ASM kernel] add v4 MLA decode for mi350 (recompile mi300)#3112
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
The v4 new-model kernel uses an 18-slot kernarg ABI that's incompatible
with v3's 14-slot layout:
- slot 8 raw `gqa_ratio` (not `s_MQA = gqa_ratio * max_seqlen_q`)
- slot 10 `total_kv = kv_seq_lens * num_seqs`
- 4 new tail slots: ptr_STP, out_16_nosplit, ptr_QROPE, ptr_KVROPE
- scalar = 1/sqrt(kV4DimNope + kV4DimRope) = 1/sqrt(512), independent of
head_size
- kernel hardcodes wave64 (gfx950) block dim = wv_tg * 64 = 256
Integration follows the canonical aiter JIT module pattern (peer of
asm_mla.cu / module_mla_asm, not the compile_template_op jinja path):
- csrc/py_itfs_cu/asm_mla_v4.cu: C-ABI dispatcher using aiter_tensor_t*,
wrapped in AITER_CTYPES_DEFINE_ENTRYPOINT_VOID so AITER_CHECK failures
surface as clean Python RuntimeError via the aiter_get_last_error TLS
bridge (no std::abort()).
- hsa/gfx950/mla_v4/mla_v4_asm.csv: CSV registry consumed by
hsa/codegen.py -m mla_v4 -> asm_mla_v4_configs.hpp / cfg_mla_v4_asm.
Schema (qType, kvType, gqa, sub_Q, page_size, num_kv_splits, prefill,
causal, knl_name, co_name) is independent of mla_asm.csv so v3 / v4
dispatchers don't collide.
- hsa/gfx950/mla_v4/mla_a8w8_qh64_qseqlen4_gqaratio16_nm_recmp.co:
compiled shader. Built from poc_kl/mi350/mla_asm/shaders/mla_a8w8_qh64
_1tg_16mx4_16nx1_nm_recompile.s with the literal `mla_kernel_func`
symbol sed-replaced to `_ZN5aiter42mla_a8w8_qh64_qseqlen4_gqaratio16
_nm_recmpE` then reassembled with `clang -x assembler
--offload-arch=gfx950`.
- aiter/jit/optCompilerConfig.json: module_mla_v4_asm entry with
blob_gen_cmd = codegen.py -m mla_v4.
- aiter/ops/attention.py: mla_decode_v4_asm Python signature
(@compile_ops ffi_type="ctypes").
- aiter/mla.py: thin wrapper `mla_decode_fwd_v4_nm` routes through
aiter.mla_decode_v4_asm and allocates the canonical 5D
[num_seqs, num_kv_splits, num_kv_heads, gqa*max_seqlen_q, v_head_dim]
split-logits / attn_lse internally.
- aiter/utility/dtypes.py: _torch_to_aiter_dtype now accepts both
torch.float8_e4m3fn (OCP) and torch.float8_e4m3fnuz (NUZ); both map
to AITER_DTYPE_fp8, so callers on gfx942 / gfx950 can use either
variant transparently.
Tests (op_tests/test_mla_v4_nm.py):
- test_v4_nm_smoke_default_shape : dispatcher -> kernel launch ->
SENTINEL coverage check.
- test_v4_nm_no_half_zero_pattern : wave64-launch regression guard
(the historic 256-NaN + 256-zero
output pattern on gfx950).
- test_v4_nm_determinism : two back-to-back launches are
bit-identical.
- test_v4_nm_out_16_nosplit_arg_accepted
: out_16_nosplit=1 path doesn't
crash the dispatcher.
- test_v4_nm_unknown_variant_raises : sub_Q=128 (not in cfg_mla_v4_asm)
triggers a clean RuntimeError via
the C-ABI error bridge, no
silently-loaded wrong .co.
- test_v4_nm_kernarg_scalar_slots : env-gated kernarg hexdump
captured via capfd; scalar slots
7-12 + 15 byte-exact match
poc_kl's MlaV4HipKernelArgs.
Numerical correctness (op_tests/test_mla_v4_nm_golden.py):
- Byte-exact compare against poc_kl gpu_SPLIT_DATA.hex. Inputs loaded
via raw hipMemcpy H2D from device-side `.bin` dumps produced by an
env-gated patch on poc_kl side
(POC_KL_V4_DUMP_DEVBINS=1 in mla_execute_v4_hip.inl). Bypasses the
brittle path of reconstructing inputs from hex dumps; sidesteps the
v4 nm kernel's degenerate code path on uniform e8m0 scale bytes
(which would produce 100% NaN for synthetic all-1.0 inputs).
- CI-safe: libamdhip64.so is lazy-loaded inside the test body, and the
test SKIPs when POC_KL_DUMP_DIR / devbins / gpu_SPLIT_DATA.hex are
missing or when the GPU is not gfx950.
Variant currently shipped:
gqa_ratio=16, sub_Q=64, page_size=1, q_dtype=fp8, kv_dtype=fp8
Validated on gfx950 (MI355X):
- 6 baseline pytest cases pass.
- Byte-exact golden test passes (0/65536 differ vs poc_kl).
Co-authored-by: Cursor <cursoragent@cursor.com>
39e764a to
8abec4b
Compare
| import re | ||
| lines = captured.err.splitlines() | ||
| try: | ||
| start = next(i for i, l in enumerate(lines) |
Contributor
| import struct | ||
| return struct.unpack("<f", slot(i)[:4])[0] | ||
|
|
||
| import math, struct |
Contributor
| import struct | ||
| return struct.unpack("<f", slot(i)[:4])[0] | ||
|
|
||
| import math, struct |
Contributor
| # otherwise hard-error before the needs_gfx950 / needs_dumps skipif decorators | ||
| # get a chance to run). | ||
| # --------------------------------------------------------------------------- | ||
| import ctypes |
Contributor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The v4 new model kernel uses an 18-slot kernarg ABI that's incompatible with v3's 14-slot layout:
gqa_ratio(nots_MQA = gqa_ratio * max_seqlen_q)total_kv = kv_seq_lens * num_seqsFiles:
mla_kernel_funcsymbol sed-replaced to_ZN5aiter42mla_a8w8_qh64_qseqlen4_gqaratio16_nm_recmpEbefore reassembling with clang -x assembler --offload-arch=gfx950.mla_decode_fwd_v4_nm.Variant currently shipped:
gqa_ratio=16, sub_Q=64, page_size=1, q_dtype=fp8, kv_dtype=fp8
Validated:
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist