@@ -24,9 +24,9 @@ template <typename OffsetVecType,
2424 BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout ,
2525 bool kIsKcache ,
2626 index_t kVectorSize >
27- CK_TILE_HOST_DEVICE void kv_offset_array_transform (const index_t * page_vec ,
28- const index_t & stride_kv ,
29- const index_t & page_stride_kv ,
27+ CK_TILE_HOST_DEVICE void kv_offset_array_transform (const index_t * page_idx ,
28+ const index_t & stride_token ,
29+ const index_t & stride_page_block ,
3030 const CoordVecType& coord_vec,
3131 OffsetVecType& kv_offset_vec,
3232 index_t global_seq_offset = 0 )
@@ -39,47 +39,70 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
3939 static_for<0 , kLoopCount , 1 >{}([&](auto k0) {
4040 const index_t global_token_idx =
4141 global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value ;
42- const index_t page_id = global_token_idx >> kLog2PageSize ;
43- const index_t page_offset = global_token_idx & kInPageOffsetMask ;
44- kv_offset_vec[k0] = static_cast <long_index_t >(page_vec [page_id]) * page_stride_kv +
45- static_cast <long_index_t >(page_offset ) * stride_kv ;
42+ const index_t page_id = global_token_idx >> kLog2PageSize ;
43+ const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask ;
44+ kv_offset_vec[k0] = static_cast <long_index_t >(page_idx [page_id]) * stride_page_block +
45+ static_cast <long_index_t >(token_idx_in_page ) * stride_token ;
4646 });
4747 }
4848 else
4949 {
5050 // for v offsets
51- const index_t lane0_start = __builtin_amdgcn_readfirstlane (thread_coord_start);
52- const index_t lane0_page_id =
53- (global_seq_offset + lane0_start + kLoopStart ) >> kLog2PageSize ;
54-
55- const long_index_t page_loc =
56- static_cast <long_index_t >(page_vec[lane0_page_id]) * page_stride_kv;
51+ if constexpr (kLog2PageSize == 0 &&
52+ kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT )
53+ {
54+ // page size = 1, per-token page lookup.
55+ // Here page_idx maps token_idx -> physical_page_id, so global_seq_offset must be
56+ // the absolute token index within the batch's kv_page_indices slice.
57+ static_for<0 , kLoopCount , 1 >{}([&](auto k0) {
58+ const index_t global_token_idx =
59+ global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value ;
5760
58- static_for<0 , kLoopCount , 1 >{}([&](auto k0) {
59- const index_t page_offset =
60- (global_seq_offset + thread_coord_start + kLoopStart + k0.value ) &
61- kInPageOffsetMask ;
61+ const long_index_t page_base_offset =
62+ static_cast <long_index_t >(page_idx[global_token_idx]) * stride_page_block;
6263
63- if constexpr (kKVMemoryLayout ==
64- BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT )
65- {
66- // Vectorized layout offset
67- // Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
68- // Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize)
69- const index_t s = page_offset;
70- const index_t D = stride_kv;
64+ kv_offset_vec[k0] = page_base_offset;
65+ });
66+ }
67+ else
68+ {
69+ // This path handles page_size > 1 and/or non-linear KV layout, where page_idx is
70+ // indexed by page_id (token_idx >> log2_page_size) with an in-page offset.
71+ // Assumes the V tile stays within a single page so lane0 can broadcast the page id.
72+ const index_t lane0_start = __builtin_amdgcn_readfirstlane (thread_coord_start);
73+ const index_t lane0_page_id =
74+ (global_seq_offset + lane0_start + kLoopStart ) >> kLog2PageSize ;
75+
76+ const long_index_t page_base_offset =
77+ static_cast <long_index_t >(page_idx[lane0_page_id]) * stride_page_block;
78+
79+ static_for<0 , kLoopCount , 1 >{}([&](auto k0) {
80+ const index_t token_idx_in_page =
81+ (global_seq_offset + thread_coord_start + kLoopStart + k0.value ) &
82+ kInPageOffsetMask ;
83+
84+ if constexpr (kKVMemoryLayout ==
85+ BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT )
86+ {
87+ // Vectorized layout offset
88+ // Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
89+ // Offset = (token_idx_in_page / kVectorSize) * (HeadDim * kVectorSize) +
90+ // (token_idx_in_page % kVectorSize)
7191
72- const long_index_t s_offset =
73- static_cast <long_index_t >((s / kVectorSize ) * (D * kVectorSize )) +
74- (s % kVectorSize );
92+ const long_index_t token_offset =
93+ static_cast <long_index_t >((token_idx_in_page / kVectorSize ) *
94+ (stride_token * kVectorSize )) +
95+ (token_idx_in_page % kVectorSize );
7596
76- kv_offset_vec[k0] = page_loc + s_offset;
77- }
78- else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
79- {
80- kv_offset_vec[k0] = page_loc + static_cast <long_index_t >(page_offset) * stride_kv;
81- }
82- });
97+ kv_offset_vec[k0] = page_base_offset + token_offset;
98+ }
99+ else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
100+ {
101+ kv_offset_vec[k0] = page_base_offset +
102+ static_cast <long_index_t >(token_idx_in_page) * stride_token;
103+ }
104+ });
105+ }
83106 }
84107}
85108
@@ -127,9 +150,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
127150 static constexpr auto I3 = number<3 >{};
128151
129152 static_assert (kSubQKHeaddim <= 256 , " hdim bigger than 256 is not suitable for this pipeline!" );
130- static_assert (kPageBlockSize % kN0 == 0 ,
131- " V offset assumes each tile stays within a page; kPageBlockSize must be "
132- " divisible by kN0." );
153+ static_assert (kPageBlockSize % kN0 == 0 || kLog2PageSize == 0 ,
154+ " Page size must be 1, or a multiple of the tile size (kN0)." );
133155
134156 static constexpr bool kIsGroupMode = Problem::kIsGroupMode ;
135157 // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
0 commit comments