Skip to content

Optimizations for Decode Attn fp8 kernel for MI350#74

Open
mycpuorg wants to merge 3 commits into
ROCm:developfrom
mycpuorg:manrao/decode_attn_fp8
Open

Optimizations for Decode Attn fp8 kernel for MI350#74
mycpuorg wants to merge 3 commits into
ROCm:developfrom
mycpuorg:manrao/decode_attn_fp8

Conversation

@mycpuorg
Copy link
Copy Markdown

@mycpuorg mycpuorg commented Sep 18, 2025

  • Baseline (MI350 FP8 split‑K forward) - Runtime ~125 µs (BF16 path was 110 µs). - rocprof showed the Triton JIT emitted a 512‑VGPR kernel with 56 spills, so each SIMD carried only a single wavefront. - The inner loop kept every K/V fragment plus dequant buffers resident, and the HIP autotuner could pick even heavier configs (≥150 VGPR), cementing the one‑wave bottleneck.
  • Streaming K and On‑Demand V ( xformers/ops/fmha/_triton/splitk_kernels.py ) - Rewrote the FP8 branch so each quantization group loads K directly into the dot product and reloads V only when updating the accumulator. - Removed the persistent register lists for K/V, collapsing VGPR usage to 108 with zero spills; MI350 can now run two waves per SIMD. - Result: FP8 runtime dropped to ~105 µs, now beating BF16’s 110 µs.
  • HIP Autotune Guardrails (same file)
    • Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans.
    • Ensures every new launch stays in the low‑register regime uncovered by the streaming change.
  • Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py ) - Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present. - Eliminates heuristics drifting at runtime and locks in the ~101µs profile.

@mycpuorg mycpuorg changed the title Optimizations for Decode Attn fp8 kernel Optimizations for Decode Attn fp8 kernel for MI350 Sep 18, 2025
@mycpuorg mycpuorg force-pushed the manrao/decode_attn_fp8 branch from 36d31dd to e680031 Compare September 19, 2025 19:40
 - Baseline (MI350 FP8 split‑K forward)
      - Runtime ~125 µs (BF16 path was 110 µs).
      - rocprof showed the Triton JIT emitted a 512‑VGPR kernel with 56 spills, so each SIMD carried only a single wavefront.
      - The inner loop kept every K/V fragment plus dequant buffers resident, and the HIP autotuner could pick even heavier configs (≥150 VGPR), cementing the one‑wave bottleneck.
  - Streaming K and On‑Demand V ( xformers/ops/fmha/_triton/splitk_kernels.py )
      - Rewrote the FP8 branch so each quantization group loads K directly into the dot product and reloads V only when updating the accumulator.
      - Removed the persistent register lists for K/V, collapsing VGPR usage to 108 with zero spills; MI350 can now run two waves per SIMD.
      - Result: FP8 runtime dropped to ~105 µs, now beating BF16’s 110 µs.
  - HIP Autotune Guardrails (same file)
      - Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans.
      - Ensures every new launch stays in the low‑register regime uncovered by the streaming change.
  - Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py )
      - Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present.
      - Eliminates heuristics drifting at runtime and locks in the ~105 µs profile.
- HIP Autotune Guardrails (same file)
- Constrained the HIP autotuner to tiles ≤64×64 and ≤4 warps, preventing Triton from revisiting the high‑VGPR plans.
- Ensures every new launch stays in the low‑register regime uncovered by the streaming change.
- Forced HIP FP8 Launch Parameters ( xformers/ops/fmha/triton_splitk.py )
- Added FwOp.force_kernel_config and, by default, return the measured best tuple (BLOCK_M=16, BLOCK_N=64, num_stages=2, num_warps=1) whenever FP8 scale/shift tensors are present.
- Eliminates heuristics drifting at runtime and locks in the ~105 µs profile.
@scxiao
Copy link
Copy Markdown

scxiao commented Sep 25, 2025

Hi @mycpuorg, I tried you PR with the size 8193 using the benchmark script https://github.com/scxiao/xformers/blob/scxiao/attn_decode_perf_kv_fp8/xformers/benchmarks/benchmark_attn_decoding.py
, here is what I got:

root@asrock-126-009c:/workspace/projects/xformers/xformers/benchmarks# AMDGCN_USE_BUFFER_OPS=1  python benchmark_attn_decoding.py
====== {'B': 128, 'Mq': 1, 'Mkv': 32769, 'Hq': 8, 'Hkv': 1, 'K': 128, 'attn_bias_type': <class 'xformers.ops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask'>} ======                                                                                                                    
  0%|                                                                                                                                                                                                                                                                          | 0/1 [00:00<?, ?it/s]/workspace/projects/xformers/xformers/benchmarks/benchmark_attn_decoding.py:666: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/Context.cpp:328.)
  attn = (q @ k.transpose(-1, -2) * scale).softmax(-1)
pytorch: memory used: 10368.81787109375 MB                                                                                                                                                                                                                                                           
  0%|                                                                                                                                                                                                                                                                          | 0/1 [00:01<?, ?it/s]extra_args = {'BLOCK_M': 16, 'BLOCK_N': 64, 'num_warps': 1, 'num_stages': 2}
extra_args = {'BLOCK_M': 16, 'BLOCK_N': 64, 'num_warps': 1, 'num_stages': 2}
optimized: memory used: 6200.43994140625 MB                                                                                                                                                                                                                                                          
[---------------------------- attn_decodingfw ----------------------------]                                                                                                                                                                                                                          
                                                                   |       
1 threads: ----------------------------------------------------------------
      B=128 Mq=1 Mkv=32769 Hq=8 Hkv=1 K=128 TotalBytes=1074036736  |  184.1

Times are in microseconds (us).

[------------------------------------ attn_decodingfw ------------------------------------]                                                                                                                                                                                                          
                                                                   |  pytorch  |  optimized
1 threads: --------------------------------------------------------------------------------
      B=128 Mq=1 Mkv=32769 Hq=8 Hkv=1 K=128 TotalBytes=2148073472  |  30933.4  |           
      B=128 Mq=1 Mkv=32769 Hq=8 Hkv=1 K=128 TotalBytes=1074036736  |           |    184.1  

It seem like the benchmark time for fp8 is 184us. How can I get the good perf number 105us?

Comment thread xformers/ops/fmha/_triton/splitk_kernels.py
Copy link
Copy Markdown
Author

@mycpuorg mycpuorg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PYTHONPATH=../.. AMDGCN_USE_BUFFER_OPS=1 python benchmark_attn_decoding.py

on top of your branch scxiao/attn_decode_perf_kv_fp8/xformers was this was tested

Comment thread xformers/ops/fmha/_triton/splitk_kernels.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants