diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu index 7e77f1ba9e3..c7a71fcc48b 100644 --- a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu @@ -18,8 +18,7 @@ #include "remote_cache_kv_ipc.h" template -std::vector PrefillMLAWriteCache( - const AppendAttnMetaData& meta_data, +std::vector MLAWriteCache( const paddle::Tensor& kv_nope, const paddle::Tensor& kv_pe, const paddle::Tensor& seq_lens, @@ -36,15 +35,16 @@ std::vector PrefillMLAWriteCache( typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; - auto max_blocks_per_seq = meta_data.max_blocks_per_seq; - auto num_tokens = meta_data.token_nums; - auto block_size = meta_data.block_size; - auto nope_size = meta_data.head_dims_v; - auto all_size = meta_data.head_dims; + const auto& kv_nope_dims = kv_nope.dims(); + const auto& kv_cache_dims = (*kv_cache).dims(); + auto max_blocks_per_seq = block_tables.dims()[1]; + auto num_tokens = kv_nope_dims[0]; + auto block_size = kv_cache_dims[2]; + auto kv_num_heads = kv_cache_dims[1]; + auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / kv_num_heads; + auto all_size = kv_cache_dims[3]; int pe_size = all_size - nope_size; - auto kv_num_heads = meta_data.kv_num_heads; const uint32_t elem_nums = num_tokens * kv_num_heads * all_size; - constexpr int PackSize = 16 / sizeof(DataType_); const int pack_num = elem_nums / PackSize; const int blocksize = 128; @@ -126,7 +126,7 @@ std::vector PrefillMLAWriteCache( return {}; } -std::vector PrefillMLAWriteCacheKernel( +std::vector MLAWriteCacheKernel( const paddle::Tensor& kv_nope, const paddle::Tensor& kv_pe, const paddle::Tensor& kv_cache, @@ -139,24 +139,9 @@ std::vector PrefillMLAWriteCacheKernel( const paddle::optional& kv_signal_data, const std::string& cache_quant_type_str) { cudaStream_t stream = kv_pe.stream(); - AppendAttnMetaData meta_data; - const auto& kv_nope_dims = kv_nope.dims(); - const auto& kv_pe_dims = kv_pe.dims(); - const auto& kv_cache_dims = kv_cache.dims(); - meta_data.kv_num_heads = kv_cache_dims[1]; - const auto nope_size = - kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; - meta_data.token_nums = kv_nope_dims[0]; - meta_data.head_dims = kv_cache_dims[3]; - meta_data.head_dims_v = nope_size; - - meta_data.max_blocks_per_seq = block_tables.dims()[1]; - meta_data.block_size = kv_cache_dims[2]; - meta_data.batch_size = seq_lens_decoder.dims()[0]; switch (kv_pe.dtype()) { case paddle::DataType::BFLOAT16: { - return PrefillMLAWriteCache( - meta_data, + return MLAWriteCache( kv_nope, kv_pe, seq_lens, @@ -171,8 +156,7 @@ std::vector PrefillMLAWriteCacheKernel( const_cast(&kv_cache)); } case paddle::DataType::FLOAT16: { - return PrefillMLAWriteCache( - meta_data, + return MLAWriteCache( kv_nope, kv_pe, seq_lens, @@ -190,142 +174,7 @@ std::vector PrefillMLAWriteCacheKernel( return {}; } -template -std::vector DecodeMLAWriteCache( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& kv_nope, - const paddle::Tensor& kv_pe, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const bool speculate_decoder, - cudaStream_t& stream, - paddle::Tensor* kv_cache) { - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; - - auto max_blocks_per_seq = meta_data.max_blocks_per_seq; - auto bsz = meta_data.batch_size; - auto token_num = meta_data.token_nums; - auto block_size = meta_data.block_size; - auto nope_size = meta_data.head_dims_v; - auto all_size = meta_data.head_dims; - int pe_size = all_size - nope_size; - auto kv_num_heads = meta_data.kv_num_heads; - constexpr int PackSize = 16 / sizeof(DataType_); - const int blocksize = 128; - int grid_size = 1; - - if (speculate_decoder) { - const uint32_t elem_nums = token_num * kv_num_heads * all_size; - const int pack_num = elem_nums / PackSize; - GetNumBlocks<128>(pack_num, &grid_size); - speculate_decode_absorb_cache_kernel - <<>>( - reinterpret_cast( - const_cast(kv_nope.data())), - reinterpret_cast( - const_cast(kv_pe.data())), - reinterpret_cast(kv_cache->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - max_blocks_per_seq, - kv_num_heads, - nope_size, - pe_size, - block_size, - elem_nums); - } else { - const uint32_t elem_nums = bsz * kv_num_heads * all_size; - const int pack_num = elem_nums / PackSize; - GetNumBlocks<128>(pack_num, &grid_size); - decode_absorb_cache_kernel - <<>>( - reinterpret_cast( - const_cast(kv_nope.data())), - reinterpret_cast( - const_cast(kv_pe.data())), - reinterpret_cast(kv_cache->data()), - block_tables.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - max_blocks_per_seq, - kv_num_heads, - nope_size, - pe_size, - block_size, - elem_nums); - } - return {}; -} - -std::vector DecodeMLAWriteCacheKernel( - const paddle::Tensor& kv_nope, - const paddle::Tensor& kv_pe, - const paddle::Tensor& kv_cache, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const std::string& cache_quant_type_str, - const bool speculate_decoder) { - cudaStream_t stream = kv_pe.stream(); - AppendAttnMetaData meta_data; - const auto& kv_nope_dims = kv_nope.dims(); - const auto& kv_pe_dims = kv_pe.dims(); - const auto& kv_cache_dims = kv_cache.dims(); - meta_data.kv_num_heads = kv_cache_dims[1]; - const auto nope_size = - kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; - meta_data.token_nums = kv_nope_dims[0]; - meta_data.head_dims = kv_cache_dims[3]; - meta_data.head_dims_v = nope_size; - - meta_data.max_blocks_per_seq = block_tables.dims()[1]; - meta_data.block_size = kv_cache_dims[2]; - meta_data.batch_size = seq_lens_encoder.dims()[0]; - switch (kv_pe.dtype()) { - case paddle::DataType::BFLOAT16: { - return DecodeMLAWriteCache( - meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - speculate_decoder, - stream, - const_cast(&kv_cache)); - } - case paddle::DataType::FLOAT16: { - return DecodeMLAWriteCache( - meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - speculate_decoder, - stream, - const_cast(&kv_cache)); - } - } - return {}; -} - -PD_BUILD_STATIC_OP(prefill_mla_write_cache) +PD_BUILD_STATIC_OP(mla_write_cache) .Inputs({"kv_nope", "kv_pe", "kv_cache", @@ -339,18 +188,4 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache) .Outputs({"kv_cache_out"}) .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) .Attrs({"cache_quant_type_str: std::string"}) - .SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel)); - -PD_BUILD_STATIC_OP(decode_mla_write_cache) - .Inputs({"kv_nope", - "kv_pe", - "kv_cache", - "seq_lens", - "seq_lens_encoder", - "batch_id_per_token", - "cu_seqlens_q", - "block_tables"}) - .Outputs({"kv_cache_out"}) - .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) - .Attrs({"cache_quant_type_str: std::string", "speculate_decoder: bool"}) - .SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel)); + .SetKernelFn(PD_KERNEL(MLAWriteCacheKernel)); diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 911695e7412..230dce23cd7 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -628,19 +628,7 @@ void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& position_ids); -std::vector DecodeMLAWriteCacheKernel( - const paddle::Tensor& kv_nope, - const paddle::Tensor& kv_pe, - const paddle::Tensor& kv_cache, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const std::string& cache_quant_type_str, - const bool speculate_decoder); - -std::vector PrefillMLAWriteCacheKernel( +std::vector MLAWriteCacheKernel( const paddle::Tensor& kv_nope, const paddle::Tensor& kv_pe, const paddle::Tensor& kv_cache, @@ -1777,13 +1765,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("scales"), py::arg("scale_ub")); #ifdef ENABLE_SM80_EXT_OPS - m.def("decode_mla_write_cache", - &DecodeMLAWriteCacheKernel, - "decode_mla_write_cache function"); - - m.def("prefill_mla_write_cache", - &PrefillMLAWriteCacheKernel, - "prefill_mla_write_cache function"); + m.def("mla_write_cache", + &MLAWriteCacheKernel, + "mla_write_cache function"); #endif m.def("fused_rotary_position_encoding", diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index ba1ef6fab0c..848d2fcbd2a 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -45,9 +45,8 @@ compiled_mla = None if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import ( - decode_mla_write_cache, multi_head_latent_attention, - prefill_mla_write_cache, + mla_write_cache, ) if TYPE_CHECKING: @@ -849,7 +848,7 @@ def forward_mixed( latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None assert k_pe.shape[0] == compressed_kv.shape[0] - prefill_mla_write_cache( + mla_write_cache( compressed_kv, k_pe, latent_cache,