feat(flydsl): Add MQA logits FP4 kernel for gfx950#3115
Open
zhiding512 wants to merge 4 commits intomainfrom
Open
feat(flydsl): Add MQA logits FP4 kernel for gfx950#3115zhiding512 wants to merge 4 commits intomainfrom
zhiding512 wants to merge 4 commits intomainfrom
Conversation
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> Signed-off-by: zhimding <zhiming.ding@amd.com>
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds a new FlyDSL paged MQA logits implementation for gfx950 that operates on native FP4 operands (Q FP4, KV FP4) with host-side preshuffled scales/caches, plus a Python wrapper and a dedicated test/benchmark.
Changes:
- Introduces the gfx950 FP4 MQA logits kernel builder + persistent-grid varctx scheduler (
pa_mqa_logits_fp4.py). - Adds a high-level cached wrapper op (
flydsl_pa_mqa_logits_fp4) to hide build/schedule/launch boilerplate (pa_mqa_logits_kernels.py). - Adds a gfx950-gated pytest module that builds inputs (including FP4 quant + KV preshuffle) and validates correctness / prints perf (
test_flydsl_pa_mqa_logits_fp4.py). - Exposes the new op from
aiter.ops.flydsl.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
aiter/ops/flydsl/test_flydsl_pa_mqa_logits_fp4.py |
Adds gfx950-only end-to-end correctness test and embedded perf benchmark for the FP4 MQA logits kernel. |
aiter/ops/flydsl/pa_mqa_logits_kernels.py |
Adds a public, cached wrapper that schedules + launches the FP4 MQA logits kernel. |
aiter/ops/flydsl/kernels/pa_mqa_logits_fp4.py |
Adds the FP4 MQA logits kernel implementation and host-side persistent-grid scheduling helper. |
aiter/ops/flydsl/__init__.py |
Re-exports the new flydsl_pa_mqa_logits_fp4 API when FlyDSL is available. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+605
to
+621
| N_KV = k_tiles * N_TILES_PER_WARP | ||
| last_c_i32 = chunk_count - fx.Int32(1) | ||
|
|
||
| phys_pre = _load_phys(c0_i32) | ||
| kv_pre, kvs_pre = _prefetch_chunk(c0_i32, phys_pre) | ||
| phys_next_pre = _load_phys(fx.Int32(1)) | ||
|
|
||
| # === Main loop: chunk_count - 1 iterations === | ||
| # Carry layout (flat list): | ||
| # kv_cur: K_TILES * N_TILES_PER_WARP entries (kv_list[nt*K_TILES+k]) | ||
| # kvs_cur: K_TILES * N_TILES_PER_WARP entries | ||
| # phys_next: N_TILES_PER_WARP entries | ||
| # Total = (2 * K_TILES + 1) * N_TILES_PER_WARP | ||
| chunk_count_minus_1_i32 = chunk_count - fx.Int32(1) | ||
| chunk_count_minus_1_idx = fx.Index(chunk_count_minus_1_i32) | ||
| init_args = list(kv_pre) + list(kvs_pre) + list(phys_next_pre) | ||
| for c_idx, state in range(0, chunk_count_minus_1_idx, 1, init=init_args): |
Comment on lines
+49
to
+54
| """Build kernel + JIT launcher for one shape config; cached by signature. | ||
|
|
||
| Cache key includes ``max_chunks_per_cta`` because it controls compile-time | ||
| pipeline unrolling (kernel re-builds when host-side ``safe_chunks_per_cta`` | ||
| grows beyond the previously compiled bound). | ||
| """ |
Comment on lines
+144
to
+166
| batch, next_n, heads, head_dim_packed = q_packed.shape | ||
| head_dim = head_dim_packed * 2 | ||
| kv_block_size = kv_cache.shape[3] | ||
| max_blocks_per_seq = block_tables.shape[1] | ||
|
|
||
| safe, cta_info, total_ctas = compute_varctx_schedule( | ||
| context_lens, block_k, parallel_unit_num, next_n=next_n | ||
| ) | ||
|
|
||
| launch = _get_compiled_pa_mqa_logits_fp4( | ||
| block_k=block_k, | ||
| kv_block_size=kv_block_size, | ||
| max_blocks_per_seq=max_blocks_per_seq, | ||
| max_chunks_per_cta=safe, | ||
| num_warps=num_warps, | ||
| next_n=next_n, | ||
| heads=heads, | ||
| head_dim=head_dim, | ||
| ) | ||
|
|
||
| if stream is None: | ||
| stream = torch.cuda.current_stream() | ||
|
|
Comment on lines
+49
to
+52
| print( | ||
| "[test] using pa_mqa_logits_fp4_qfp4_kvfp4 kernel (Q FP4, KV FP4, MFMA(Q_fp4, KV_fp4))" | ||
| ) | ||
|
|
Comment on lines
+267
to
+276
| def test_pa_mqa_logits_fp4_qfp4_kvfp4( | ||
| batch, | ||
| max_ctx, | ||
| kv_block_size, | ||
| block_k, | ||
| next_n, | ||
| heads, | ||
| num_iters=20, | ||
| num_warmup=3, | ||
| num_warps=4, |
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> Signed-off-by: zhimding <zhiming.ding@amd.com>
Add optional `schedule` argument to `flydsl_pa_mqa_logits_fp4` so callers can hoist `flydsl_pa_mqa_logits_fp4_schedule` out of perf/inference loops and avoid the per-call D2H/H2D overhead of rebuilding the CTA assignment table. Rename the schedule helper from `compute_varctx_schedule` to match the op naming and export it from the public API. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> Signed-off-by: zhimding <zhiming.ding@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.