Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& slot_mapping,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
Expand All @@ -50,19 +50,21 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);

prefill_absorb_cache_kernel<DataType_, PackSize>
using CT = DataType_;

prefill_absorb_cache_kernel<DataType_, PackSize, CT>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
slot_mapping.data<int64_t>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
Expand Down Expand Up @@ -108,9 +110,9 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& slot_mapping,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str,
const int max_seq_len) {
const std::string& cache_quant_type_str) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
Expand All @@ -137,8 +139,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
batch_id_per_token,
cu_seqlens_q,
block_tables,
slot_mapping,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
Expand All @@ -152,8 +154,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
batch_id_per_token,
cu_seqlens_q,
block_tables,
slot_mapping,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
Expand All @@ -171,7 +173,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
Expand Down Expand Up @@ -207,7 +208,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
Expand All @@ -229,7 +229,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
Expand All @@ -250,7 +249,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool speculate_decoder) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
Expand Down Expand Up @@ -278,7 +276,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
Expand All @@ -293,7 +290,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
Expand All @@ -311,10 +307,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
"batch_id_per_token",
"cu_seqlens_q",
"block_tables",
"slot_mapping",
paddle::Optional("kv_signal_data")})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"})
.Attrs({"cache_quant_type_str: std::string"})
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));

PD_BUILD_STATIC_OP(decode_mla_write_cache)
Expand All @@ -328,7 +325,5 @@ PD_BUILD_STATIC_OP(decode_mla_write_cache)
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int",
"speculate_decoder: bool"})
.Attrs({"cache_quant_type_str: std::string", "speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));
19 changes: 13 additions & 6 deletions custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ __global__ void decode_absorb_cache_kernel(
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
Expand Down Expand Up @@ -98,7 +97,6 @@ __global__ void speculate_decode_absorb_cache_kernel(
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
Expand Down Expand Up @@ -168,18 +166,18 @@ __global__ void speculate_decode_absorb_cache_kernel(
}
}

template <typename T, int VecSize = 1>
template <typename T, int VecSize = 1, typename CT = T>
__global__ void prefill_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
CT* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int64_t* __restrict__ slot_mapping,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_decoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
Expand All @@ -201,7 +199,8 @@ __global__ void prefill_absorb_cache_kernel(
linear_index += step) {
const uint32_t token_idx = linear_index / hidden_size;
const uint32_t bias = linear_index % hidden_size;
const uint32_t ori_bi = batch_id_per_token[token_idx];
const int32_t ori_bi = batch_id_per_token[token_idx];
if (ori_bi == -1) continue;
if (seq_lens[ori_bi] == 0) continue;
const uint32_t ori_seq_id =
(token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
Expand All @@ -211,6 +210,14 @@ __global__ void prefill_absorb_cache_kernel(
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
const uint32_t block_offset = ori_seq_id % block_size;

const int32_t block_idx1 = slot_mapping[token_idx] / block_size;
if (block_idx1 != block_idx) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug CUDA kernel 中残留调试代码:printf + asm volatile("trap;") 在生产环境会导致 GPU 崩溃。

asm volatile("trap;") 相当于 GPU 上的 abort(),一旦触发将终止整个 CUDA context,导致服务不可用。这段代码明显是用于对齐验证 slot_mappingblock_tables 两种寻址路径是否一致的临时调试代码,不应合入主干。

建议修复:

  1. 验证通过后,直接删除整个 block(第 213-219 行)
  2. 如需保留作为 debug 模式开关,应改为 #ifdef DEBUG_MLA_CACHE ... #endif

printf("block_idx1 %d != block_idx %d\n", block_idx1, block_idx);
printf("token_idx %d\n", token_idx);
printf("slot_mapping %d\n", slot_mapping[token_idx]);
asm volatile("trap;");

This comment was marked as outdated.

}

if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
Expand Down
5 changes: 2 additions & 3 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool speculate_decoder);

std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
Expand All @@ -568,9 +567,9 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& slot_mapping,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str,
const int max_seq_len);
const std::string& cache_quant_type_str);

void FusedRotaryPositionEncoding(
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
Expand Down
60 changes: 31 additions & 29 deletions fastdeploy/model_executor/layers/attention/mla_attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,6 @@ def forward_extend(
metadata.block_tables,
metadata.kv_signal_data_list[layer.layer_id],
"none",
getattr(forward_meta, "max_input_length", -1),
)

fmha_out = self.flash_attn_func(
Expand Down Expand Up @@ -720,7 +719,6 @@ def forward_decode(
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
speculate_decoder,
)

Expand Down Expand Up @@ -799,21 +797,23 @@ def forward_mixed(

latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None

assert k_pe.shape[0] == compressed_kv.shape[0]
prefill_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
forward_meta.slot_mapping,
metadata.kv_signal_data_list[layer.layer_id],
"none",
)

# Prefill branch: k is not None
if k is not None:
prefill_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.kv_signal_data_list[layer.layer_id],
"none",
self.max_seq_len,
)

if self.prop.major == 10:
# TODO support FA4
Expand Down Expand Up @@ -845,20 +845,6 @@ def forward_mixed(

# Decode branch: k is None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 forward_mixed 中 decode 分支移除了 decode_mla_write_cache 调用,但 decode token 的 KV 写缓存通过提升到分支前的 prefill_mla_write_cache 统一处理。

请确认:prefill_mla_write_cachebatch_id_per_token[token_idx] == -1(decode token 在 mixed batch 中被标记为 -1)时通过 if (ori_bi == -1) continue; 正确跳过,但 decode token 的实际写缓存是否仍然有效?原 decode_mla_write_cache 使用 seq_lens_decoder/seq_lens_encoder 参数顺序,而新代码使用 seq_lens_this_time/seq_lens_decoder,语义是否一致?

if k is None:
decode_mla_write_cache(
compressed_kv,
k_pe,
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
speculate_decoder,
)

if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9:
assert self.num_heads <= 64, "paddle mla attention support failed"
if self.heads_need_padding:
Expand Down Expand Up @@ -961,6 +947,12 @@ def forward_mixed(
@staticmethod
def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale):

# decoder_q = decoder_q.cast(paddle.float8_e4m3fn)
# latent_cache = latent_cache.cast(paddle.float8_e4m3fn)

assert decoder_q.dtype == latent_cache.dtype

This comment was marked as outdated.

This comment was marked as outdated.

q_dtype = decoder_q.dtype

page_size = latent_cache.shape[2]
q_num_heads = decoder_q.shape[2]
assert decoder_q.shape[1:] == [1, q_num_heads, 576]
Expand Down Expand Up @@ -1008,6 +1000,8 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft

from mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16

# from mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8

mla = BlackwellMultiHeadLatentAttentionForwardFP16(
cutlass.Float32,
cutlass.Float32,
Expand Down Expand Up @@ -1063,10 +1057,18 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft
stream,
)

if q_dtype == paddle.float8_e4m3fn:
paddle_output = paddle_output.cast("bfloat16")
return paddle_output

This comment was marked as outdated.


@staticmethod
def flashmla_baseline(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale):

assert decoder_q.dtype == latent_cache.dtype

decoder_q = decoder_q.cast("bfloat16")
latent_cache = latent_cache.cast("bfloat16")

page_size = latent_cache.shape[2]
q_num_heads = decoder_q.shape[2]
assert decoder_q.shape[1:] == [1, q_num_heads, 576]
Expand Down
Loading
Loading