Commit d22aafb
[rocm-libraries] ROCm/rocm-libraries#6479 (commit 0705c2d)
CK][fmha] Add StreamLLM sink support to batch_prefill
pipeline (#6479)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Motivation
The existing paged-KV attention pipelines (pagedkv, splitkv) support
StreamLLM-style sink tokens — a fixed set of initial tokens kept in
attention alongside the sliding window. The `batch_prefill` pipeline
(chunked-prefill with VLLM-style block tables) previously hardcoded
`kHasSink = false`, making it incompatible with sink-based attention
patterns in LLM serving scenarios.
This PR extends `batch_prefill` to support `kHasSink` and wires it
into `fmha_fwd_runner` for validation against the existing CPU
reference.
## Technical Details
**Pipeline** (`block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp`):
- When `kHasSink`, the K/V loop splits into a sink phase [0,
sink_seq_end)
and a window phase [seqlen_k_start, seqlen_k_end), mirroring pagedkv.
- K advance at the sink→window transition jumps
`seqlen_k_start - sink_seq_end + kN0` to bridge the gap.
- V scatter-gather offsets are re-initialized at the transition to fix a
window mismatch bug: V was lagging kN0 behind K after the large jump,
loading from the wrong sequence position.
- Bias window, dropout seq_offset, and mask type (LogitsSinkMask)
updated
for sink-awareness.
**Traits / codegen** (`tile_fmha_traits.hpp`, `fmha_fwd.hpp`,
`fmha_batch_prefill.py`):
- `TileFmhaBatchPrefillTraits` gains `kHasSink_` (was hardcoded
`false`).
- Codegen adds `F_sink` field; skips batch-mode kernels (group mode
required).
- CMake test filter broadened from 9 → 33 instances covering
fp16/bf16 × mask/nmask × lse/nlse × sink/nsink.
**Runner** (`fmha_fwd_runner.hpp`, `CMakeLists.txt`):
- `fmha_batch_prefill()` dispatched from `run_fwd` when:
group mode + paged KV + num_splits == 1.
- K/V strides corrected for runner's [num_pages, nhead_k,
page_block_size, hdim] layout.
- `page_block_size % 128` check relaxed: batch_prefill supports ps=16.
- CPU reference paged-KV reordering guards extended with
`CK_TILE_FMHA_FWD_BATCH_PREFILL_API`.
## Test Plan
Build with `-DFMHA_FWD_ENABLE_APIS="fwd;batch_prefill"`, run
`tile_example_fmha_fwd` in group mode with page_block_size=16.
Test matrix:
- Mask: no-mask, causal, sliding window
- Sink: nsink, sink=1..128
- dtype: fp16, bf16
- LSE output: on/off
- seqlen ∈ {512,1024,2048,4096} × window ∈ {32,256,512,1024}
- GQA, chunked prefill, large batch×seqlen
- page_block_size: 16, 32
## Test Result
171 test cases, all valid:y:
- nmask + nsink: ✓
- causal + nsink: ✓
- causal + sink=8: ✓
- sliding window + sink=8 (d=128, d=256): ✓
- bf16, LSE output, GQA: ✓
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.1 parent b75afb4 commit d22aafb
7 files changed
Lines changed: 261 additions & 59 deletions
File tree
- example/ck_tile/01_fmha
- codegen/ops
- include/ck_tile/ops/fmha
- kernel
- pipeline
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | | - | |
| 13 | + | |
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| |||
48 | 48 | | |
49 | 49 | | |
50 | 50 | | |
51 | | - | |
52 | 51 | | |
53 | 52 | | |
54 | 53 | | |
| |||
174 | 173 | | |
175 | 174 | | |
176 | 175 | | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
177 | 183 | | |
178 | 184 | | |
179 | 185 | | |
| |||
Lines changed: 36 additions & 9 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
84 | 84 | | |
85 | 85 | | |
86 | 86 | | |
| 87 | + | |
87 | 88 | | |
88 | 89 | | |
89 | 90 | | |
| |||
124 | 125 | | |
125 | 126 | | |
126 | 127 | | |
127 | | - | |
| 128 | + | |
128 | 129 | | |
129 | 130 | | |
130 | 131 | | |
| |||
201 | 202 | | |
202 | 203 | | |
203 | 204 | | |
204 | | - | |
| 205 | + | |
205 | 206 | | |
206 | | - | |
| 207 | + | |
207 | 208 | | |
208 | 209 | | |
209 | 210 | | |
| |||
247 | 248 | | |
248 | 249 | | |
249 | 250 | | |
| 251 | + | |
250 | 252 | | |
251 | 253 | | |
252 | 254 | | |
| |||
343 | 345 | | |
344 | 346 | | |
345 | 347 | | |
| 348 | + | |
346 | 349 | | |
347 | 350 | | |
348 | 351 | | |
| |||
406 | 409 | | |
407 | 410 | | |
408 | 411 | | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
409 | 417 | | |
410 | 418 | | |
411 | 419 | | |
| |||
472 | 480 | | |
473 | 481 | | |
474 | 482 | | |
| 483 | + | |
475 | 484 | | |
476 | 485 | | |
477 | 486 | | |
| |||
578 | 587 | | |
579 | 588 | | |
580 | 589 | | |
| 590 | + | |
581 | 591 | | |
582 | 592 | | |
583 | 593 | | |
| |||
617 | 627 | | |
618 | 628 | | |
619 | 629 | | |
| 630 | + | |
620 | 631 | | |
621 | 632 | | |
622 | 633 | | |
| |||
655 | 666 | | |
656 | 667 | | |
657 | 668 | | |
| 669 | + | |
658 | 670 | | |
659 | 671 | | |
660 | 672 | | |
| |||
663 | 675 | | |
664 | 676 | | |
665 | 677 | | |
| 678 | + | |
666 | 679 | | |
667 | 680 | | |
668 | 681 | | |
669 | | - | |
| 682 | + | |
670 | 683 | | |
671 | | - | |
| 684 | + | |
672 | 685 | | |
673 | 686 | | |
674 | 687 | | |
| |||
684 | 697 | | |
685 | 698 | | |
686 | 699 | | |
687 | | - | |
| 700 | + | |
688 | 701 | | |
689 | 702 | | |
690 | 703 | | |
| |||
701 | 714 | | |
702 | 715 | | |
703 | 716 | | |
704 | | - | |
| 717 | + | |
| 718 | + | |
705 | 719 | | |
| 720 | + | |
| 721 | + | |
| 722 | + | |
| 723 | + | |
| 724 | + | |
| 725 | + | |
| 726 | + | |
706 | 727 | | |
707 | 728 | | |
708 | 729 | | |
709 | 730 | | |
710 | 731 | | |
711 | 732 | | |
| 733 | + | |
| 734 | + | |
| 735 | + | |
712 | 736 | | |
713 | 737 | | |
714 | 738 | | |
715 | 739 | | |
716 | 740 | | |
717 | 741 | | |
| 742 | + | |
| 743 | + | |
| 744 | + | |
718 | 745 | | |
719 | 746 | | |
720 | 747 | | |
| |||
829 | 856 | | |
830 | 857 | | |
831 | 858 | | |
832 | | - | |
| 859 | + | |
833 | 860 | | |
834 | 861 | | |
835 | 862 | | |
| |||
844 | 871 | | |
845 | 872 | | |
846 | 873 | | |
847 | | - | |
| 874 | + | |
848 | 875 | | |
849 | 876 | | |
850 | 877 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1452 | 1452 | | |
1453 | 1453 | | |
1454 | 1454 | | |
| 1455 | + | |
1455 | 1456 | | |
1456 | 1457 | | |
1457 | 1458 | | |
| |||
1480 | 1481 | | |
1481 | 1482 | | |
1482 | 1483 | | |
1483 | | - | |
| 1484 | + | |
1484 | 1485 | | |
1485 | 1486 | | |
1486 | 1487 | | |
| |||
0 commit comments