Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/append_attn/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct AppendAttnMetaData {
};

__forceinline__ __host__ __device__ int div_up(int a, int b) {
return (a + b - 1) / b;
return a / b + (a % b != 0);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 div_up 修改超出本 PR 预期范围

此改动将 (a + b - 1) / b 改为 a / b + (a % b != 0),虽对正整数数学等价,但修改了 append_attn/ 目录下所有算子共用的工具函数,而本 PR 的新算子并不在该目录下。

建议修复方式:

  • 若目的是修复潜在的整数溢出风险(a + b - 1 对超大值可能溢出),请在 PR 描述中补充说明,并增加对现有 append_attn 算子的回归测试
  • 若只是 decode_unified_attention 内部需要,建议只在 custom_ops/gpu_ops/decode_unified_attention/utils.cuh 中单独定义,避免修改无关文件

}

enum PosEncMode { kNonePos, kRoPE, kAliBi };
Expand Down
102 changes: 102 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,84 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const int sliding_window,
const int sink_size);

std::vector<paddle::Tensor> DecoderWriteCacheWithRoPE(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& set_max_lengths,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const bool speculate_decoder);

std::vector<paddle::Tensor> DecodeUnifiedAttention(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& tmp_workspace,
const paddle::Tensor& tmp_m,
const paddle::Tensor& tmp_d,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& block_indices,
const paddle::Tensor& num_blocks,
const paddle::Tensor& chunk_size,
const paddle::Tensor& set_max_lengths,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& sinks,
paddle::Tensor& fmha_out,
const std::string& cache_quant_type,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const int max_tokens_per_batch,
const bool causal,
const int sliding_window);

void ConfigForAttention(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
paddle::Tensor& block_indices, // Inplace
paddle::Tensor& num_blocks, // Inplace
paddle::Tensor& chunk_size, // Inplace
paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU
const std::string cache_quant_type,
const int group_size,
const int kv_num_heads,
const int max_tokens_per_batch);

std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
Expand Down Expand Up @@ -1963,4 +2041,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("per_token_group_fp8_quant",
&PerTokenGroupQuantFp8,
"per_token_group_quant_fp8");

/**

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 decoder_write_cache_with_ropeconfig_for_attention 仅通过 pybind11 注册,未见 PD_BUILD_STATIC_OP

在 diff 范围内,只有 decode_unified_attention.cu 中可见 PD_BUILD_STATIC_OP(decode_unified_attention) 宏注册。decoder_write_cache_with_rope.cudecode_unified_attention/config_for_attention.cu 未在 diff 中展示其静态 Op 注册。

请确认:

  1. 两个文件中是否也有对应的 PD_BUILD_STATIC_OP(decoder_write_cache_with_rope) / PD_BUILD_STATIC_OP(config_for_attention) 宏;
  2. 若这两个算子只通过 pybind11 调用(动态图),请在代码中加注释说明不需要静态注册的原因。

* decoder_write_cache_with_rope.cu
* decoder_write_cache_with_rope
*/
m.def("decoder_write_cache_with_rope",
&DecoderWriteCacheWithRoPE,
"decoder write cache with RoPE function");

/**
* decode_unified_attention.cu
* decode_unified_attention
*/
m.def("decode_unified_attention",
&DecodeUnifiedAttention,
"decoder append attention function");

/**
* config_for_attention.cu
* config_for_attention
*/
m.def("config_for_attention",
&ConfigForAttention,
"config for attention function");
}
Loading
Loading