@@ -27,8 +27,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
2727 const paddle::Tensor& batch_id_per_token,
2828 const paddle::Tensor& cu_seqlens_q,
2929 const paddle::Tensor& block_tables,
30+ const paddle::Tensor& slot_mapping,
3031 const paddle::optional<paddle::Tensor>& kv_signal_data,
31- const int max_seq_len,
3232 cudaStream_t& stream,
3333 paddle::Tensor* kv_cache) {
3434 typedef PDTraits<T> traits_;
@@ -50,19 +50,21 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
5050 int grid_size = 1 ;
5151 GetNumBlocks<128 >(pack_num, &grid_size);
5252
53- prefill_absorb_cache_kernel<DataType_, PackSize>
53+ using CT = DataType_;
54+
55+ prefill_absorb_cache_kernel<DataType_, PackSize, CT>
5456 <<<grid_size, blocksize, 0 , stream>>> (
5557 reinterpret_cast <DataType_*>(
5658 const_cast <data_t *>(kv_nope.data <data_t >())),
5759 reinterpret_cast <DataType_*>(
5860 const_cast <data_t *>(kv_pe.data <data_t >())),
5961 reinterpret_cast <DataType_*>(kv_cache->data <data_t >()),
6062 block_tables.data <int >(),
63+ slot_mapping.data <int64_t >(),
6164 batch_id_per_token.data <int >(),
6265 cu_seqlens_q.data <int >(),
6366 seq_lens.data <int >(),
6467 seq_lens_decoder.data <int >(),
65- max_seq_len,
6668 max_blocks_per_seq,
6769 kv_num_heads,
6870 nope_size,
@@ -108,9 +110,9 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
108110 const paddle::Tensor& batch_id_per_token,
109111 const paddle::Tensor& cu_seqlens_q,
110112 const paddle::Tensor& block_tables,
113+ const paddle::Tensor& slot_mapping,
111114 const paddle::optional<paddle::Tensor>& kv_signal_data,
112- const std::string& cache_quant_type_str,
113- const int max_seq_len) {
115+ const std::string& cache_quant_type_str) {
114116 cudaStream_t stream = kv_pe.stream ();
115117 AppendAttnMetaData meta_data;
116118 const auto & kv_nope_dims = kv_nope.dims ();
@@ -137,8 +139,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
137139 batch_id_per_token,
138140 cu_seqlens_q,
139141 block_tables,
142+ slot_mapping,
140143 kv_signal_data,
141- max_seq_len,
142144 stream,
143145 const_cast <paddle::Tensor*>(&kv_cache));
144146 }
@@ -152,8 +154,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
152154 batch_id_per_token,
153155 cu_seqlens_q,
154156 block_tables,
157+ slot_mapping,
155158 kv_signal_data,
156- max_seq_len,
157159 stream,
158160 const_cast <paddle::Tensor*>(&kv_cache));
159161 }
@@ -171,7 +173,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
171173 const paddle::Tensor& batch_id_per_token,
172174 const paddle::Tensor& cu_seqlens_q,
173175 const paddle::Tensor& block_tables,
174- const int max_seq_len,
175176 const bool speculate_decoder,
176177 cudaStream_t& stream,
177178 paddle::Tensor* kv_cache) {
@@ -207,7 +208,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
207208 cu_seqlens_q.data <int >(),
208209 seq_lens.data <int >(),
209210 seq_lens_encoder.data <int >(),
210- max_seq_len,
211211 max_blocks_per_seq,
212212 kv_num_heads,
213213 nope_size,
@@ -229,7 +229,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
229229 cu_seqlens_q.data <int >(),
230230 seq_lens.data <int >(),
231231 seq_lens_encoder.data <int >(),
232- max_seq_len,
233232 max_blocks_per_seq,
234233 kv_num_heads,
235234 nope_size,
@@ -250,7 +249,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
250249 const paddle::Tensor& cu_seqlens_q,
251250 const paddle::Tensor& block_tables,
252251 const std::string& cache_quant_type_str,
253- const int max_seq_len,
254252 const bool speculate_decoder) {
255253 cudaStream_t stream = kv_pe.stream ();
256254 AppendAttnMetaData meta_data;
@@ -278,7 +276,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
278276 batch_id_per_token,
279277 cu_seqlens_q,
280278 block_tables,
281- max_seq_len,
282279 speculate_decoder,
283280 stream,
284281 const_cast <paddle::Tensor*>(&kv_cache));
@@ -293,7 +290,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
293290 batch_id_per_token,
294291 cu_seqlens_q,
295292 block_tables,
296- max_seq_len,
297293 speculate_decoder,
298294 stream,
299295 const_cast <paddle::Tensor*>(&kv_cache));
@@ -311,10 +307,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
311307 " batch_id_per_token" ,
312308 " cu_seqlens_q" ,
313309 " block_tables" ,
310+ " slot_mapping" ,
314311 paddle::Optional (" kv_signal_data" )})
315312 .Outputs({" kv_cache_out" })
316313 .SetInplaceMap({{" kv_cache" , " kv_cache_out" }})
317- .Attrs({" cache_quant_type_str: std::string" , " max_seq_len: int " })
314+ .Attrs({" cache_quant_type_str: std::string" })
318315 .SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
319316
320317PD_BUILD_STATIC_OP (decode_mla_write_cache)
@@ -328,7 +325,5 @@ PD_BUILD_STATIC_OP(decode_mla_write_cache)
328325 " block_tables" })
329326 .Outputs({" kv_cache_out" })
330327 .SetInplaceMap({{" kv_cache" , " kv_cache_out" }})
331- .Attrs({" cache_quant_type_str: std::string" ,
332- " max_seq_len: int" ,
333- " speculate_decoder: bool" })
328+ .Attrs({" cache_quant_type_str: std::string" , " speculate_decoder: bool" })
334329 .SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));
0 commit comments