Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/append_attn/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct AppendAttnMetaData {
};

__forceinline__ __host__ __device__ int div_up(int a, int b) {
return (a + b - 1) / b;
return a / b + (a % b != 0);
}

enum PosEncMode { kNonePos, kRoPE, kAliBi };
Expand Down
102 changes: 102 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,84 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
const int sliding_window,
const int sink_size);

std::vector<paddle::Tensor> DecoderWriteCacheWithRoPE(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& set_max_lengths,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const bool speculate_decoder);

std::vector<paddle::Tensor> DecodeUnifiedAttention(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& tmp_workspace,
const paddle::Tensor& tmp_m,
const paddle::Tensor& tmp_d,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& block_indices,
const paddle::Tensor& num_blocks,
const paddle::Tensor& chunk_size,
const paddle::Tensor& set_max_lengths,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& sinks,
paddle::Tensor& fmha_out,
const std::string& cache_quant_type,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const int max_tokens_per_batch,
const bool causal,
const int sliding_window);

void ConfigForAttention(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
paddle::Tensor& block_indices, // Inplace
paddle::Tensor& num_blocks, // Inplace
paddle::Tensor& chunk_size, // Inplace
paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU
const std::string cache_quant_type,
const int group_size,
const int kv_num_heads,
const int max_tokens_per_batch);

std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
Expand Down Expand Up @@ -1943,4 +2021,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("per_token_group_fp8_quant",
&PerTokenGroupQuantFp8,
"per_token_group_quant_fp8");

/**
* decoder_write_cache_with_rope.cu
* decoder_write_cache_with_rope
*/
m.def("decoder_write_cache_with_rope",
&DecoderWriteCacheWithRoPE,
"decoder write cache with RoPE function");

/**
* decode_unified_attention.cu
* decode_unified_attention
*/
m.def("decode_unified_attention",
&DecodeUnifiedAttention,
"decoder append attention function");

/**
* config_for_attention.cu
* config_for_attention
*/
m.def("config_for_attention",
&ConfigForAttention,
"config for attention function");
}
Loading
Loading