Skip to content

Commit d22aafb

Browse files
[rocm-libraries] ROCm/rocm-libraries#6479 (commit 0705c2d)
CK][fmha] Add StreamLLM sink support to batch_prefill pipeline (#6479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The existing paged-KV attention pipelines (pagedkv, splitkv) support StreamLLM-style sink tokens — a fixed set of initial tokens kept in attention alongside the sliding window. The `batch_prefill` pipeline (chunked-prefill with VLLM-style block tables) previously hardcoded `kHasSink = false`, making it incompatible with sink-based attention patterns in LLM serving scenarios. This PR extends `batch_prefill` to support `kHasSink` and wires it into `fmha_fwd_runner` for validation against the existing CPU reference. ## Technical Details **Pipeline** (`block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp`): - When `kHasSink`, the K/V loop splits into a sink phase [0, sink_seq_end) and a window phase [seqlen_k_start, seqlen_k_end), mirroring pagedkv. - K advance at the sink→window transition jumps `seqlen_k_start - sink_seq_end + kN0` to bridge the gap. - V scatter-gather offsets are re-initialized at the transition to fix a window mismatch bug: V was lagging kN0 behind K after the large jump, loading from the wrong sequence position. - Bias window, dropout seq_offset, and mask type (LogitsSinkMask) updated for sink-awareness. **Traits / codegen** (`tile_fmha_traits.hpp`, `fmha_fwd.hpp`, `fmha_batch_prefill.py`): - `TileFmhaBatchPrefillTraits` gains `kHasSink_` (was hardcoded `false`). - Codegen adds `F_sink` field; skips batch-mode kernels (group mode required). - CMake test filter broadened from 9 → 33 instances covering fp16/bf16 × mask/nmask × lse/nlse × sink/nsink. **Runner** (`fmha_fwd_runner.hpp`, `CMakeLists.txt`): - `fmha_batch_prefill()` dispatched from `run_fwd` when: group mode + paged KV + num_splits == 1. - K/V strides corrected for runner's [num_pages, nhead_k, page_block_size, hdim] layout. - `page_block_size % 128` check relaxed: batch_prefill supports ps=16. - CPU reference paged-KV reordering guards extended with `CK_TILE_FMHA_FWD_BATCH_PREFILL_API`. ## Test Plan Build with `-DFMHA_FWD_ENABLE_APIS="fwd;batch_prefill"`, run `tile_example_fmha_fwd` in group mode with page_block_size=16. Test matrix: - Mask: no-mask, causal, sliding window - Sink: nsink, sink=1..128 - dtype: fp16, bf16 - LSE output: on/off - seqlen ∈ {512,1024,2048,4096} × window ∈ {32,256,512,1024} - GQA, chunked prefill, large batch×seqlen - page_block_size: 16, 32 ## Test Result 171 test cases, all valid:y: - nmask + nsink: ✓ - causal + nsink: ✓ - causal + sink=8: ✓ - sliding window + sink=8 (d=128, d=256): ✓ - bf16, LSE output, GQA: ✓ ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent b75afb4 commit d22aafb

7 files changed

Lines changed: 261 additions & 59 deletions

File tree

example/ck_tile/01_fmha/CMakeLists.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ if(NOT INST_TARGETS)
1010
endif()
1111

1212
# validate user-specified fmha_fwd API list
13-
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
13+
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill;batch_prefill")
1414
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
1515
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
1616
if(BUILD_TESTING)
@@ -48,7 +48,6 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS
4848
--targets ${FMHA_TARGETS_ARG}
4949
--api ${FMHA_FWD_APIS}
5050
--optdim 32,64,80,128,256
51-
# --filter fmha_fwd...
5251
)
5352
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
5453
${CMAKE_CURRENT_LIST_DIR}/generate.py
@@ -174,6 +173,13 @@ else()
174173
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
175174
endif()
176175

176+
# conditionally enable call to the batch_prefill API in fmha_fwd example and tests
177+
if("batch_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
178+
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=1)
179+
else()
180+
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=0)
181+
endif()
182+
177183
# conditionally specify the use of OCP_FP8
178184
if(CK_USE_OCP_FP8)
179185
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)

example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
{F_qscale},
8585
{F_occupancy},
8686
false,
87+
{F_sink},
8788
{F_page_size},
8889
{F_kv_memory_layout},
8990
{F_kv_lookup_table}>;
@@ -124,7 +125,7 @@
124125
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
125126
126127
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
127-
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
128+
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
128129
129130
#include <iostream>
130131
@@ -201,9 +202,9 @@
201202
}}
202203
"""
203204

204-
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
205+
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
205206
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
206-
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
207+
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
207208
return fmha_batch_prefill_<trait_>(s, a);
208209
}}
209210
"""
@@ -247,6 +248,7 @@ class FmhaFwdApiTrait:
247248
skpad: str
248249
dpad: str
249250
dvpad: str
251+
sink: str # t/f
250252
constraint: CppConstraint
251253
kv_memory_layout: str
252254
kv_lookup_table: str
@@ -343,6 +345,7 @@ class FmhaFwdPipeline:
343345
F_dropout: str #
344346
F_qscale: str # no/pertensor
345347
F_mask: str # value from MASK_MAP
348+
F_sink: str # t/f (StreamLLM sink tokens)
346349
F_kv_memory_layout: str #
347350
F_kv_lookup_table: str #
348351
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@@ -406,6 +409,11 @@ def pad_name() -> str:
406409
else:
407410
n += "_nqscale"
408411

412+
if self.F_sink == "t":
413+
n += "_sink"
414+
else:
415+
n += "_nsink"
416+
409417
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
410418
return n
411419

@@ -472,6 +480,7 @@ def api(self) -> str:
472480
trait.kv_lookup_table
473481
],
474482
F_page_size=trait.page_size,
483+
F_sink=BOOL_MAP[trait.sink],
475484
)
476485
if_j = "if" if j == 0 else "else if"
477486
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
@@ -578,6 +587,7 @@ def template(self) -> str:
578587
F_mode=MODE_MAP[self.F_mode],
579588
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
580589
F_page_size=self.F_page_size,
590+
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
581591
)
582592

583593
@property
@@ -617,6 +627,7 @@ def api_trait(self) -> FmhaFwdApiTrait:
617627
skpad=self.F_pipeline.F_skpad,
618628
dpad=self.F_pipeline.F_dpad,
619629
dvpad=self.F_pipeline.F_dvpad,
630+
sink=self.F_pipeline.F_sink,
620631
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
621632
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
622633
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
@@ -655,6 +666,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
655666
bias,
656667
lse,
657668
dropout,
669+
sink,
658670
kv_memory_layout,
659671
kv_lookup_table,
660672
) in itertools.product(
@@ -663,12 +675,13 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
663675
BIAS_MAP.keys(),
664676
["t", "f"],
665677
["t", "f"],
678+
["t", "f"],
666679
SUPPORTED_KV_MEMORY_LAYOUT,
667680
SUPPORTED_KV_LOOKUP_TABLE,
668681
):
669-
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
682+
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
670683
elif dtype in ["fp8bf16"]:
671-
# no need lse/dropout kernels
684+
# no need lse/dropout/sink kernels
672685
for (
673686
logits,
674687
qscale,
@@ -684,7 +697,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
684697
SUPPORTED_KV_MEMORY_LAYOUT,
685698
SUPPORTED_KV_LOOKUP_TABLE,
686699
):
687-
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
700+
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip
688701
else:
689702
assert False
690703
return pipelines
@@ -701,20 +714,34 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
701714

702715

703716
def get_fwd_blobs(
704-
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
717+
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
718+
targets: Optional[List[str]] = None
705719
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
720+
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
721+
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
722+
# non-gfx9 architectures (gfx11/gfx12/gfx10 are wave32 and use different
723+
# buffer instruction formats). Skip all batch_prefill kernels for non-gfx9 targets.
724+
has_non_gfx9 = targets is not None and any(
725+
not t.startswith("gfx9") for t in targets
726+
)
706727
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
707728
# support this in future
708729

709730
gen = list()
710731
api_pool = FmhaFwdApiPool(mask_impl)
711732

733+
if has_non_gfx9:
734+
return api_pool, gen
735+
712736
for dtype in FWD_DTYPE_MAP.keys():
713737
d = CustomFactory.get_hdim_tile_size_dict(dtype)
714738
if d is None:
715739
continue
716740
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
717741
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
742+
# batch_prefill pipeline requires group mode (static_assert in pipeline problem)
743+
if mode != "group":
744+
continue
718745
for tile, pipeline in itertools.product(
719746
tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
720747
):
@@ -829,7 +856,7 @@ def write_blobs(
829856
optdim_list,
830857
mask_impl,
831858
) -> None:
832-
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
859+
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
833860
for kernel in kernels:
834861
write_single_fwd_kernel(kernel, output_dir)
835862
write_fwd_api(api_pool, output_dir)
@@ -844,7 +871,7 @@ def list_blobs(
844871
mask_impl,
845872
) -> None:
846873
with file_path.open("a") as f:
847-
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
874+
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
848875
for kernel in kernels:
849876
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
850877
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")

example/ck_tile/01_fmha/fmha_fwd.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,7 @@ template <ck_tile::index_t HDim_,
14521452
bool kPadDv_,
14531453
bool kUseTrLoad_,
14541454
bool kSkipMinSeqlenQ_ = false,
1455+
bool kHasSink_ = false,
14551456
ck_tile::index_t kPageBlockSize_ = 1,
14561457
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
14571458
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
@@ -1480,7 +1481,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
14801481
kPadDv_,
14811482
kUseTrLoad_,
14821483
kSkipMinSeqlenQ_,
1483-
false>
1484+
kHasSink_>
14841485
{
14851486
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
14861487
static constexpr auto kKVLookupTable = kKVLookupTable_;

0 commit comments

Comments
 (0)