Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 14 additions & 179 deletions custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
#include "remote_cache_kv_ipc.h"

template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillMLAWriteCache(
const AppendAttnMetaData& meta_data,
std::vector<paddle::Tensor> MLAWriteCache(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
Expand All @@ -36,15 +35,16 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto num_tokens = meta_data.token_nums;
auto block_size = meta_data.block_size;
auto nope_size = meta_data.head_dims_v;
auto all_size = meta_data.head_dims;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_cache_dims = (*kv_cache).dims();
auto max_blocks_per_seq = block_tables.dims()[1];
auto num_tokens = kv_nope_dims[0];
auto block_size = kv_cache_dims[2];
auto kv_num_heads = kv_cache_dims[1];
auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / kv_num_heads;
auto all_size = kv_cache_dims[3];
int pe_size = all_size - nope_size;
auto kv_num_heads = meta_data.kv_num_heads;
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;

constexpr int PackSize = 16 / sizeof(DataType_);
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
Expand Down Expand Up @@ -126,7 +126,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
return {};
}

std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
std::vector<paddle::Tensor> MLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
Expand All @@ -139,24 +139,9 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;

meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = seq_lens_decoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
return MLAWriteCache<paddle::DataType::BFLOAT16>(
kv_nope,
kv_pe,
seq_lens,
Expand All @@ -171,8 +156,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
return MLAWriteCache<paddle::DataType::FLOAT16>(
kv_nope,
kv_pe,
seq_lens,
Expand All @@ -190,142 +174,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
return {};
}

template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto bsz = meta_data.batch_size;
auto token_num = meta_data.token_nums;
auto block_size = meta_data.block_size;
auto nope_size = meta_data.head_dims_v;
auto all_size = meta_data.head_dims;
int pe_size = all_size - nope_size;
auto kv_num_heads = meta_data.kv_num_heads;
constexpr int PackSize = 16 / sizeof(DataType_);
const int blocksize = 128;
int grid_size = 1;

if (speculate_decoder) {
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
} else {
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
}
return {};
}

std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const bool speculate_decoder) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;

meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = seq_lens_encoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}

PD_BUILD_STATIC_OP(prefill_mla_write_cache)
PD_BUILD_STATIC_OP(mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
Expand All @@ -339,18 +188,4 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string"})
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));

PD_BUILD_STATIC_OP(decode_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
"seq_lens",
"seq_lens_encoder",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string", "speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));
.SetKernelFn(PD_KERNEL(MLAWriteCacheKernel));
24 changes: 4 additions & 20 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,19 +628,7 @@ void GetPositionIdsAndMaskEncoderBatch(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids);

std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const bool speculate_decoder);

std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
std::vector<paddle::Tensor> MLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
Expand Down Expand Up @@ -1777,13 +1765,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("scales"),
py::arg("scale_ub"));
#ifdef ENABLE_SM80_EXT_OPS
m.def("decode_mla_write_cache",
&DecodeMLAWriteCacheKernel,
"decode_mla_write_cache function");

m.def("prefill_mla_write_cache",
&PrefillMLAWriteCacheKernel,
"prefill_mla_write_cache function");
m.def("mla_write_cache",
&MLAWriteCacheKernel,
"mla_write_cache function");
#endif

m.def("fused_rotary_position_encoding",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@
compiled_mla = None
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
decode_mla_write_cache,
multi_head_latent_attention,
prefill_mla_write_cache,
mla_write_cache,

This comment was marked as outdated.

)

if TYPE_CHECKING:
Expand Down Expand Up @@ -849,7 +848,7 @@ def forward_mixed(
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None

assert k_pe.shape[0] == compressed_kv.shape[0]
prefill_mla_write_cache(
mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
Expand Down