Skip to content

Commit eab3599

Browse files
committed
update
1 parent bee1e8a commit eab3599

3 files changed

Lines changed: 78 additions & 45 deletions

File tree

fastdeploy/model_executor/layers/attention/triton_mla_attention_backend.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
self.causal: bool = getattr(fd_config.model_config, "causal", True)
108108

109109
self.num_heads: int = num_heads
110-
self.head_dim: int = fd_config.model_config.head_dim
110+
self.head_dim: int = head_dim
111111
self.num_layers: int = fd_config.model_config.num_hidden_layers
112112

113113
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
@@ -124,7 +124,7 @@ def __init__(
124124
self.max_kv_splits: int = 32
125125

126126
self.rank, self.device_id = init_rank_and_device_id(fd_config)
127-
self.useless_tensor = paddle.randn([1]).cast("int32")
127+
self.useless_tensor = paddle.zeros([1], dtype="int32")
128128

129129
# Pre-allocate buffers for CUDAGraph compatibility (stable memory addresses)
130130
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
@@ -133,6 +133,12 @@ def __init__(
133133
self._kv_indices_buf = paddle.zeros([self.max_num_seqs * max_blocks_per_seq * self.block_size], dtype="int32")
134134
self._num_kv_splits_buf = paddle.ones([self.max_num_seqs], dtype="int32")
135135

136+
# Pre-allocate decode kernel intermediate buffers for CUDAGraph address stability
137+
Lv = fd_config.model_config.kv_lora_rank
138+
self._attn_logits_buf = paddle.empty([self.max_num_seqs, num_heads, self.max_kv_splits, Lv], dtype="float32")
139+
self._attn_lse_buf = paddle.empty([self.max_num_seqs, num_heads, self.max_kv_splits], dtype="float32")
140+
self._o_buf = paddle.empty([self.max_num_seqs, num_heads, Lv], dtype=paddle.get_default_dtype())
141+
136142
if self.flash_attn_func is None:
137143
prop = paddle.device.cuda.get_device_properties()
138144
cc = prop.major * 10 + prop.minor
@@ -191,7 +197,10 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
191197
total_kv_len = int(paddle.sum(decode_seq_lens).item())
192198

193199
build_kv_indices_from_block_tables(
194-
decode_block_tables, decode_seq_lens, self.block_size, decode_bs,
200+
decode_block_tables,
201+
decode_seq_lens,
202+
self.block_size,
203+
decode_bs,
195204
total_kv_len=total_kv_len,
196205
kv_indptr_buf=self._kv_indptr_buf,
197206
kv_indices_buf=self._kv_indices_buf,
@@ -200,11 +209,10 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
200209
# kv_indptr[decode_bs] = total_kv_len; positions beyond must equal the same
201210
# so that (kv_indptr[i+1] - kv_indptr[i]) = 0 for padded batches.
202211
if decode_bs < self.max_num_seqs:
203-
self._kv_indptr_buf[decode_bs + 1:] = total_kv_len
212+
self._kv_indptr_buf[decode_bs + 1 :] = total_kv_len
204213

205214
# Compute num_kv_splits into the pre-allocated buffer
206-
compute_num_kv_splits(decode_seq_lens, decode_bs, self.max_kv_splits,
207-
out_buf=self._num_kv_splits_buf)
215+
compute_num_kv_splits(decode_seq_lens, decode_bs, self.max_kv_splits, out_buf=self._num_kv_splits_buf)
208216
# Padded entries must be >= 1 to avoid division by zero in kernel
209217
if decode_bs < self.max_num_seqs:
210218
self._num_kv_splits_buf[decode_bs:] = 1
@@ -346,14 +354,15 @@ def _run_decode_kernel(
346354
latent_dim = self.kv_lora_rank + self.qk_rope_head_dim
347355
q_reshaped = q.reshape([bs, self.num_heads, latent_dim])
348356

349-
attn_logits = paddle.empty([bs, self.num_heads, self.max_kv_splits, Lv], dtype="float32")
350-
attn_lse = paddle.empty([bs, self.num_heads, self.max_kv_splits], dtype="float32")
351-
o = paddle.empty([bs, self.num_heads, Lv], dtype=q.dtype)
357+
# Use pre-allocated buffers sliced to current batch size for CUDAGraph address stability
358+
attn_logits = self._attn_logits_buf[:bs]
359+
attn_lse = self._attn_lse_buf[:bs]
360+
o = self._o_buf[:bs]
352361

353362
decode_attention_fwd(
354363
q_reshaped,
355364
latent_cache,
356-
latent_cache[:, :, :, :self.kv_lora_rank],
365+
latent_cache[:, :, :, : self.kv_lora_rank],
357366
o,
358367
metadata.kv_indptr,
359368
metadata.kv_indices,

fastdeploy/model_executor/layers/attention/triton_ops/decode_attention.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,9 @@
2828
enable_compat_on_triton_kernel,
2929
)
3030

31-
3231
_MIN_BLOCK_KV = 32
3332

3433

35-
@enable_compat_on_triton_kernel
36-
@triton.jit
37-
def tanh(x):
38-
return 2 * tl.sigmoid(2 * x) - 1
39-
40-
4134
@enable_compat_on_triton_kernel
4235
@triton.jit
4336
def _fwd_grouped_kernel_stage1(
@@ -104,13 +97,9 @@ def _fwd_grouped_kernel_stage1(
10497
if BLOCK_DPE > 0:
10598
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
10699
mask_dpe = offs_dpe < Lk
107-
off_qpe = (
108-
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
109-
)
100+
off_qpe = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
110101

111-
kv_len_per_split = (
112-
tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
113-
)
102+
kv_len_per_split = tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
114103
split_kv_start = kv_len_per_split * split_kv_id
115104
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
116105

@@ -121,9 +110,7 @@ def _fwd_grouped_kernel_stage1(
121110
if split_kv_end > split_kv_start:
122111
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
123112
if BLOCK_DPE > 0:
124-
qpe = tl.load(
125-
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
126-
)
113+
qpe = tl.load(Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0)
127114
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
128115
offs_n = start_n + tl.arange(0, BLOCK_N)
129116
kv_loc = tl.load(
@@ -163,9 +150,7 @@ def _fwd_grouped_kernel_stage1(
163150
qk += tl.dot(qpe, kpe.to(qpe.dtype))
164151
qk *= sm_scale
165152

166-
qk = tl.where(
167-
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
168-
)
153+
qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf"))
169154

170155
# Load V from paged cache
171156
offs_buf_v = (
@@ -202,11 +187,7 @@ def _fwd_grouped_kernel_stage1(
202187
mask=(mask_h[:, None]) & (mask_dv[None, :]),
203188
)
204189

205-
offs_mid_o_1 = (
206-
cur_batch * stride_mid_ob
207-
+ cur_head * stride_mid_oh
208-
+ split_kv_id * stride_mid_os
209-
) // Lv
190+
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os) // Lv
210191

211192
tl.store(
212193
Att_Lse + offs_mid_o_1,
@@ -239,9 +220,7 @@ def _fwd_kernel_stage2(
239220
cur_batch = tl.program_id(0)
240221
cur_head = tl.program_id(1)
241222

242-
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
243-
kv_indptr + cur_batch
244-
)
223+
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch)
245224
kv_splits = tl.load(num_kv_splits + cur_batch)
246225

247226
offs_d = tl.arange(0, BLOCK_DV)
@@ -253,18 +232,14 @@ def _fwd_kernel_stage2(
253232

254233
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
255234
offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv
256-
kv_len_per_split = (
257-
tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
258-
)
235+
kv_len_per_split = tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV
259236

260237
for split_kv_id in range(0, MAX_KV_SPLITS):
261238
split_kv_start = kv_len_per_split * split_kv_id
262239
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
263240

264241
if split_kv_end > split_kv_start:
265-
tv = tl.load(
266-
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
267-
)
242+
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0)
268243
tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv)
269244
n_e_max = tl.maximum(tlogic, e_max)
270245

@@ -276,9 +251,11 @@ def _fwd_kernel_stage2(
276251
e_sum = e_sum * old_scale + exp_logic
277252
e_max = n_e_max
278253

254+
# Guard against e_sum==0 (empty sequences from CUDAGraph padding) to avoid NaN
255+
safe_e_sum = tl.where(e_sum == 0.0, 1.0, e_sum)
279256
tl.store(
280257
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
281-
acc / e_sum,
258+
tl.where(e_sum == 0.0, 0.0, acc / safe_e_sum),
282259
mask=mask_d,
283260
)
284261

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import math
20+
import os
2021
import re
2122
from typing import Dict
2223

@@ -344,6 +345,9 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
344345

345346
self.prefix = prefix
346347

348+
prop = paddle.device.cuda.get_device_properties()
349+
self.prop = prop
350+
347351
@staticmethod
348352
def yarn_get_mscale(scale=1, mscale=1):
349353
""" """
@@ -362,6 +366,8 @@ def forward(
362366
fused_read_cache_and_interleave,
363367
)
364368

369+
q_total_token_num = hidden_states.shape[0]
370+
365371
attn_out = None
366372
if self.use_gated_attn:
367373
gate_out = self.gate(hidden_states)
@@ -439,6 +445,36 @@ def forward(
439445
attn_out = fmha_out
440446

441447
if need_do_decode: # max_dec_len_this_time
448+
449+
if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9:
450+
pass
451+
else:
452+
from fastdeploy.model_executor.layers.attention.mla_attention_backend import (
453+
extract_decoder_token_from_q,
454+
insert_decoder_result_back,
455+
)
456+
457+
decoder_query_nope, cache_seqlens = extract_decoder_token_from_q(
458+
query_nope.reshape([0, -1]),
459+
forward_meta.cu_seqlens_q,
460+
forward_meta.seq_lens_encoder,
461+
forward_meta.seq_lens_decoder,
462+
)
463+
464+
decoder_query_pe, cache_seqlens = extract_decoder_token_from_q(
465+
query_pe.reshape([0, -1]),
466+
forward_meta.cu_seqlens_q,
467+
forward_meta.seq_lens_encoder,
468+
forward_meta.seq_lens_decoder,
469+
)
470+
assert decoder_query_nope.shape[0] == forward_meta.seq_lens_encoder.shape[0]
471+
assert decoder_query_pe.shape[0] == forward_meta.seq_lens_encoder.shape[0]
472+
473+
forward_meta.cache_seqlens = cache_seqlens
474+
475+
query_nope = decoder_query_nope.reshape([0, -1, self.qk_nope_head_dim])
476+
query_pe = decoder_query_pe.reshape([0, -1, self.qk_rope_head_dim])
477+
442478
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
443479

444480
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
@@ -467,6 +503,17 @@ def forward(
467503
.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
468504
)
469505

506+
if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9:
507+
pass
508+
else:
509+
fmqa_out = insert_decoder_result_back(
510+
fmqa_out.reshape([0, 1, self.num_attention_heads_tp, self.v_head_dim]),
511+
forward_meta.cu_seqlens_q,
512+
forward_meta.seq_lens_encoder,
513+
forward_meta.seq_lens_decoder,
514+
q_total_token_num,
515+
)
516+
470517
if need_do_prefill:
471518
merge_prefill_decode_output(
472519
attn_out,
@@ -1062,7 +1109,7 @@ def forward(
10621109
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0
10631110

10641111
if not need_do_prefill and not need_do_decode:
1065-
return hidden_states
1112+
return hidden_states, residual
10661113

10671114
if hidden_states.shape[0] > 0:
10681115
hidden_states, residual = self.input_layernorm(

0 commit comments

Comments
 (0)