Skip to content

Commit 5348902

Browse files
[OPTIMIZE] remove decode_mla_write_cache in mla attention backend (#7834)
1 parent dad5a43 commit 5348902

6 files changed

Lines changed: 112 additions & 84 deletions

File tree

custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
2727
const paddle::Tensor& batch_id_per_token,
2828
const paddle::Tensor& cu_seqlens_q,
2929
const paddle::Tensor& block_tables,
30+
const paddle::Tensor& slot_mapping,
3031
const paddle::optional<paddle::Tensor>& kv_signal_data,
31-
const int max_seq_len,
3232
cudaStream_t& stream,
3333
paddle::Tensor* kv_cache) {
3434
typedef PDTraits<T> traits_;
@@ -50,19 +50,21 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
5050
int grid_size = 1;
5151
GetNumBlocks<128>(pack_num, &grid_size);
5252

53-
prefill_absorb_cache_kernel<DataType_, PackSize>
53+
using CT = DataType_;
54+
55+
prefill_absorb_cache_kernel<DataType_, PackSize, CT>
5456
<<<grid_size, blocksize, 0, stream>>>(
5557
reinterpret_cast<DataType_*>(
5658
const_cast<data_t*>(kv_nope.data<data_t>())),
5759
reinterpret_cast<DataType_*>(
5860
const_cast<data_t*>(kv_pe.data<data_t>())),
5961
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
6062
block_tables.data<int>(),
63+
slot_mapping.data<int64_t>(),
6164
batch_id_per_token.data<int>(),
6265
cu_seqlens_q.data<int>(),
6366
seq_lens.data<int>(),
6467
seq_lens_decoder.data<int>(),
65-
max_seq_len,
6668
max_blocks_per_seq,
6769
kv_num_heads,
6870
nope_size,
@@ -108,9 +110,9 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
108110
const paddle::Tensor& batch_id_per_token,
109111
const paddle::Tensor& cu_seqlens_q,
110112
const paddle::Tensor& block_tables,
113+
const paddle::Tensor& slot_mapping,
111114
const paddle::optional<paddle::Tensor>& kv_signal_data,
112-
const std::string& cache_quant_type_str,
113-
const int max_seq_len) {
115+
const std::string& cache_quant_type_str) {
114116
cudaStream_t stream = kv_pe.stream();
115117
AppendAttnMetaData meta_data;
116118
const auto& kv_nope_dims = kv_nope.dims();
@@ -137,8 +139,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
137139
batch_id_per_token,
138140
cu_seqlens_q,
139141
block_tables,
142+
slot_mapping,
140143
kv_signal_data,
141-
max_seq_len,
142144
stream,
143145
const_cast<paddle::Tensor*>(&kv_cache));
144146
}
@@ -152,8 +154,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
152154
batch_id_per_token,
153155
cu_seqlens_q,
154156
block_tables,
157+
slot_mapping,
155158
kv_signal_data,
156-
max_seq_len,
157159
stream,
158160
const_cast<paddle::Tensor*>(&kv_cache));
159161
}
@@ -171,7 +173,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
171173
const paddle::Tensor& batch_id_per_token,
172174
const paddle::Tensor& cu_seqlens_q,
173175
const paddle::Tensor& block_tables,
174-
const int max_seq_len,
175176
const bool speculate_decoder,
176177
cudaStream_t& stream,
177178
paddle::Tensor* kv_cache) {
@@ -207,7 +208,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
207208
cu_seqlens_q.data<int>(),
208209
seq_lens.data<int>(),
209210
seq_lens_encoder.data<int>(),
210-
max_seq_len,
211211
max_blocks_per_seq,
212212
kv_num_heads,
213213
nope_size,
@@ -229,7 +229,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
229229
cu_seqlens_q.data<int>(),
230230
seq_lens.data<int>(),
231231
seq_lens_encoder.data<int>(),
232-
max_seq_len,
233232
max_blocks_per_seq,
234233
kv_num_heads,
235234
nope_size,
@@ -250,7 +249,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
250249
const paddle::Tensor& cu_seqlens_q,
251250
const paddle::Tensor& block_tables,
252251
const std::string& cache_quant_type_str,
253-
const int max_seq_len,
254252
const bool speculate_decoder) {
255253
cudaStream_t stream = kv_pe.stream();
256254
AppendAttnMetaData meta_data;
@@ -278,7 +276,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
278276
batch_id_per_token,
279277
cu_seqlens_q,
280278
block_tables,
281-
max_seq_len,
282279
speculate_decoder,
283280
stream,
284281
const_cast<paddle::Tensor*>(&kv_cache));
@@ -293,7 +290,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
293290
batch_id_per_token,
294291
cu_seqlens_q,
295292
block_tables,
296-
max_seq_len,
297293
speculate_decoder,
298294
stream,
299295
const_cast<paddle::Tensor*>(&kv_cache));
@@ -311,10 +307,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
311307
"batch_id_per_token",
312308
"cu_seqlens_q",
313309
"block_tables",
310+
"slot_mapping",
314311
paddle::Optional("kv_signal_data")})
315312
.Outputs({"kv_cache_out"})
316313
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
317-
.Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"})
314+
.Attrs({"cache_quant_type_str: std::string"})
318315
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
319316

320317
PD_BUILD_STATIC_OP(decode_mla_write_cache)
@@ -328,7 +325,5 @@ PD_BUILD_STATIC_OP(decode_mla_write_cache)
328325
"block_tables"})
329326
.Outputs({"kv_cache_out"})
330327
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
331-
.Attrs({"cache_quant_type_str: std::string",
332-
"max_seq_len: int",
333-
"speculate_decoder: bool"})
328+
.Attrs({"cache_quant_type_str: std::string", "speculate_decoder: bool"})
334329
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));

custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ __global__ void decode_absorb_cache_kernel(
2727
const int* __restrict__ cu_seqlens_q,
2828
const int* __restrict__ seq_lens, // [bsz]
2929
const int* __restrict__ seq_lens_encoder, // [bsz]
30-
const int max_seq_len,
3130
const int max_blocks_per_seq,
3231
const int kv_num_heads,
3332
const int nope_size,
@@ -98,7 +97,6 @@ __global__ void speculate_decode_absorb_cache_kernel(
9897
const int* __restrict__ cu_seqlens_q,
9998
const int* __restrict__ seq_lens, // [bsz]
10099
const int* __restrict__ seq_lens_encoder, // [bsz]
101-
const int max_seq_len,
102100
const int max_blocks_per_seq,
103101
const int kv_num_heads,
104102
const int nope_size,
@@ -168,18 +166,18 @@ __global__ void speculate_decode_absorb_cache_kernel(
168166
}
169167
}
170168

171-
template <typename T, int VecSize = 1>
169+
template <typename T, int VecSize = 1, typename CT = T>
172170
__global__ void prefill_absorb_cache_kernel(
173171
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
174172
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
175-
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
173+
CT* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
176174
// nope_size]
177175
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
176+
const int64_t* __restrict__ slot_mapping,
178177
const int* __restrict__ batch_id_per_token,
179178
const int* __restrict__ cu_seqlens_q,
180179
const int* __restrict__ seq_lens, // [bsz]
181180
const int* __restrict__ seq_lens_decoder, // [bsz]
182-
const int max_seq_len,
183181
const int max_blocks_per_seq,
184182
const int kv_num_heads,
185183
const int nope_size,
@@ -201,7 +199,8 @@ __global__ void prefill_absorb_cache_kernel(
201199
linear_index += step) {
202200
const uint32_t token_idx = linear_index / hidden_size;
203201
const uint32_t bias = linear_index % hidden_size;
204-
const uint32_t ori_bi = batch_id_per_token[token_idx];
202+
const int32_t ori_bi = batch_id_per_token[token_idx];
203+
if (ori_bi == -1) continue;
205204
if (seq_lens[ori_bi] == 0) continue;
206205
const uint32_t ori_seq_id =
207206
(token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
@@ -211,6 +210,14 @@ __global__ void prefill_absorb_cache_kernel(
211210
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
212211
const uint32_t block_offset = ori_seq_id % block_size;
213212

213+
const int32_t block_idx1 = slot_mapping[token_idx] / block_size;
214+
if (block_idx1 != block_idx) {
215+
printf("block_idx1 %d != block_idx %d\n", block_idx1, block_idx);
216+
printf("token_idx %d\n", token_idx);
217+
printf("slot_mapping %d\n", slot_mapping[token_idx]);
218+
asm volatile("trap;");
219+
}
220+
214221
if (bias < nope_hidden_size) { // pe
215222
const uint32_t inner_bias = bias;
216223
const uint32_t hi = inner_bias / nope_size;

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,6 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
556556
const paddle::Tensor& cu_seqlens_q,
557557
const paddle::Tensor& block_tables,
558558
const std::string& cache_quant_type_str,
559-
const int max_seq_len,
560559
const bool speculate_decoder);
561560

562561
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
@@ -568,9 +567,9 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
568567
const paddle::Tensor& batch_id_per_token,
569568
const paddle::Tensor& cu_seqlens_q,
570569
const paddle::Tensor& block_tables,
570+
const paddle::Tensor& slot_mapping,
571571
const paddle::optional<paddle::Tensor>& kv_signal_data,
572-
const std::string& cache_quant_type_str,
573-
const int max_seq_len);
572+
const std::string& cache_quant_type_str);
574573

575574
void FusedRotaryPositionEncoding(
576575
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,6 @@ def forward_extend(
664664
metadata.block_tables,
665665
metadata.kv_signal_data_list[layer.layer_id],
666666
"none",
667-
getattr(forward_meta, "max_input_length", -1),
668667
)
669668

670669
fmha_out = self.flash_attn_func(
@@ -720,7 +719,6 @@ def forward_decode(
720719
forward_meta.cu_seqlens_q,
721720
metadata.block_tables,
722721
"none",
723-
self.max_seq_len,
724722
speculate_decoder,
725723
)
726724

@@ -799,21 +797,23 @@ def forward_mixed(
799797

800798
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None
801799

800+
assert k_pe.shape[0] == compressed_kv.shape[0]
801+
prefill_mla_write_cache(
802+
compressed_kv,
803+
k_pe,
804+
latent_cache,
805+
forward_meta.seq_lens_this_time,
806+
forward_meta.seq_lens_decoder,
807+
forward_meta.batch_id_per_token,
808+
forward_meta.cu_seqlens_q,
809+
metadata.block_tables,
810+
forward_meta.slot_mapping,
811+
metadata.kv_signal_data_list[layer.layer_id],
812+
"none",
813+
)
814+
802815
# Prefill branch: k is not None
803816
if k is not None:
804-
prefill_mla_write_cache(
805-
compressed_kv,
806-
k_pe,
807-
latent_cache,
808-
forward_meta.seq_lens_encoder,
809-
forward_meta.seq_lens_decoder,
810-
forward_meta.batch_id_per_token,
811-
forward_meta.cu_seqlens_q,
812-
metadata.block_tables,
813-
metadata.kv_signal_data_list[layer.layer_id],
814-
"none",
815-
self.max_seq_len,
816-
)
817817

818818
if self.prop.major == 10:
819819
# TODO support FA4
@@ -845,20 +845,6 @@ def forward_mixed(
845845

846846
# Decode branch: k is None
847847
if k is None:
848-
decode_mla_write_cache(
849-
compressed_kv,
850-
k_pe,
851-
latent_cache,
852-
forward_meta.seq_lens_decoder,
853-
forward_meta.seq_lens_encoder,
854-
forward_meta.batch_id_per_token,
855-
forward_meta.cu_seqlens_q,
856-
metadata.block_tables,
857-
"none",
858-
self.max_seq_len,
859-
speculate_decoder,
860-
)
861-
862848
if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9:
863849
assert self.num_heads <= 64, "paddle mla attention support failed"
864850
if self.heads_need_padding:
@@ -961,6 +947,12 @@ def forward_mixed(
961947
@staticmethod
962948
def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale):
963949

950+
# decoder_q = decoder_q.cast(paddle.float8_e4m3fn)
951+
# latent_cache = latent_cache.cast(paddle.float8_e4m3fn)
952+
953+
assert decoder_q.dtype == latent_cache.dtype
954+
q_dtype = decoder_q.dtype
955+
964956
page_size = latent_cache.shape[2]
965957
q_num_heads = decoder_q.shape[2]
966958
assert decoder_q.shape[1:] == [1, q_num_heads, 576]
@@ -1008,6 +1000,8 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft
10081000

10091001
from mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16
10101002

1003+
# from mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8
1004+
10111005
mla = BlackwellMultiHeadLatentAttentionForwardFP16(
10121006
cutlass.Float32,
10131007
cutlass.Float32,
@@ -1063,10 +1057,18 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft
10631057
stream,
10641058
)
10651059

1060+
if q_dtype == paddle.float8_e4m3fn:
1061+
paddle_output = paddle_output.cast("bfloat16")
10661062
return paddle_output
10671063

10681064
@staticmethod
10691065
def flashmla_baseline(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale):
1066+
1067+
assert decoder_q.dtype == latent_cache.dtype
1068+
1069+
decoder_q = decoder_q.cast("bfloat16")
1070+
latent_cache = latent_cache.cast("bfloat16")
1071+
10701072
page_size = latent_cache.shape[2]
10711073
q_num_heads = decoder_q.shape[2]
10721074
assert decoder_q.shape[1:] == [1, q_num_heads, 576]

0 commit comments

Comments
 (0)