Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
const paddle::Tensor& slot_mapping,
const paddle::optional<paddle::Tensor>& kv_signal_data,
cudaStream_t& stream,
const std::string& cache_quant_type_str,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
Expand All @@ -50,27 +51,51 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);

using CT = DataType_;

prefill_absorb_cache_kernel<DataType_, PackSize, CT>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
slot_mapping.data<int64_t>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
if (cache_quant_type_str == "cache_fp8") {
using CT = __nv_fp8_e4m3;
prefill_absorb_cache_kernel<DataType_, PackSize, CT>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<CT*>(kv_cache->data<uint8_t>()),
block_tables.data<int>(),
slot_mapping.data<int64_t>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
} else if (cache_quant_type_str == "none") {
prefill_absorb_cache_kernel<DataType_, PackSize, DataType_>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
slot_mapping.data<int64_t>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
} else {
PD_THROW("Unsupported cache_quant_type_str type: %s.",
cache_quant_type_str.c_str());
}

const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
Expand Down Expand Up @@ -142,6 +167,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
slot_mapping,
kv_signal_data,
stream,
cache_quant_type_str,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
Expand All @@ -157,6 +183,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
slot_mapping,
kv_signal_data,
stream,
cache_quant_type_str,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
Expand Down
30 changes: 28 additions & 2 deletions custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ __global__ void prefill_absorb_cache_kernel(
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
using StoreT = AlignedVector<CT, VecSize>;
StoreT dst_vec;

int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
Expand Down Expand Up @@ -227,7 +229,20 @@ __global__ void prefill_absorb_cache_kernel(
hi * block_size * all_size + block_offset * all_size + h_bias;
const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);

if constexpr (std::is_same_v<CT, __nv_fp8_e4m3>) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 FP8 量化未使用 scale 因子

当前实现等价于 scale=1.0 的静态量化(仅 clamp 至 FP8 e4m3 的 ±448 范围)。若 MLA KV cache 的激活值实际分布在较小量级(如 ±10),FP8 e4m3 在该范围内只有约 4 个指数级别,精度损失可能不可忽略。请确认:

  1. 是否有针对 DeepSeek-R1 等目标模型的量化精度对比数据?
  2. 是否刻意省略 scale(如激活已经过归一化处理)?建议在 PR 描述的 Accuracy Tests 段补充说明。

for (int i = 0; i < VecSize; i++) {
float quant_value = (float)(src_vec[i]);
quant_value = quant_value > 448.0f ? 448.0f : quant_value;
quant_value = quant_value < -448.0f ? -448.0f : quant_value;
dst_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value);
}

Store<CT, VecSize>(dst_vec, &kv_cache[tgt_idx]);
} else {
Store<CT, VecSize>(src_vec, &kv_cache[tgt_idx]);
}

} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
Expand All @@ -238,7 +253,18 @@ __global__ void prefill_absorb_cache_kernel(
h_bias;
const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);

if constexpr (std::is_same_v<CT, __nv_fp8_e4m3>) {
for (int i = 0; i < VecSize; i++) {
float quant_value = (float)(src_vec[i]);
quant_value = quant_value > 448.0f ? 448.0f : quant_value;
quant_value = quant_value < -448.0f ? -448.0f : quant_value;
dst_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value);
}
Store<CT, VecSize>(dst_vec, &kv_cache[tgt_idx]);
} else {
Store<CT, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ def forward_mixed(
metadata.block_tables,
forward_meta.slot_mapping,
metadata.kv_signal_data_list[layer.layer_id],
"none",
getattr(layer, "cache_quant_type_str", "none"),
)

# Prefill branch: k is not None
Expand Down Expand Up @@ -947,11 +947,13 @@ def forward_mixed(
@staticmethod
def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale):

# decoder_q = decoder_q.cast(paddle.float8_e4m3fn)
# latent_cache = latent_cache.cast(paddle.float8_e4m3fn)
assert latent_cache.dtype in [paddle.bfloat16, paddle.uint8], latent_cache.dtype
use_fp8_cache_kv = latent_cache.dtype == paddle.uint8
if use_fp8_cache_kv:
decoder_q = decoder_q.cast(paddle.float8_e4m3fn)
latent_cache = latent_cache.view(paddle.float8_e4m3fn)

assert decoder_q.dtype == latent_cache.dtype
q_dtype = decoder_q.dtype

page_size = latent_cache.shape[2]
q_num_heads = decoder_q.shape[2]
Expand Down Expand Up @@ -998,11 +1000,16 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft
softmax_scale = attn_softmax_scale
output_scale = 1.0

from mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16

# from mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8
if use_fp8_cache_kv:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 compiled_mla 全局缓存不区分 FP8/FP16 kernel 类型

当前逻辑:compiled_mla 只在 None 时编译一次并持久复用。若在同一进程内曾以 use_fp8_cache_kv=False(FP16 kernel)完成初始化,后续以 use_fp8_cache_kv=True 调用时,compiled_mla 仍指向 FP16 版本,传入 FP8 tensor 会引发 dtype mismatch 运行时错误(反之亦然)。

建议修复:使用独立变量或 dict 分别缓存两个 kernel 的编译结果:

global compiled_mla_fp8, compiled_mla_fp16
if use_fp8_cache_kv:
    if compiled_mla_fp8 is None:
        compiled_mla_fp8 = cute.compile(mla, ...)
    compiled_mla_fp8(...)
else:
    if compiled_mla_fp16 is None:
        compiled_mla_fp16 = cute.compile(mla, ...)
    compiled_mla_fp16(...)

from mla_decode_fp8 import (
BlackwellMultiHeadLatentAttentionForwardFP8 as kernel,
)
else:
from mla_decode_fp16 import (
BlackwellMultiHeadLatentAttentionForwardFP16 as kernel,
)

mla = BlackwellMultiHeadLatentAttentionForwardFP16(
mla = kernel(
cutlass.Float32,
cutlass.Float32,
mma_qk_tiler_mn=(128, 128),
Expand Down Expand Up @@ -1057,7 +1064,7 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft
stream,
)

if q_dtype == paddle.float8_e4m3fn:
if use_fp8_cache_kv:
paddle_output = paddle_output.cast("bfloat16")
return paddle_output

Expand Down
Loading