Optimizations for Decode Attn fp8 kernel for MI350#74
Open
mycpuorg wants to merge 3 commits into
Open
Conversation
36d31dd to
e680031
Compare
- Baseline (MI350 FP8 split‑K forward)
- Runtime ~125 µs (BF16 path was 110 µs).
- rocprof showed the Triton JIT emitted a 512‑VGPR kernel with 56 spills, so each SIMD carried only a single wavefront.
- The inner loop kept every K/V fragment plus dequant buffers resident, and the HIP autotuner could pick even heavier configs (≥150 VGPR), cementing the one‑wave bottleneck.
- Streaming K and On‑Demand V ( xformers/ops/fmha/_triton/splitk_kernels.py )
- Rewrote the FP8 branch so each quantization group loads K directly into the dot product and reloads V only when updating the accumulator.
- Removed the persistent register lists for K/V, collapsing VGPR usage to 108 with zero spills; MI350 can now run two waves per SIMD.
- Result: FP8 runtime dropped to ~105 µs, now beating BF16’s 110 µs.
- HIP Autotune Guardrails (same file)
- Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans.
- Ensures every new launch stays in the low‑register regime uncovered by the streaming change.
- Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py )
- Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present.
- Eliminates heuristics drifting at runtime and locks in the ~105 µs profile.
- HIP Autotune Guardrails (same file) - Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans. - Ensures every new launch stays in the low‑register regime uncovered by the streaming change. - Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py ) - Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present. - Eliminates heuristics drifting at runtime and locks in the ~105 µs profile.
e680031 to
6ef14fb
Compare
|
Hi @mycpuorg, I tried you PR with the size 8193 using the benchmark script https://github.com/scxiao/xformers/blob/scxiao/attn_decode_perf_kv_fp8/xformers/benchmarks/benchmark_attn_decoding.py It seem like the benchmark time for fp8 is 184us. How can I get the good perf number 105us? |
scxiao
reviewed
Sep 25, 2025
mycpuorg
commented
Sep 25, 2025
Author
mycpuorg
left a comment
There was a problem hiding this comment.
PYTHONPATH=../.. AMDGCN_USE_BUFFER_OPS=1 python benchmark_attn_decoding.py
on top of your branch scxiao/attn_decode_perf_kv_fp8/xformers was this was tested
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.