Skip to content

Commit 49ffc40

Browse files
committed
Add CPU FP32 GQA GEMV decode kernel
Adds a dedicated GEMV kernel (MlasGQADecodeGQAThreaded) for single-token decode (sequence_length == 1), and converts the flash-decoding inner M=1 GEMMs to GEMV. Re-enables the FP32 flash gate for decode (total_sequence_length > 1). Verified correctness vs naive (~1e-8); long-context decode ~1.0-1.2x, fixing the prior per-block SGEMM decode regression.
1 parent 4a4e845 commit 49ffc40

4 files changed

Lines changed: 310 additions & 59 deletions

File tree

docs/contrib_ops/cpu/gqa.md

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,45 @@ The non-quantized flash path is selected when ALL of the following hold:
258258
- No output QK capture
259259
- `present_key` and `present_value` are provided
260260

261-
Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffers are all supported for prefill, mirroring the quantized flash path. The non-quantized flash path is currently selected for prefill only (`sequence_length > 1`); single-token decode falls back to the naive full-materialization path (a dedicated decode kernel is added in a follow-up change). When any supported condition is not met, the kernel also falls back to the naive path.
261+
Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path.
262262

263263
### Block Sizes, Threading, and Flash Decoding
264264

265-
Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) for prefill are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. The two-phase flash-decoding strategy for single-token decode is gated off for the non-quantized path in this PR (decode falls back to naive); it is enabled together with the dedicated decode kernel in a follow-up change.
265+
Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization.
266+
267+
#### Decode uses a dedicated GEMV kernel (`sequence_length == 1`)
268+
269+
The tiled online-softmax SGEMM kernel (`MlasFlashAttentionGQAThreaded`) is used **only for
270+
prefill** (`sequence_length > 1`), where each KV tile is reused across the `q_block_size`
271+
query rows and tiling delivers real cache-locality and SGEMM packing benefits.
272+
273+
For single-token decode the query tile has `M = 1`, so every K/V element is streamed
274+
exactly once with no reuse across query rows. Tiling provides **no** cache-locality
275+
benefit, and routing the `1 × T × H` work through `MlasSgemmOperation` pays the SGEMM
276+
B-packing/setup cost on every call — which previously made the flash decode path *slower*
277+
than the naive path (≈0.4–0.6x) for short-to-medium total sequence lengths.
278+
279+
Decode is therefore handled by a dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`),
280+
dispatched whenever `sequence_length == 1` and flash decoding is not active. It
281+
parallelizes over `(batch, head)` and, per head, computes the attention directly with two
282+
matrix-vector products and a two-pass softmax:
283+
284+
- **QK GEMV**`scores[t] = scale · dot(q, K[t])` for `t ∈ [0, total_seqlen)`.
285+
- two-pass softmax over `scores` using the dispatched `ReduceMaximumF32Kernel` /
286+
`ComputeSumExpF32Kernel` helpers.
287+
- **SV GEMV**`out[h] = Σ_t probs[t] · V[t][h]`, then normalize by `1/Σ probs`.
288+
289+
Both GEMV helpers (`MlasGQADecodeQK`, `MlasGQADecodeSV`) live in the baseline-ISA MLAS
290+
translation unit, so their inner loops use independent accumulator lanes / map-style
291+
updates that vectorize under SSE2 without `-ffast-math`. Decode needs no causal mask (the
292+
single new token is the most recent position and attends to every cached token); only
293+
optional local-window masking and additive attention bias are applied. The kernel streams
294+
K and V exactly once each, so it is memory-bandwidth bound.
295+
296+
The two-phase flash-decoding path (active when `batch × heads < threads`, KV partitioned
297+
across idle threads) now also uses these GEMV helpers for its per-chunk QK and SV products
298+
instead of `M = 1` SGEMM calls, removing the same packing overhead.
299+
266300

267301
## MLAS Dispatch Paths
268302

@@ -516,11 +550,29 @@ algorithmic rather than purely from threading.
516550

517551
#### Latency — Decode (S = 1, token generation)
518552

519-
Single-token decode (`sequence_length == 1`) currently falls back to the naive path for the
520-
non-quantized FP32 cache: the flash path is gated on `sequence_length > 1` (prefill only),
521-
because routing the tiny `1 × T × H` decode work through the tiled SGEMM kernel pays
522-
per-block GEMM setup overhead with no tiling reuse benefit. A dedicated FP32 decode kernel
523-
is added in a follow-up change.
553+
For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so
554+
flash decoding KV-partitioning is not active), the workload per `Run` is tiny (a `1 × T × H`
555+
GEMV pair per head) and operator-level latency is dominated by fixed per-`Run` overhead
556+
(session dispatch, KV-cache concatenation), so operator-level measurements on the EPYC dev
557+
box are extremely noisy. The numbers below come from a min-of-many-repeats MLAS-path harness
558+
to suppress that jitter.
559+
560+
| Total Seqlen | Naive (ms) | Flash (ms) | Speedup |
561+
|---:|---:|---:|---:|
562+
| 513 | 0.50 | 0.42 | ~1.0\u20131.2x (noisy) |
563+
| 1025 | 0.78 | 0.69 | ~1.0\u20131.1x (noisy) |
564+
| 2049 | 1.89 | 1.73 | ~1.0\u20131.1x (noisy) |
565+
| 4097 | 6.1 | 4.5 | 1.35\u20131.5x |
566+
567+
Decode is now handled by the dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`) instead of
568+
the prefill tiling kernel; see *Decode uses a dedicated GEMV kernel* above. Replacing the
569+
per-head `M = 1` `MlasSgemmOperation` QK/SV calls with direct GEMVs removes the SGEMM
570+
B-packing overhead that previously made flash decode noticeably **slower** than naive
571+
(measured ≈0.4\u20130.6x across all lengths before the change). Flash decode is now at parity
572+
for short/medium sequences (where the work is memory-bandwidth bound and overhead-dominated)
573+
and consistently ahead for long contexts (T≥4097, ~1.4\u20131.5x) where the streamed
574+
single-pass KV access wins. Short decode remains overhead-bound rather than algorithm-bound,
575+
so it is not the target of the prefill-oriented causal early-termination optimization.
524576
## Current CPU Limitations
525577

526578
The current CPU GroupQueryAttention implementation has a few important limitations:

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,11 @@ class GQAAttentionBase {
11001100
// Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats
11011101
partials_buffer_bytes = SafeInt<size_t>(batch_size) * num_heads_ *
11021102
kv_chunk_count * (2 + head_size) * sizeof(float);
1103+
} else if (sequence_length == 1) {
1104+
// Decode (GEMV kernel, no Q/KV tiling): per-thread scratch holds the full
1105+
// score row scores[total_seqlen] plus a temp output accumulator[head_size].
1106+
buffer_size_per_thread =
1107+
(SafeInt<size_t>(max_total_seqlen) + head_size) * sizeof(float);
11031108
} else {
11041109
buffer_size_per_thread =
11051110
(SafeInt<size_t>(q_block_size) * 2 + // l + m

onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,12 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
349349
// naive path when an unsupported feature is requested (softcap, smooth softmax,
350350
// head sink, or QK output).
351351
//
352-
// The flash path is currently used for prefill only (sequence_length > 1). Single-token
353-
// decode (sequence_length == 1) falls back to the naive path; a dedicated decode kernel
354-
// is added in a follow-up change.
352+
// Prefill (sequence_length > 1) uses the tiled kernel; single-token decode
353+
// (sequence_length == 1 with total_sequence_length > 1) uses the dedicated GEMV
354+
// decode kernel. Both are reached when total_sequence_length > 1.
355355
if constexpr (std::is_same_v<T, float>) {
356356
const bool use_flash = !disable_gqa_flash_ &&
357-
parameters.sequence_length > 1 &&
357+
parameters.total_sequence_length > 1 &&
358358
softcap_ == 0.0f &&
359359
!use_smooth_softmax_ &&
360360
head_sink_data == nullptr &&

0 commit comments

Comments
 (0)