Skip to content

Commit c9f112b

Browse files
authored
[FMHA] Support page_size=1 (linear layout) in batch prefill pipeline (#3545)
- Enable page_size=1 support in batch prefill codegen (linear layout only). - Implement per-token page lookup in `kv_offset_array_transform` for page_size=1 to handle 3D input tensors correctly. - Relax `kPageBlockSize` alignment assertion for the page_size=1 case.
1 parent a575acb commit c9f112b

2 files changed

Lines changed: 63 additions & 39 deletions

File tree

example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
3838

39-
SUPPORTED_PAGE_SIZE = [128, 256, 1024]
39+
SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
4040
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
4141
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
4242
KV_MEMORY_LAYOUT_ENUM_MAP = {
@@ -737,6 +737,8 @@ def get_fwd_blobs(
737737

738738
# Generate kernels for both page_size=16 and page_size=1024
739739
for page_size in SUPPORTED_PAGE_SIZE:
740+
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
741+
continue
740742
k = FmhaFwdKernel(
741743
F_idx=0,
742744
F_hdim=hdim,

include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ template <typename OffsetVecType,
2424
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
2525
bool kIsKcache,
2626
index_t kVectorSize>
27-
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
28-
const index_t& stride_kv,
29-
const index_t& page_stride_kv,
27+
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
28+
const index_t& stride_token,
29+
const index_t& stride_page_block,
3030
const CoordVecType& coord_vec,
3131
OffsetVecType& kv_offset_vec,
3232
index_t global_seq_offset = 0)
@@ -39,47 +39,70 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
3939
static_for<0, kLoopCount, 1>{}([&](auto k0) {
4040
const index_t global_token_idx =
4141
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
42-
const index_t page_id = global_token_idx >> kLog2PageSize;
43-
const index_t page_offset = global_token_idx & kInPageOffsetMask;
44-
kv_offset_vec[k0] = static_cast<long_index_t>(page_vec[page_id]) * page_stride_kv +
45-
static_cast<long_index_t>(page_offset) * stride_kv;
42+
const index_t page_id = global_token_idx >> kLog2PageSize;
43+
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
44+
kv_offset_vec[k0] = static_cast<long_index_t>(page_idx[page_id]) * stride_page_block +
45+
static_cast<long_index_t>(token_idx_in_page) * stride_token;
4646
});
4747
}
4848
else
4949
{
5050
// for v offsets
51-
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
52-
const index_t lane0_page_id =
53-
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
54-
55-
const long_index_t page_loc =
56-
static_cast<long_index_t>(page_vec[lane0_page_id]) * page_stride_kv;
51+
if constexpr(kLog2PageSize == 0 &&
52+
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
53+
{
54+
// page size = 1, per-token page lookup.
55+
// Here page_idx maps token_idx -> physical_page_id, so global_seq_offset must be
56+
// the absolute token index within the batch's kv_page_indices slice.
57+
static_for<0, kLoopCount, 1>{}([&](auto k0) {
58+
const index_t global_token_idx =
59+
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
5760

58-
static_for<0, kLoopCount, 1>{}([&](auto k0) {
59-
const index_t page_offset =
60-
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
61-
kInPageOffsetMask;
61+
const long_index_t page_base_offset =
62+
static_cast<long_index_t>(page_idx[global_token_idx]) * stride_page_block;
6263

63-
if constexpr(kKVMemoryLayout ==
64-
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
65-
{
66-
// Vectorized layout offset
67-
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
68-
// Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize)
69-
const index_t s = page_offset;
70-
const index_t D = stride_kv;
64+
kv_offset_vec[k0] = page_base_offset;
65+
});
66+
}
67+
else
68+
{
69+
// This path handles page_size > 1 and/or non-linear KV layout, where page_idx is
70+
// indexed by page_id (token_idx >> log2_page_size) with an in-page offset.
71+
// Assumes the V tile stays within a single page so lane0 can broadcast the page id.
72+
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
73+
const index_t lane0_page_id =
74+
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
75+
76+
const long_index_t page_base_offset =
77+
static_cast<long_index_t>(page_idx[lane0_page_id]) * stride_page_block;
78+
79+
static_for<0, kLoopCount, 1>{}([&](auto k0) {
80+
const index_t token_idx_in_page =
81+
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
82+
kInPageOffsetMask;
83+
84+
if constexpr(kKVMemoryLayout ==
85+
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
86+
{
87+
// Vectorized layout offset
88+
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
89+
// Offset = (token_idx_in_page / kVectorSize) * (HeadDim * kVectorSize) +
90+
// (token_idx_in_page % kVectorSize)
7191

72-
const long_index_t s_offset =
73-
static_cast<long_index_t>((s / kVectorSize) * (D * kVectorSize)) +
74-
(s % kVectorSize);
92+
const long_index_t token_offset =
93+
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
94+
(stride_token * kVectorSize)) +
95+
(token_idx_in_page % kVectorSize);
7596

76-
kv_offset_vec[k0] = page_loc + s_offset;
77-
}
78-
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
79-
{
80-
kv_offset_vec[k0] = page_loc + static_cast<long_index_t>(page_offset) * stride_kv;
81-
}
82-
});
97+
kv_offset_vec[k0] = page_base_offset + token_offset;
98+
}
99+
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
100+
{
101+
kv_offset_vec[k0] = page_base_offset +
102+
static_cast<long_index_t>(token_idx_in_page) * stride_token;
103+
}
104+
});
105+
}
83106
}
84107
}
85108

@@ -127,9 +150,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
127150
static constexpr auto I3 = number<3>{};
128151

129152
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
130-
static_assert(kPageBlockSize % kN0 == 0,
131-
"V offset assumes each tile stays within a page; kPageBlockSize must be "
132-
"divisible by kN0.");
153+
static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0,
154+
"Page size must be 1, or a multiple of the tile size (kN0).");
133155

134156
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
135157
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)

0 commit comments

Comments
 (0)