Skip to content

Commit a095d6f

Browse files
[Cherry-Pick][Feature] support decode unified attention for mix(#7688) (#7729)
* support c8 decode attention * support c16 attention && backend * opt kernel * fix * opt larger batch * inplace out * fix input_batch && remove fast_math * fix xpu * fix bug * fix ci * opt and fix mtp * fix merge * clean code * fix merge * update * update test * fix test * fix test * opt buffer * fix conflict --------- Co-authored-by: Jiaxin Sui <95567040+plusNew001@users.noreply.github.com>
1 parent bf0dace commit a095d6f

28 files changed

Lines changed: 8172 additions & 63 deletions

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,84 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
189189
const int sliding_window,
190190
const int sink_size);
191191

192+
std::vector<paddle::Tensor> DecoderWriteCacheWithRoPE(
193+
const paddle::Tensor& qkv,
194+
const paddle::Tensor& key_cache,
195+
const paddle::Tensor& value_cache,
196+
const paddle::Tensor& seq_lens_encoder,
197+
const paddle::Tensor& seq_lens_decoder,
198+
const paddle::Tensor& seq_lens_this_time,
199+
const paddle::Tensor& batch_id_per_token,
200+
const paddle::Tensor& cu_seqlens_q,
201+
const paddle::Tensor& block_tables,
202+
const paddle::Tensor& set_max_lengths,
203+
const paddle::optional<paddle::Tensor>& rotary_embs,
204+
const paddle::optional<paddle::Tensor>& qkv_bias,
205+
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
206+
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
207+
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
208+
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
209+
const paddle::optional<paddle::Tensor>& cache_k_zp,
210+
const paddle::optional<paddle::Tensor>& cache_v_zp,
211+
const paddle::optional<paddle::Tensor>& kv_signal_data,
212+
const paddle::optional<paddle::Tensor>& q_norm_weight,
213+
const paddle::optional<paddle::Tensor>& k_norm_weight,
214+
const float rms_norm_eps,
215+
const std::string& cache_quant_type_str,
216+
const bool use_neox_rotary_style,
217+
const bool rope_3d,
218+
const int max_input_length,
219+
const float quant_max_bound,
220+
const float quant_min_bound,
221+
const bool speculate_decoder);
222+
223+
std::vector<paddle::Tensor> DecodeUnifiedAttention(
224+
const paddle::Tensor& qkv,
225+
const paddle::Tensor& key_cache,
226+
const paddle::Tensor& value_cache,
227+
const paddle::Tensor& tmp_workspace,
228+
const paddle::Tensor& tmp_m,
229+
const paddle::Tensor& tmp_d,
230+
const paddle::Tensor& seq_lens_encoder,
231+
const paddle::Tensor& seq_lens_decoder,
232+
const paddle::Tensor& seq_lens_this_time,
233+
const paddle::Tensor& batch_id_per_token,
234+
const paddle::Tensor& cu_seqlens_q,
235+
const paddle::Tensor& block_tables,
236+
const paddle::Tensor& block_indices,
237+
const paddle::Tensor& num_blocks,
238+
const paddle::Tensor& chunk_size,
239+
const paddle::Tensor& set_max_lengths,
240+
const paddle::optional<paddle::Tensor>& attn_mask,
241+
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
242+
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
243+
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
244+
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
245+
const paddle::optional<paddle::Tensor>& cache_k_zp,
246+
const paddle::optional<paddle::Tensor>& cache_v_zp,
247+
const paddle::optional<paddle::Tensor>& mask_offset,
248+
const paddle::optional<paddle::Tensor>& sinks,
249+
paddle::Tensor& fmha_out,
250+
const std::string& cache_quant_type,
251+
const int max_input_length,
252+
const float quant_max_bound,
253+
const float quant_min_bound,
254+
const int max_tokens_per_batch,
255+
const bool causal,
256+
const int sliding_window);
257+
258+
void ConfigForAttention(const paddle::Tensor& seq_lens_encoder,
259+
const paddle::Tensor& seq_lens_decoder,
260+
const paddle::Tensor& seq_lens_this_time,
261+
paddle::Tensor& block_indices, // Inplace
262+
paddle::Tensor& num_blocks, // Inplace
263+
paddle::Tensor& chunk_size, // Inplace
264+
paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU
265+
const std::string cache_quant_type,
266+
const int group_size,
267+
const int kv_num_heads,
268+
const int max_tokens_per_batch);
269+
192270
std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
193271
const paddle::Tensor& qkv,
194272
const paddle::Tensor& key_cache,
@@ -1962,4 +2040,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
19622040
m.def("per_token_group_fp8_quant",
19632041
&PerTokenGroupQuantFp8,
19642042
"per_token_group_quant_fp8");
2043+
2044+
/**
2045+
* decoder_write_cache_with_rope.cu
2046+
* decoder_write_cache_with_rope
2047+
*/
2048+
m.def("decoder_write_cache_with_rope",
2049+
&DecoderWriteCacheWithRoPE,
2050+
"decoder write cache with RoPE function");
2051+
2052+
/**
2053+
* decode_unified_attention.cu
2054+
* decode_unified_attention
2055+
*/
2056+
m.def("decode_unified_attention",
2057+
&DecodeUnifiedAttention,
2058+
"decoder append attention function");
2059+
2060+
/**
2061+
* config_for_attention.cu
2062+
* config_for_attention
2063+
*/
2064+
m.def("config_for_attention",
2065+
&ConfigForAttention,
2066+
"config for attention function");
19652067
}

0 commit comments

Comments
 (0)