You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/contrib_ops/cpu/gqa.md
+59-7Lines changed: 59 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -258,11 +258,45 @@ The non-quantized flash path is selected when ALL of the following hold:
258
258
- No output QK capture
259
259
-`present_key` and `present_value` are provided
260
260
261
-
Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffersare all supported for prefill, mirroring the quantized flash path. The non-quantized flash path is currently selected for prefill only (`sequence_length > 1`); single-token decode falls back to the naive full-materialization path (a dedicated decode kernel is added in a follow-up change). When any supported condition is not met, the kernel also falls back to the naive path.
261
+
Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path.
262
262
263
263
### Block Sizes, Threading, and Flash Decoding
264
264
265
-
Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) for prefill are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. The two-phase flash-decoding strategy for single-token decode is gated off for the non-quantized path in this PR (decode falls back to naive); it is enabled together with the dedicated decode kernel in a follow-up change.
265
+
Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization.
266
+
267
+
#### Decode uses a dedicated GEMV kernel (`sequence_length == 1`)
268
+
269
+
The tiled online-softmax SGEMM kernel (`MlasFlashAttentionGQAThreaded`) is used **only for
270
+
prefill** (`sequence_length > 1`), where each KV tile is reused across the `q_block_size`
271
+
query rows and tiling delivers real cache-locality and SGEMM packing benefits.
272
+
273
+
For single-token decode the query tile has `M = 1`, so every K/V element is streamed
274
+
exactly once with no reuse across query rows. Tiling provides **no** cache-locality
275
+
benefit, and routing the `1 × T × H` work through `MlasSgemmOperation` pays the SGEMM
276
+
B-packing/setup cost on every call — which previously made the flash decode path *slower*
277
+
than the naive path (≈0.4–0.6x) for short-to-medium total sequence lengths.
278
+
279
+
Decode is therefore handled by a dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`),
280
+
dispatched whenever `sequence_length == 1` and flash decoding is not active. It
281
+
parallelizes over `(batch, head)` and, per head, computes the attention directly with two
0 commit comments