Commit 8abec4b
[MLA ASM kernel] add v4 MLA decode for mi350 (recompile mi300)
The v4 new-model kernel uses an 18-slot kernarg ABI that's incompatible
with v3's 14-slot layout:
- slot 8 raw `gqa_ratio` (not `s_MQA = gqa_ratio * max_seqlen_q`)
- slot 10 `total_kv = kv_seq_lens * num_seqs`
- 4 new tail slots: ptr_STP, out_16_nosplit, ptr_QROPE, ptr_KVROPE
- scalar = 1/sqrt(kV4DimNope + kV4DimRope) = 1/sqrt(512), independent of
head_size
- kernel hardcodes wave64 (gfx950) block dim = wv_tg * 64 = 256
Integration follows the canonical aiter JIT module pattern (peer of
asm_mla.cu / module_mla_asm, not the compile_template_op jinja path):
- csrc/py_itfs_cu/asm_mla_v4.cu: C-ABI dispatcher using aiter_tensor_t*,
wrapped in AITER_CTYPES_DEFINE_ENTRYPOINT_VOID so AITER_CHECK failures
surface as clean Python RuntimeError via the aiter_get_last_error TLS
bridge (no std::abort()).
- hsa/gfx950/mla_v4/mla_v4_asm.csv: CSV registry consumed by
hsa/codegen.py -m mla_v4 -> asm_mla_v4_configs.hpp / cfg_mla_v4_asm.
Schema (qType, kvType, gqa, sub_Q, page_size, num_kv_splits, prefill,
causal, knl_name, co_name) is independent of mla_asm.csv so v3 / v4
dispatchers don't collide.
- hsa/gfx950/mla_v4/mla_a8w8_qh64_qseqlen4_gqaratio16_nm_recmp.co:
compiled shader. Built from poc_kl/mi350/mla_asm/shaders/mla_a8w8_qh64
_1tg_16mx4_16nx1_nm_recompile.s with the literal `mla_kernel_func`
symbol sed-replaced to `_ZN5aiter42mla_a8w8_qh64_qseqlen4_gqaratio16
_nm_recmpE` then reassembled with `clang -x assembler
--offload-arch=gfx950`.
- aiter/jit/optCompilerConfig.json: module_mla_v4_asm entry with
blob_gen_cmd = codegen.py -m mla_v4.
- aiter/ops/attention.py: mla_decode_v4_asm Python signature
(@compile_ops ffi_type="ctypes").
- aiter/mla.py: thin wrapper `mla_decode_fwd_v4_nm` routes through
aiter.mla_decode_v4_asm and allocates the canonical 5D
[num_seqs, num_kv_splits, num_kv_heads, gqa*max_seqlen_q, v_head_dim]
split-logits / attn_lse internally.
- aiter/utility/dtypes.py: _torch_to_aiter_dtype now accepts both
torch.float8_e4m3fn (OCP) and torch.float8_e4m3fnuz (NUZ); both map
to AITER_DTYPE_fp8, so callers on gfx942 / gfx950 can use either
variant transparently.
Tests (op_tests/test_mla_v4_nm.py):
- test_v4_nm_smoke_default_shape : dispatcher -> kernel launch ->
SENTINEL coverage check.
- test_v4_nm_no_half_zero_pattern : wave64-launch regression guard
(the historic 256-NaN + 256-zero
output pattern on gfx950).
- test_v4_nm_determinism : two back-to-back launches are
bit-identical.
- test_v4_nm_out_16_nosplit_arg_accepted
: out_16_nosplit=1 path doesn't
crash the dispatcher.
- test_v4_nm_unknown_variant_raises : sub_Q=128 (not in cfg_mla_v4_asm)
triggers a clean RuntimeError via
the C-ABI error bridge, no
silently-loaded wrong .co.
- test_v4_nm_kernarg_scalar_slots : env-gated kernarg hexdump
captured via capfd; scalar slots
7-12 + 15 byte-exact match
poc_kl's MlaV4HipKernelArgs.
Numerical correctness (op_tests/test_mla_v4_nm_golden.py):
- Byte-exact compare against poc_kl gpu_SPLIT_DATA.hex. Inputs loaded
via raw hipMemcpy H2D from device-side `.bin` dumps produced by an
env-gated patch on poc_kl side
(POC_KL_V4_DUMP_DEVBINS=1 in mla_execute_v4_hip.inl). Bypasses the
brittle path of reconstructing inputs from hex dumps; sidesteps the
v4 nm kernel's degenerate code path on uniform e8m0 scale bytes
(which would produce 100% NaN for synthetic all-1.0 inputs).
- CI-safe: libamdhip64.so is lazy-loaded inside the test body, and the
test SKIPs when POC_KL_DUMP_DIR / devbins / gpu_SPLIT_DATA.hex are
missing or when the GPU is not gfx950.
Variant currently shipped:
gqa_ratio=16, sub_Q=64, page_size=1, q_dtype=fp8, kv_dtype=fp8
Validated on gfx950 (MI355X):
- 6 baseline pytest cases pass.
- Byte-exact golden test passes (0/65536 differ vs poc_kl).
Co-authored-by: Cursor <cursoragent@cursor.com>1 parent 45c428e commit 8abec4b
11 files changed
Lines changed: 1135 additions & 1 deletion
File tree
- aiter
- jit
- ops
- utility
- csrc/py_itfs_cu
- hsa/gfx950
- mla_v4
- mla
- op_tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
138 | 138 | | |
139 | 139 | | |
140 | 140 | | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
141 | 152 | | |
142 | 153 | | |
143 | 154 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
973 | 973 | | |
974 | 974 | | |
975 | 975 | | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
| 980 | + | |
| 981 | + | |
| 982 | + | |
| 983 | + | |
| 984 | + | |
| 985 | + | |
| 986 | + | |
| 987 | + | |
| 988 | + | |
| 989 | + | |
| 990 | + | |
| 991 | + | |
| 992 | + | |
| 993 | + | |
| 994 | + | |
| 995 | + | |
| 996 | + | |
| 997 | + | |
| 998 | + | |
| 999 | + | |
| 1000 | + | |
| 1001 | + | |
| 1002 | + | |
| 1003 | + | |
| 1004 | + | |
| 1005 | + | |
| 1006 | + | |
| 1007 | + | |
| 1008 | + | |
| 1009 | + | |
| 1010 | + | |
| 1011 | + | |
| 1012 | + | |
| 1013 | + | |
| 1014 | + | |
| 1015 | + | |
| 1016 | + | |
| 1017 | + | |
| 1018 | + | |
| 1019 | + | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
| 1029 | + | |
| 1030 | + | |
| 1031 | + | |
| 1032 | + | |
| 1033 | + | |
| 1034 | + | |
| 1035 | + | |
| 1036 | + | |
| 1037 | + | |
| 1038 | + | |
| 1039 | + | |
| 1040 | + | |
| 1041 | + | |
| 1042 | + | |
| 1043 | + | |
| 1044 | + | |
| 1045 | + | |
| 1046 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
688 | 688 | | |
689 | 689 | | |
690 | 690 | | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
| 719 | + | |
| 720 | + | |
| 721 | + | |
| 722 | + | |
| 723 | + | |
| 724 | + | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
691 | 733 | | |
692 | 734 | | |
693 | 735 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
42 | 51 | | |
43 | 52 | | |
44 | 53 | | |
| |||
0 commit comments