@@ -10,10 +10,13 @@ INFINIOP_MOORE_KERNEL pagedCaching(
1010 const int64_t * slot_mapping ,
1111 const size_t head_size , const size_t block_size ,
1212 const ptrdiff_t k_src_stride , const ptrdiff_t v_src_stride ,
13- const ptrdiff_t k_cache_block_stride , const ptrdiff_t v_cache_block_stride ) {
13+ const ptrdiff_t k_cache_block_stride , const ptrdiff_t v_cache_block_stride ,
14+ const ptrdiff_t k_cache_head_stride , const ptrdiff_t v_cache_head_stride ,
15+ const ptrdiff_t k_cache_slot_stride , const ptrdiff_t v_cache_slot_stride ) {
1416 op ::paged_caching ::cuda ::pagedCachingKernel < Tdata , NUM_THREADS > (
1517 k_cache , v_cache , k , v , slot_mapping , head_size ,
16- block_size , k_src_stride , v_src_stride , k_cache_block_stride , v_cache_block_stride );
18+ block_size , k_src_stride , v_src_stride ,
19+ k_cache_block_stride , v_cache_block_stride , k_cache_head_stride , v_cache_head_stride , k_cache_slot_stride , v_cache_slot_stride );
1720}
1821
1922namespace op ::paged_caching ::moore {
@@ -59,6 +62,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
5962 size_t num_tokens , size_t num_kv_heads , size_t head_size , size_t block_size ,
6063 ptrdiff_t k_src_stride , ptrdiff_t v_src_stride ,
6164 ptrdiff_t k_cache_block_stride , ptrdiff_t v_cache_block_stride ,
65+ ptrdiff_t k_cache_head_stride , ptrdiff_t v_cache_head_stride ,
66+ ptrdiff_t k_cache_slot_stride , ptrdiff_t v_cache_slot_stride ,
6267 musaStream_t stream ) {
6368
6469 // Grid dimension is 1D, with one block per token, as we decided.
@@ -83,7 +88,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
8388 k_src_stride ,
8489 v_src_stride ,
8590 k_cache_block_stride ,
86- v_cache_block_stride );
91+ v_cache_block_stride ,
92+ k_cache_head_stride ,
93+ v_cache_head_stride ,
94+ k_cache_slot_stride ,
95+ v_cache_slot_stride );
8796 } else if (dtype == INFINI_DTYPE_BF16 ) {
8897 pagedCaching < __mt_bfloat16 , NUM_THREADS >
8998 <<< grid , block , shared_mem_size , stream >>> (
@@ -97,7 +106,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
97106 k_src_stride ,
98107 v_src_stride ,
99108 k_cache_block_stride ,
100- v_cache_block_stride );
109+ v_cache_block_stride ,
110+ k_cache_head_stride ,
111+ v_cache_head_stride ,
112+ k_cache_slot_stride ,
113+ v_cache_slot_stride );
101114 } else if (dtype == INFINI_DTYPE_F32 ) {
102115 pagedCaching < float , NUM_THREADS >
103116 <<< grid , block , shared_mem_size , stream >>> (
@@ -111,7 +124,11 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
111124 k_src_stride ,
112125 v_src_stride ,
113126 k_cache_block_stride ,
114- v_cache_block_stride );
127+ v_cache_block_stride ,
128+ k_cache_head_stride ,
129+ v_cache_head_stride ,
130+ k_cache_slot_stride ,
131+ v_cache_slot_stride );
115132 } else {
116133 return INFINI_STATUS_BAD_TENSOR_DTYPE ;
117134 }
@@ -137,13 +154,17 @@ infiniStatus_t Descriptor::calculate(
137154 _info . num_tokens , _info . num_kv_heads , _info . head_size , _info . block_size ,
138155 _info . k_src_stride , _info . v_src_stride ,
139156 _info . k_cache_block_stride , _info . v_cache_block_stride ,
157+ _info . k_cache_head_stride , _info . v_cache_head_stride ,
158+ _info . k_cache_slot_stride , _info . v_cache_slot_stride ,
140159 stream );
141160 } else if (_opaque -> internal -> maxThreadsPerBlock () >= MOORE_BLOCK_SIZE_512 ) {
142161 launchKernel < MOORE_BLOCK_SIZE_512 > (
143162 _info , k_cache , v_cache , _info . dtype , k , v , slot_mapping ,
144163 _info . num_tokens , _info . num_kv_heads , _info . head_size , _info . block_size ,
145164 _info . k_src_stride , _info . v_src_stride ,
146165 _info . k_cache_block_stride , _info . v_cache_block_stride ,
166+ _info . k_cache_head_stride , _info . v_cache_head_stride ,
167+ _info . k_cache_slot_stride , _info . v_cache_slot_stride ,
147168 stream );
148169 } else {
149170 // If the GPU is older and supports fewer threads, return an error.
0 commit comments