-
Notifications
You must be signed in to change notification settings - Fork 752
[Feature] support decode unified attention #7688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0988ec5
c57e336
5edc78b
633ff02
184b648
89c917b
9bea80e
4596ea9
3f7f0fd
faac077
0b30335
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -189,6 +189,84 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput( | |
| const int sliding_window, | ||
| const int sink_size); | ||
|
|
||
| std::vector<paddle::Tensor> DecoderWriteCacheWithRoPE( | ||
| const paddle::Tensor& qkv, | ||
| const paddle::Tensor& key_cache, | ||
| const paddle::Tensor& value_cache, | ||
| const paddle::Tensor& seq_lens_encoder, | ||
| const paddle::Tensor& seq_lens_decoder, | ||
| const paddle::Tensor& seq_lens_this_time, | ||
| const paddle::Tensor& batch_id_per_token, | ||
| const paddle::Tensor& cu_seqlens_q, | ||
| const paddle::Tensor& block_tables, | ||
| const paddle::Tensor& set_max_lengths, | ||
| const paddle::optional<paddle::Tensor>& rotary_embs, | ||
| const paddle::optional<paddle::Tensor>& qkv_bias, | ||
| const paddle::optional<paddle::Tensor>& cache_k_quant_scales, | ||
| const paddle::optional<paddle::Tensor>& cache_v_quant_scales, | ||
| const paddle::optional<paddle::Tensor>& cache_k_dequant_scales, | ||
| const paddle::optional<paddle::Tensor>& cache_v_dequant_scales, | ||
| const paddle::optional<paddle::Tensor>& cache_k_zp, | ||
| const paddle::optional<paddle::Tensor>& cache_v_zp, | ||
| const paddle::optional<paddle::Tensor>& kv_signal_data, | ||
| const paddle::optional<paddle::Tensor>& q_norm_weight, | ||
| const paddle::optional<paddle::Tensor>& k_norm_weight, | ||
| const float rms_norm_eps, | ||
| const std::string& cache_quant_type_str, | ||
| const bool use_neox_rotary_style, | ||
| const bool rope_3d, | ||
| const int max_input_length, | ||
| const float quant_max_bound, | ||
| const float quant_min_bound, | ||
| const bool speculate_decoder); | ||
|
|
||
| std::vector<paddle::Tensor> DecodeUnifiedAttention( | ||
| const paddle::Tensor& qkv, | ||
| const paddle::Tensor& key_cache, | ||
| const paddle::Tensor& value_cache, | ||
| const paddle::Tensor& tmp_workspace, | ||
| const paddle::Tensor& tmp_m, | ||
| const paddle::Tensor& tmp_d, | ||
| const paddle::Tensor& seq_lens_encoder, | ||
| const paddle::Tensor& seq_lens_decoder, | ||
| const paddle::Tensor& seq_lens_this_time, | ||
| const paddle::Tensor& batch_id_per_token, | ||
| const paddle::Tensor& cu_seqlens_q, | ||
| const paddle::Tensor& block_tables, | ||
| const paddle::Tensor& block_indices, | ||
| const paddle::Tensor& num_blocks, | ||
| const paddle::Tensor& chunk_size, | ||
| const paddle::Tensor& set_max_lengths, | ||
| const paddle::optional<paddle::Tensor>& attn_mask, | ||
| const paddle::optional<paddle::Tensor>& cache_k_quant_scales, | ||
| const paddle::optional<paddle::Tensor>& cache_v_quant_scales, | ||
| const paddle::optional<paddle::Tensor>& cache_k_dequant_scales, | ||
| const paddle::optional<paddle::Tensor>& cache_v_dequant_scales, | ||
| const paddle::optional<paddle::Tensor>& cache_k_zp, | ||
| const paddle::optional<paddle::Tensor>& cache_v_zp, | ||
| const paddle::optional<paddle::Tensor>& mask_offset, | ||
| const paddle::optional<paddle::Tensor>& sinks, | ||
| paddle::Tensor& fmha_out, | ||
| const std::string& cache_quant_type, | ||
| const int max_input_length, | ||
| const float quant_max_bound, | ||
| const float quant_min_bound, | ||
| const int max_tokens_per_batch, | ||
| const bool causal, | ||
| const int sliding_window); | ||
|
|
||
| void ConfigForAttention(const paddle::Tensor& seq_lens_encoder, | ||
| const paddle::Tensor& seq_lens_decoder, | ||
| const paddle::Tensor& seq_lens_this_time, | ||
| paddle::Tensor& block_indices, // Inplace | ||
| paddle::Tensor& num_blocks, // Inplace | ||
| paddle::Tensor& chunk_size, // Inplace | ||
| paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU | ||
| const std::string cache_quant_type, | ||
| const int group_size, | ||
| const int kv_num_heads, | ||
| const int max_tokens_per_batch); | ||
|
|
||
| std::vector<paddle::Tensor> GQARopeWriteCacheKernel( | ||
| const paddle::Tensor& qkv, | ||
| const paddle::Tensor& key_cache, | ||
|
|
@@ -1963,4 +2041,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) { | |
| m.def("per_token_group_fp8_quant", | ||
| &PerTokenGroupQuantFp8, | ||
| "per_token_group_quant_fp8"); | ||
|
|
||
| /** | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 在 diff 范围内,只有 请确认:
|
||
| * decoder_write_cache_with_rope.cu | ||
| * decoder_write_cache_with_rope | ||
| */ | ||
| m.def("decoder_write_cache_with_rope", | ||
| &DecoderWriteCacheWithRoPE, | ||
| "decoder write cache with RoPE function"); | ||
|
|
||
| /** | ||
| * decode_unified_attention.cu | ||
| * decode_unified_attention | ||
| */ | ||
| m.def("decode_unified_attention", | ||
| &DecodeUnifiedAttention, | ||
| "decoder append attention function"); | ||
|
|
||
| /** | ||
| * config_for_attention.cu | ||
| * config_for_attention | ||
| */ | ||
| m.def("config_for_attention", | ||
| &ConfigForAttention, | ||
| "config for attention function"); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议
div_up修改超出本 PR 预期范围此改动将
(a + b - 1) / b改为a / b + (a % b != 0),虽对正整数数学等价,但修改了append_attn/目录下所有算子共用的工具函数,而本 PR 的新算子并不在该目录下。建议修复方式:
a + b - 1对超大值可能溢出),请在 PR 描述中补充说明,并增加对现有 append_attn 算子的回归测试custom_ops/gpu_ops/decode_unified_attention/utils.cuh中单独定义,避免修改无关文件