[CPU] Add FP32 GEMV decode kernel for GroupQueryAttention#29216
Conversation
Single-token decode (sequence_length == 1) falls back to the naive path. A dedicated FP32 decode kernel will be added in a follow-up PR. The quantized path is unchanged.
49ffc40 to
4042bd2
Compare
There was a problem hiding this comment.
Pull request overview
Adds an optimized CPU FP32 single-token decode path for com.microsoft.GroupQueryAttention, aiming to eliminate the decode regression from routing M=1 work through per-block SGEMM by introducing GEMV-based decode (and GEMV-based flash-decoding partials).
Changes:
- Add GEMV-based decode helpers/kernels for
sequence_length == 1, including optional two-phase “flash decoding” KV-chunk reduction. - Update FP32 flash gating to activate when
total_sequence_length > 1, enabling prefill via tiled flash attention and decode via the new GEMV path. - Update CPU GQA documentation to describe the new decode behavior and performance characteristics.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
onnxruntime/core/mlas/lib/flashattn_gqa.cpp |
Adds GEMV decode helpers plus new decode/flash-decoding threaded kernels and dispatch logic. |
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc |
Adjusts FP32 flash routing gate to include decode when total_sequence_length > 1. |
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h |
Allocates/partitions scratch buffers for decode vs flash-decoding; wires new args fields into MLAS call. |
docs/contrib_ops/cpu/gqa.md |
Updates docs to describe FP32 decode GEMV kernel and flash-decoding behavior. |
Adds a dedicated GEMV kernel (MlasGQADecodeGQAThreaded) for single-token decode (sequence_length == 1), and converts the flash-decoding inner M=1 GEMMs to GEMV. Re-enables the FP32 flash gate for decode (total_sequence_length > 1). Verified correctness vs naive (~1e-8); long-context decode ~1.0-1.2x, fixing the prior per-block SGEMM decode regression.
4042bd2 to
284d4a5
Compare
…test - Gate use_flash_decoding on common_past_seqlen >= 0 so the small per-thread flash-decoding scratch buffer is only selected when the unified KV-split kernel runs. Ragged/per-batch decode falls back to MlasGQADecodeGQAThreaded which needs a larger scratch (scores[total_seqlen] + temp_output[head_size]); previously it reused the small buffer and threads overran each other's scratch, producing non-deterministic output for batch>1 ragged decode. - Add test_gqa_decode_flash_vs_naive_parity comparing both the flash and naive (ORT_GQA_DISABLE_FLASH_ATTENTION=1) decode paths to the reference (addresses review thread 4). - Correct flashattn_gqa.cpp file header to describe the decode GEMV helpers (addresses review thread 1).
|
So we are replacing flash decoding with a Gemv kernel here, right (removed in the previous PR) ? Are there any cases where flash decoding will perform better than the Gemv kernel ? |
|
@hariharans29 Not replacing the prefill flash attention path. The change here is specifically for FP32 single-token decode ( Now decode uses direct GEMV helpers for the QK and SV products. There are still two decode modes:
So flash-decoding can still perform better when |
Review: PR #29216 — [CPU] Add FP32 GEMV decode kernel for GroupQueryAttentionAuthor: SummaryFollow-up to the merged #28962. That PR shipped FP32 prefill flash but explicitly skipped decode because routing the Three execution modes for the FP32 flash path now:
The flash gate in What's correct
Minor things to mention as comments
Things I would NOT ask for
VerdictApprove. Correct shape for the follow-up: surgical, addresses the regression #28962 explicitly punted on, brings back flash-decoding for the |
Description
PR1 #28962 adds flash attention for prefill, and removed flash decoding. This PR will add optimized kernel for single-token decode, which will be faster than other kernels including flash decoding.
This PR builds on the prefill-only flash attention change and additionally introduces a dedicated decode kernel.
What's included
MlasGQADecodeGQAThreaded) forsequence_length == 1, parallelized over (batch, head) with a two-pass softmax, using GEMV (acc[8]-lane dot product / AXPY) helpers instead of per-block M=1 SGEMM calls. This fixes the per-block SGEMM decode regression.group_query_attention.cc) is enabled fortotal_sequence_length > 1, routing prefill to the tiled kernel and decode to the GEMV kernel.Results (AMD EPYC 7763, AVX2, 8 threads)
Motivation and Context
The naive GQA path materializes the full score matrix, which is memory-bound for long sequences. Flash attention reduces memory traffic for prefill, and the GEMV decode kernel avoids SGEMM overhead for the M=1 decode case.
Testing
--compile_no_warning_as_error.benchmark_gqa_cpu_flash.py.