Skip to content

[CPU] Add FP32 GEMV decode kernel for GroupQueryAttention#29216

Merged
tianleiwu merged 3 commits into
mainfrom
tlwu/20260608/gqa_cpu_decode_gemv
Jun 26, 2026
Merged

[CPU] Add FP32 GEMV decode kernel for GroupQueryAttention#29216
tianleiwu merged 3 commits into
mainfrom
tlwu/20260608/gqa_cpu_decode_gemv

Conversation

@tianleiwu

@tianleiwu tianleiwu commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Description

PR1 #28962 adds flash attention for prefill, and removed flash decoding. This PR will add optimized kernel for single-token decode, which will be faster than other kernels including flash decoding.

This PR builds on the prefill-only flash attention change and additionally introduces a dedicated decode kernel.

What's included

  • Decode (GEMV) kernel — A dedicated single-token decode kernel (MlasGQADecodeGQAThreaded) for sequence_length == 1, parallelized over (batch, head) with a two-pass softmax, using GEMV (acc[8]-lane dot product / AXPY) helpers instead of per-block M=1 SGEMM calls. This fixes the per-block SGEMM decode regression.
  • The FP32 flash gate (group_query_attention.cc) is enabled for total_sequence_length > 1, routing prefill to the tiled kernel and decode to the GEMV kernel.
  • The quantized KV-cache path is unchanged (FP32-only scope).

Results (AMD EPYC 7763, AVX2, 8 threads)

  • Decode: correctness ~1e-8 vs naive; long-context decode ~1.0–1.5x (T = 4097 ~1.3–1.5x).

Motivation and Context

The naive GQA path materializes the full score matrix, which is memory-bound for long sequences. Flash attention reduces memory traffic for prefill, and the GEMV decode kernel avoids SGEMM overhead for the M=1 decode case.

Testing

  • Built with --compile_no_warning_as_error.
  • Correctness verified against the naive path for both prefill and decode (max abs diff ~1e-8).
  • Benchmarked via benchmark_gqa_cpu_flash.py.

@tianleiwu tianleiwu changed the title [CPU] Add FP32 flash attention (prefill) and GEMV decode kernel for GroupQueryAttention [CPU] Add FP32 GEMV decode kernel for GroupQueryAttention Jun 23, 2026
Single-token decode (sequence_length == 1) falls back to the naive path. A dedicated FP32 decode kernel will be added in a follow-up PR. The quantized path is unchanged.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an optimized CPU FP32 single-token decode path for com.microsoft.GroupQueryAttention, aiming to eliminate the decode regression from routing M=1 work through per-block SGEMM by introducing GEMV-based decode (and GEMV-based flash-decoding partials).

Changes:

  • Add GEMV-based decode helpers/kernels for sequence_length == 1, including optional two-phase “flash decoding” KV-chunk reduction.
  • Update FP32 flash gating to activate when total_sequence_length > 1, enabling prefill via tiled flash attention and decode via the new GEMV path.
  • Update CPU GQA documentation to describe the new decode behavior and performance characteristics.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
onnxruntime/core/mlas/lib/flashattn_gqa.cpp Adds GEMV decode helpers plus new decode/flash-decoding threaded kernels and dispatch logic.
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Adjusts FP32 flash routing gate to include decode when total_sequence_length > 1.
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Allocates/partitions scratch buffers for decode vs flash-decoding; wires new args fields into MLAS call.
docs/contrib_ops/cpu/gqa.md Updates docs to describe FP32 decode GEMV kernel and flash-decoding behavior.

Comment thread onnxruntime/core/mlas/lib/flashattn_gqa.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/flashattn_gqa.cpp
Comment thread onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Adds a dedicated GEMV kernel (MlasGQADecodeGQAThreaded) for single-token decode (sequence_length == 1), and converts the flash-decoding inner M=1 GEMMs to GEMV. Re-enables the FP32 flash gate for decode (total_sequence_length > 1). Verified correctness vs naive (~1e-8); long-context decode ~1.0-1.2x, fixing the prior per-block SGEMM decode regression.
@tianleiwu tianleiwu force-pushed the tlwu/20260608/gqa_cpu_decode_gemv branch from 4042bd2 to 284d4a5 Compare June 25, 2026 02:36
…test

- Gate use_flash_decoding on common_past_seqlen >= 0 so the small per-thread
  flash-decoding scratch buffer is only selected when the unified KV-split
  kernel runs. Ragged/per-batch decode falls back to MlasGQADecodeGQAThreaded
  which needs a larger scratch (scores[total_seqlen] + temp_output[head_size]);
  previously it reused the small buffer and threads overran each other's
  scratch, producing non-deterministic output for batch>1 ragged decode.
- Add test_gqa_decode_flash_vs_naive_parity comparing both the flash and naive
  (ORT_GQA_DISABLE_FLASH_ATTENTION=1) decode paths to the reference (addresses
  review thread 4).
- Correct flashattn_gqa.cpp file header to describe the decode GEMV helpers
  (addresses review thread 1).
@tianleiwu tianleiwu marked this pull request as ready for review June 25, 2026 04:50
@tianleiwu tianleiwu requested a review from Copilot June 25, 2026 04:50
@tianleiwu tianleiwu requested a review from hariharans29 June 25, 2026 04:51

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated no new comments.

@hariharans29

hariharans29 commented Jun 25, 2026

Copy link
Copy Markdown
Member

So we are replacing flash decoding with a Gemv kernel here, right (removed in the previous PR) ? Are there any cases where flash decoding will perform better than the Gemv kernel ?

Comment thread onnxruntime/core/mlas/lib/flashattn_gqa.cpp
Comment thread onnxruntime/core/mlas/lib/flashattn_gqa.cpp
Comment thread onnxruntime/core/mlas/lib/flashattn_gqa.cpp
@tianleiwu

Copy link
Copy Markdown
Contributor Author

@hariharans29 Not replacing the prefill flash attention path. The change here is specifically for FP32 single-token decode (sequence_length == 1). Previously that path could route the degenerate M = 1 QK/SV work through SGEMM inside the flash/flash-decoding code, which paid setup/packing overhead and was slower for decode.

Now decode uses direct GEMV helpers for the QK and SV products. There are still two decode modes:

  • single-pass decode, parallelized over (batch, head), when B * N has enough work for the thread pool;
  • two-phase flash-decoding, when B * N < thread_count, where KV is split across chunks so otherwise idle threads can help. That path still exists, but each chunk now uses GEMV instead of M = 1 SGEMM.

So flash-decoding can still perform better when B * N is small and the KV length is long enough that splitting the KV dimension gives useful extra parallelism. For short/medium decode it is mostly overhead-bound and is around parity; the main gain is avoiding the old SGEMM overhead/regression.

@tianleiwu tianleiwu requested a review from hariharans29 June 26, 2026 00:30
@hariharans29

Copy link
Copy Markdown
Member

Review: PR #29216 — [CPU] Add FP32 GEMV decode kernel for GroupQueryAttention

Author: tianleiwu · Branch: tlwu/20260608/gqa_cpu_decode_gemvmain
Size: +707 / −39 across 6 files · CI: 86 / 86 green · Status: awaiting approval; reviewer thread already engaged


Summary

Follow-up to the merged #28962. That PR shipped FP32 prefill flash but explicitly skipped decode because routing the M = 1 QK/SV work through MlasSgemmOperation was slower than naive (~0.4–0.6×) — SGEMM B-packing setup cost dominated the tiny work. This PR fixes that by introducing direct GEMV helpers for decode and re-enabling flash decoding with those helpers under the hood.

Three execution modes for the FP32 flash path now:

Case Dispatch Inner kernel
sequence_length > 1 (prefill) tiled MlasFlashAttentionGQAThreaded (unchanged) SGEMM
sequence_length == 1, B·N < threads, kv_chunks > 1 two-phase MlasFlashDecodingGQAThreaded + MlasFlashDecodingGQAReduceThreaded GEMV
sequence_length == 1, otherwise single-pass MlasGQADecodeGQAThreaded GEMV

The flash gate in group_query_attention.cc flips from sequence_length > 1 to total_sequence_length > 1, so decode (s = 1 with non-empty KV) now reaches the dispatcher.


What's correct

  • GEMV implementation. MlasGQADecodeQK uses an 8-lane explicit accumulator with a paired reduction tree (((a0+a1)+(a2+a3)) + ((a4+a5)+(a6+a7))) — this is the textbook pattern that lets a compiler emit SIMD FMAs without -ffast-math because reassociation isn't required. Placed in the baseline-ISA TU so SSE2 codegen is guaranteed. Correct choice.
  • Flash-decoding reduce math is the standard FA-2 form. Phase 1 stores (m_c, l_c, partial_out_c) per chunk. Phase 2 finds global_m, rescales each chunk by exp(m_c - global_m), accumulates global_l = Σ rescale·l_c and output = Σ rescale·partial_out_c, then divides by global_l. Masked-chunk sentinels (m_c = lowest, l_c = 0) are handled correctly. Race-free: each Phase-1 task writes to a unique partials slot; MlasExecuteThreaded between phases acts as the barrier.
  • Decode causal masking absence is correct. The single new query is at position past_seqlen; it attends to KV positions [0, past_seqlen], i.e. the entire cache. No causal mask needed — only optional local-window and bias.
  • Buffer sizing is mode-aware. gqa_attention_base.h picks the right per-thread scratch (kv_block_size for flash-decoding, total_seqlen + head_size for single-pass decode, prefill layout otherwise) plus a shared partials buffer for flash decoding, all via SafeInt<size_t>. Partials pointer placement (flash_buffer + per_thread_scratch * thread_count) is correct.
  • Policy gate is in the right place. The "single-pass vs flash-decoding" decision is in the EP layer (gqa_attention_base.h), signalled to MLAS via flash_decoding_partials = nullptr | allocated. Clean separation.
  • The ragged-seqlens determinism fix in the latest commit (bb024f2) is the right shape — ragged paths force common_past_seqlen = -1, which makes the EP set flash_decoding_partials = nullptr and kv_chunk_count = 0, so ragged decode always uses the single-pass kernel (deterministic per (batch, head)). The KV-split reduce was the thread-order-dependent piece, and excluding it from ragged is the simplest correct fix.
  • Parity test in test_gqa_cpu.pytest_gqa_decode_flash_vs_naive_parity runs decode configs twice (flash on, flash off via ORT_GQA_DISABLE_FLASH_ATTENTION flipped by scoped_env_var). The (1, 2048) config explicitly exercises kv_chunk_count > 1. Addresses the test-coverage gap from Add flash attention for non-quantized CPU GroupQueryAttention #28962.

Minor things to mention as comments

  1. Vestigial causal mask in MlasFlashDecodingGQAThreaded. With past_seqlen = total_seqlen - 1 and causal_limit = past_seqlen + 1 = total_seqlen, the if (kv_pos >= causal_limit) loop never fires because kv_pos < row_size_kv ≤ total_seqlen - ir. It's a no-op for decode. Either remove it or add a one-line comment:

    // no-op for s=1; kept for robustness if past_seqlen ever != total_seqlen-1

    Stylistic.

  2. MlasGQADecodeSV doesn't have explicit lane unrolling the way MlasGQADecodeQK does — it's a plain doubly-nested scalar loop with map-style out[h] += p * vrow[h]. The compiler should still auto-vectorize since the inner loop has independent updates across h, but worth eyeballing the codegen at -O2//O2 on SSE2 baseline to confirm. If it doesn't vectorize, SV becomes the bottleneck of decode (it's the bigger of the two GEMVs since it streams V and writes accumulators).

  3. Parity test grid is narrower than the manual sweep behind Add flash attention for non-quantized CPU GroupQueryAttention #28962. Just h_sizes={64,128}, num_h={(9,3)}, batches={1,3}, two seq configs, with/without local/bias ≈ 16 configs × 2 env states = 32. Author's earlier 600-config sweep covered all head sizes 32–256 and four GQA ratios. Not a blocker — committing even a thin parity test is a real improvement over zero — but consider a follow-up to widen the grid as a slow-marked test.

  4. Doc table for decode now shows ≈1.0–1.1× speedup for short/medium T and 1.35–1.5× at T = 4097, with the honest caveat that short-decode is overhead-bound. That's the right framing — the win here is not regressing (vs. Add flash attention for non-quantized CPU GroupQueryAttention #28962's pre-GEMV ~0.4–0.6×) plus the long-context gain. Don't oversell it.


Things I would NOT ask for

  • Removing the vestigial causal mask in flash-decoding Phase 1 — it's free defense if past_seqlen ever gets used differently (e.g., speculative decoding with s = 1 but past_seqlen != total_seqlen - 1). Just comment it.
  • Per-ISA specialized GEMV helpers (AVX/AVX-512 explicit intrinsics). The current map-style pattern in the baseline TU is the right starting point; targeted specialization can be a follow-up if profiling shows a specific ISA leaving perf on the table.

Verdict

Approve. Correct shape for the follow-up: surgical, addresses the regression #28962 explicitly punted on, brings back flash-decoding for the B·N < threads case (now profitable because GEMV avoids the SGEMM overhead that killed it before), and adds the parity testing the previous review asked for. The math and threading are clean. Already-resolved Copilot comments and the existing reviewer thread cover the rest.

@tianleiwu tianleiwu enabled auto-merge (squash) June 26, 2026 06:29
@tianleiwu tianleiwu merged commit a203dfa into main Jun 26, 2026
87 checks passed
@tianleiwu tianleiwu deleted the tlwu/20260608/gqa_cpu_decode_gemv branch June 26, 2026 06:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants