Skip to content

feat(flydsl): Add MQA logits FP4 kernel for gfx950#3115

Open
zhiding512 wants to merge 4 commits intomainfrom
zhimding/mqa_logits_fp4_0509
Open

feat(flydsl): Add MQA logits FP4 kernel for gfx950#3115
zhiding512 wants to merge 4 commits intomainfrom
zhimding/mqa_logits_fp4_0509

Conversation

@zhiding512
Copy link
Copy Markdown
Contributor

No description provided.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Signed-off-by: zhimding <zhiming.ding@amd.com>
@zhiding512 zhiding512 requested review from a team and Copilot May 11, 2026 03:49
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3115 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new FlyDSL paged MQA logits implementation for gfx950 that operates on native FP4 operands (Q FP4, KV FP4) with host-side preshuffled scales/caches, plus a Python wrapper and a dedicated test/benchmark.

Changes:

  • Introduces the gfx950 FP4 MQA logits kernel builder + persistent-grid varctx scheduler (pa_mqa_logits_fp4.py).
  • Adds a high-level cached wrapper op (flydsl_pa_mqa_logits_fp4) to hide build/schedule/launch boilerplate (pa_mqa_logits_kernels.py).
  • Adds a gfx950-gated pytest module that builds inputs (including FP4 quant + KV preshuffle) and validates correctness / prints perf (test_flydsl_pa_mqa_logits_fp4.py).
  • Exposes the new op from aiter.ops.flydsl.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

File Description
aiter/ops/flydsl/test_flydsl_pa_mqa_logits_fp4.py Adds gfx950-only end-to-end correctness test and embedded perf benchmark for the FP4 MQA logits kernel.
aiter/ops/flydsl/pa_mqa_logits_kernels.py Adds a public, cached wrapper that schedules + launches the FP4 MQA logits kernel.
aiter/ops/flydsl/kernels/pa_mqa_logits_fp4.py Adds the FP4 MQA logits kernel implementation and host-side persistent-grid scheduling helper.
aiter/ops/flydsl/__init__.py Re-exports the new flydsl_pa_mqa_logits_fp4 API when FlyDSL is available.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +605 to +621
N_KV = k_tiles * N_TILES_PER_WARP
last_c_i32 = chunk_count - fx.Int32(1)

phys_pre = _load_phys(c0_i32)
kv_pre, kvs_pre = _prefetch_chunk(c0_i32, phys_pre)
phys_next_pre = _load_phys(fx.Int32(1))

# === Main loop: chunk_count - 1 iterations ===
# Carry layout (flat list):
# kv_cur: K_TILES * N_TILES_PER_WARP entries (kv_list[nt*K_TILES+k])
# kvs_cur: K_TILES * N_TILES_PER_WARP entries
# phys_next: N_TILES_PER_WARP entries
# Total = (2 * K_TILES + 1) * N_TILES_PER_WARP
chunk_count_minus_1_i32 = chunk_count - fx.Int32(1)
chunk_count_minus_1_idx = fx.Index(chunk_count_minus_1_i32)
init_args = list(kv_pre) + list(kvs_pre) + list(phys_next_pre)
for c_idx, state in range(0, chunk_count_minus_1_idx, 1, init=init_args):
Comment on lines +49 to +54
"""Build kernel + JIT launcher for one shape config; cached by signature.

Cache key includes ``max_chunks_per_cta`` because it controls compile-time
pipeline unrolling (kernel re-builds when host-side ``safe_chunks_per_cta``
grows beyond the previously compiled bound).
"""
Comment on lines +144 to +166
batch, next_n, heads, head_dim_packed = q_packed.shape
head_dim = head_dim_packed * 2
kv_block_size = kv_cache.shape[3]
max_blocks_per_seq = block_tables.shape[1]

safe, cta_info, total_ctas = compute_varctx_schedule(
context_lens, block_k, parallel_unit_num, next_n=next_n
)

launch = _get_compiled_pa_mqa_logits_fp4(
block_k=block_k,
kv_block_size=kv_block_size,
max_blocks_per_seq=max_blocks_per_seq,
max_chunks_per_cta=safe,
num_warps=num_warps,
next_n=next_n,
heads=heads,
head_dim=head_dim,
)

if stream is None:
stream = torch.cuda.current_stream()

Comment on lines +49 to +52
print(
"[test] using pa_mqa_logits_fp4_qfp4_kvfp4 kernel (Q FP4, KV FP4, MFMA(Q_fp4, KV_fp4))"
)

Comment on lines +267 to +276
def test_pa_mqa_logits_fp4_qfp4_kvfp4(
batch,
max_ctx,
kv_block_size,
block_k,
next_n,
heads,
num_iters=20,
num_warmup=3,
num_warps=4,
zhiding512 and others added 2 commits May 11, 2026 05:41
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Signed-off-by: zhimding <zhiming.ding@amd.com>
Add optional `schedule` argument to `flydsl_pa_mqa_logits_fp4` so callers
can hoist `flydsl_pa_mqa_logits_fp4_schedule` out of perf/inference loops
and avoid the per-call D2H/H2D overhead of rebuilding the CTA assignment
table. Rename the schedule helper from `compute_varctx_schedule` to match
the op naming and export it from the public API.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Signed-off-by: zhimding <zhiming.ding@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants