@@ -189,6 +189,84 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
189189 const int sliding_window,
190190 const int sink_size);
191191
192+ std::vector<paddle::Tensor> DecoderWriteCacheWithRoPE (
193+ const paddle::Tensor& qkv,
194+ const paddle::Tensor& key_cache,
195+ const paddle::Tensor& value_cache,
196+ const paddle::Tensor& seq_lens_encoder,
197+ const paddle::Tensor& seq_lens_decoder,
198+ const paddle::Tensor& seq_lens_this_time,
199+ const paddle::Tensor& batch_id_per_token,
200+ const paddle::Tensor& cu_seqlens_q,
201+ const paddle::Tensor& block_tables,
202+ const paddle::Tensor& set_max_lengths,
203+ const paddle::optional<paddle::Tensor>& rotary_embs,
204+ const paddle::optional<paddle::Tensor>& qkv_bias,
205+ const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
206+ const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
207+ const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
208+ const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
209+ const paddle::optional<paddle::Tensor>& cache_k_zp,
210+ const paddle::optional<paddle::Tensor>& cache_v_zp,
211+ const paddle::optional<paddle::Tensor>& kv_signal_data,
212+ const paddle::optional<paddle::Tensor>& q_norm_weight,
213+ const paddle::optional<paddle::Tensor>& k_norm_weight,
214+ const float rms_norm_eps,
215+ const std::string& cache_quant_type_str,
216+ const bool use_neox_rotary_style,
217+ const bool rope_3d,
218+ const int max_input_length,
219+ const float quant_max_bound,
220+ const float quant_min_bound,
221+ const bool speculate_decoder);
222+
223+ std::vector<paddle::Tensor> DecodeUnifiedAttention (
224+ const paddle::Tensor& qkv,
225+ const paddle::Tensor& key_cache,
226+ const paddle::Tensor& value_cache,
227+ const paddle::Tensor& tmp_workspace,
228+ const paddle::Tensor& tmp_m,
229+ const paddle::Tensor& tmp_d,
230+ const paddle::Tensor& seq_lens_encoder,
231+ const paddle::Tensor& seq_lens_decoder,
232+ const paddle::Tensor& seq_lens_this_time,
233+ const paddle::Tensor& batch_id_per_token,
234+ const paddle::Tensor& cu_seqlens_q,
235+ const paddle::Tensor& block_tables,
236+ const paddle::Tensor& block_indices,
237+ const paddle::Tensor& num_blocks,
238+ const paddle::Tensor& chunk_size,
239+ const paddle::Tensor& set_max_lengths,
240+ const paddle::optional<paddle::Tensor>& attn_mask,
241+ const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
242+ const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
243+ const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
244+ const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
245+ const paddle::optional<paddle::Tensor>& cache_k_zp,
246+ const paddle::optional<paddle::Tensor>& cache_v_zp,
247+ const paddle::optional<paddle::Tensor>& mask_offset,
248+ const paddle::optional<paddle::Tensor>& sinks,
249+ paddle::Tensor& fmha_out,
250+ const std::string& cache_quant_type,
251+ const int max_input_length,
252+ const float quant_max_bound,
253+ const float quant_min_bound,
254+ const int max_tokens_per_batch,
255+ const bool causal,
256+ const int sliding_window);
257+
258+ void ConfigForAttention (const paddle::Tensor& seq_lens_encoder,
259+ const paddle::Tensor& seq_lens_decoder,
260+ const paddle::Tensor& seq_lens_this_time,
261+ paddle::Tensor& block_indices, // Inplace
262+ paddle::Tensor& num_blocks, // Inplace
263+ paddle::Tensor& chunk_size, // Inplace
264+ paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU
265+ const std::string cache_quant_type,
266+ const int group_size,
267+ const int kv_num_heads,
268+ const int max_tokens_per_batch);
269+
192270std::vector<paddle::Tensor> GQARopeWriteCacheKernel (
193271 const paddle::Tensor& qkv,
194272 const paddle::Tensor& key_cache,
@@ -1962,4 +2040,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
19622040 m.def (" per_token_group_fp8_quant" ,
19632041 &PerTokenGroupQuantFp8,
19642042 " per_token_group_quant_fp8" );
2043+
2044+ /* *
2045+ * decoder_write_cache_with_rope.cu
2046+ * decoder_write_cache_with_rope
2047+ */
2048+ m.def (" decoder_write_cache_with_rope" ,
2049+ &DecoderWriteCacheWithRoPE,
2050+ " decoder write cache with RoPE function" );
2051+
2052+ /* *
2053+ * decode_unified_attention.cu
2054+ * decode_unified_attention
2055+ */
2056+ m.def (" decode_unified_attention" ,
2057+ &DecodeUnifiedAttention,
2058+ " decoder append attention function" );
2059+
2060+ /* *
2061+ * config_for_attention.cu
2062+ * config_for_attention
2063+ */
2064+ m.def (" config_for_attention" ,
2065+ &ConfigForAttention,
2066+ " config for attention function" );
19652067}
0 commit comments