Skip to content

Commit b32b84e

Browse files
committed
support decode unified attention
1 parent 55eb3a6 commit b32b84e

34 files changed

Lines changed: 11192 additions & 65 deletions

custom_ops/gpu_ops/append_attn/utils.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct AppendAttnMetaData {
3131
};
3232

3333
__forceinline__ __host__ __device__ int div_up(int a, int b) {
34-
return (a + b - 1) / b;
34+
return a / b + (a % b != 0);
3535
}
3636

3737
enum PosEncMode { kNonePos, kRoPE, kAliBi };

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,84 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
189189
const int sliding_window,
190190
const int sink_size);
191191

192+
std::vector<paddle::Tensor> DecoderWriteCacheWithRoPE(
193+
const paddle::Tensor& qkv,
194+
const paddle::Tensor& key_cache,
195+
const paddle::Tensor& value_cache,
196+
const paddle::Tensor& seq_lens_encoder,
197+
const paddle::Tensor& seq_lens_decoder,
198+
const paddle::Tensor& seq_lens_this_time,
199+
const paddle::Tensor& batch_id_per_token,
200+
const paddle::Tensor& cu_seqlens_q,
201+
const paddle::Tensor& block_tables,
202+
const paddle::Tensor& set_max_lengths,
203+
const paddle::optional<paddle::Tensor>& rotary_embs,
204+
const paddle::optional<paddle::Tensor>& qkv_bias,
205+
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
206+
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
207+
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
208+
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
209+
const paddle::optional<paddle::Tensor>& cache_k_zp,
210+
const paddle::optional<paddle::Tensor>& cache_v_zp,
211+
const paddle::optional<paddle::Tensor>& kv_signal_data,
212+
const paddle::optional<paddle::Tensor>& q_norm_weight,
213+
const paddle::optional<paddle::Tensor>& k_norm_weight,
214+
const float rms_norm_eps,
215+
const std::string& cache_quant_type_str,
216+
const bool use_neox_rotary_style,
217+
const bool rope_3d,
218+
const int max_input_length,
219+
const float quant_max_bound,
220+
const float quant_min_bound,
221+
const bool speculate_decoder);
222+
223+
std::vector<paddle::Tensor> DecodeUnifiedAttention(
224+
const paddle::Tensor& qkv,
225+
const paddle::Tensor& key_cache,
226+
const paddle::Tensor& value_cache,
227+
const paddle::Tensor& tmp_workspace,
228+
const paddle::Tensor& tmp_m,
229+
const paddle::Tensor& tmp_d,
230+
const paddle::Tensor& seq_lens_encoder,
231+
const paddle::Tensor& seq_lens_decoder,
232+
const paddle::Tensor& seq_lens_this_time,
233+
const paddle::Tensor& batch_id_per_token,
234+
const paddle::Tensor& cu_seqlens_q,
235+
const paddle::Tensor& block_tables,
236+
const paddle::Tensor& block_indices,
237+
const paddle::Tensor& num_blocks,
238+
const paddle::Tensor& chunk_size,
239+
const paddle::Tensor& set_max_lengths,
240+
const paddle::optional<paddle::Tensor>& attn_mask,
241+
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
242+
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
243+
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
244+
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
245+
const paddle::optional<paddle::Tensor>& cache_k_zp,
246+
const paddle::optional<paddle::Tensor>& cache_v_zp,
247+
const paddle::optional<paddle::Tensor>& mask_offset,
248+
const paddle::optional<paddle::Tensor>& sinks,
249+
paddle::Tensor& fmha_out,
250+
const std::string& cache_quant_type,
251+
const int max_input_length,
252+
const float quant_max_bound,
253+
const float quant_min_bound,
254+
const int max_tokens_per_batch,
255+
const bool causal,
256+
const int sliding_window);
257+
258+
void ConfigForAttention(const paddle::Tensor& seq_lens_encoder,
259+
const paddle::Tensor& seq_lens_decoder,
260+
const paddle::Tensor& seq_lens_this_time,
261+
paddle::Tensor& block_indices, // Inplace
262+
paddle::Tensor& num_blocks, // Inplace
263+
paddle::Tensor& chunk_size, // Inplace
264+
paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU
265+
const std::string cache_quant_type,
266+
const int group_size,
267+
const int kv_num_heads,
268+
const int max_tokens_per_batch);
269+
192270
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
193271
const paddle::Tensor& qkv,
194272
const paddle::Tensor& key_cache,
@@ -1943,4 +2021,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
19432021
m.def("per_token_group_fp8_quant",
19442022
&PerTokenGroupQuantFp8,
19452023
"per_token_group_quant_fp8");
2024+
2025+
/**
2026+
* decoder_write_cache_with_rope.cu
2027+
* decoder_write_cache_with_rope
2028+
*/
2029+
m.def("decoder_write_cache_with_rope",
2030+
&DecoderWriteCacheWithRoPE,
2031+
"decoder write cache with RoPE function");
2032+
2033+
/**
2034+
* decode_unified_attention.cu
2035+
* decode_unified_attention
2036+
*/
2037+
m.def("decode_unified_attention",
2038+
&DecodeUnifiedAttention,
2039+
"decoder append attention function");
2040+
2041+
/**
2042+
* config_for_attention.cu
2043+
* config_for_attention
2044+
*/
2045+
m.def("config_for_attention",
2046+
&ConfigForAttention,
2047+
"config for attention function");
19462048
}

0 commit comments

Comments
 (0)