@@ -189,6 +189,84 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
189189 const int sliding_window,
190190 const int sink_size);
191191
192+ std::vector<paddle::Tensor> DecoderWriteCacheWithRoPE (
193+ const paddle::Tensor& qkv,
194+ const paddle::Tensor& key_cache,
195+ const paddle::Tensor& value_cache,
196+ const paddle::Tensor& seq_lens_encoder,
197+ const paddle::Tensor& seq_lens_decoder,
198+ const paddle::Tensor& seq_lens_this_time,
199+ const paddle::Tensor& batch_id_per_token,
200+ const paddle::Tensor& cu_seqlens_q,
201+ const paddle::Tensor& block_tables,
202+ const paddle::Tensor& set_max_lengths,
203+ const paddle::optional<paddle::Tensor>& rotary_embs,
204+ const paddle::optional<paddle::Tensor>& qkv_bias,
205+ const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
206+ const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
207+ const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
208+ const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
209+ const paddle::optional<paddle::Tensor>& cache_k_zp,
210+ const paddle::optional<paddle::Tensor>& cache_v_zp,
211+ const paddle::optional<paddle::Tensor>& kv_signal_data,
212+ const paddle::optional<paddle::Tensor>& q_norm_weight,
213+ const paddle::optional<paddle::Tensor>& k_norm_weight,
214+ const float rms_norm_eps,
215+ const std::string& cache_quant_type_str,
216+ const bool use_neox_rotary_style,
217+ const bool rope_3d,
218+ const int max_input_length,
219+ const float quant_max_bound,
220+ const float quant_min_bound,
221+ const bool speculate_decoder);
222+
223+ std::vector<paddle::Tensor> DecodeUnifiedAttention (
224+ const paddle::Tensor& qkv,
225+ const paddle::Tensor& key_cache,
226+ const paddle::Tensor& value_cache,
227+ const paddle::Tensor& tmp_workspace,
228+ const paddle::Tensor& tmp_m,
229+ const paddle::Tensor& tmp_d,
230+ const paddle::Tensor& seq_lens_encoder,
231+ const paddle::Tensor& seq_lens_decoder,
232+ const paddle::Tensor& seq_lens_this_time,
233+ const paddle::Tensor& batch_id_per_token,
234+ const paddle::Tensor& cu_seqlens_q,
235+ const paddle::Tensor& block_tables,
236+ const paddle::Tensor& block_indices,
237+ const paddle::Tensor& num_blocks,
238+ const paddle::Tensor& chunk_size,
239+ const paddle::Tensor& set_max_lengths,
240+ const paddle::optional<paddle::Tensor>& attn_mask,
241+ const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
242+ const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
243+ const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
244+ const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
245+ const paddle::optional<paddle::Tensor>& cache_k_zp,
246+ const paddle::optional<paddle::Tensor>& cache_v_zp,
247+ const paddle::optional<paddle::Tensor>& mask_offset,
248+ const paddle::optional<paddle::Tensor>& sinks,
249+ paddle::Tensor& fmha_out,
250+ const std::string& cache_quant_type,
251+ const int max_input_length,
252+ const float quant_max_bound,
253+ const float quant_min_bound,
254+ const int max_tokens_per_batch,
255+ const bool causal,
256+ const int sliding_window);
257+
258+ void ConfigForAttention (const paddle::Tensor& seq_lens_encoder,
259+ const paddle::Tensor& seq_lens_decoder,
260+ const paddle::Tensor& seq_lens_this_time,
261+ paddle::Tensor& block_indices, // Inplace
262+ paddle::Tensor& num_blocks, // Inplace
263+ paddle::Tensor& chunk_size, // Inplace
264+ paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU
265+ const std::string cache_quant_type,
266+ const int group_size,
267+ const int kv_num_heads,
268+ const int max_tokens_per_batch);
269+
192270std::vector<paddle::Tensor> GQARopeWriteCacheKernel (
193271 const paddle::Tensor& qkv,
194272 const paddle::Tensor& key_cache,
@@ -1943,4 +2021,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
19432021 m.def (" per_token_group_fp8_quant" ,
19442022 &PerTokenGroupQuantFp8,
19452023 " per_token_group_quant_fp8" );
2024+
2025+ /* *
2026+ * decoder_write_cache_with_rope.cu
2027+ * decoder_write_cache_with_rope
2028+ */
2029+ m.def (" decoder_write_cache_with_rope" ,
2030+ &DecoderWriteCacheWithRoPE,
2031+ " decoder write cache with RoPE function" );
2032+
2033+ /* *
2034+ * decode_unified_attention.cu
2035+ * decode_unified_attention
2036+ */
2037+ m.def (" decode_unified_attention" ,
2038+ &DecodeUnifiedAttention,
2039+ " decoder append attention function" );
2040+
2041+ /* *
2042+ * config_for_attention.cu
2043+ * config_for_attention
2044+ */
2045+ m.def (" config_for_attention" ,
2046+ &ConfigForAttention,
2047+ " config for attention function" );
19462048}
0 commit comments