Skip to content

Commit 6076add

Browse files
authored
[OP] Support MLA sliding-window attention (#8060)
1 parent 5372fe5 commit 6076add

6 files changed

Lines changed: 449 additions & 980 deletions

File tree

fastdeploy/model_executor/layers/attention/dsa_attention_backend.py

Lines changed: 133 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,68 @@
3838
if TYPE_CHECKING:
3939
from fastdeploy.model_executor.forward_meta import ForwardMeta
4040

41+
import triton
42+
import triton.language as tl
43+
4144
from fastdeploy.config import FDConfig
4245
from fastdeploy.model_executor.layers.attention.attention import Attention
4346
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
4447
AttentionBackend,
4548
AttentionMetadata,
4649
)
4750
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
51+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
52+
enable_compat_on_triton_kernel,
53+
)
54+
55+
56+
@enable_compat_on_triton_kernel
57+
@triton.jit()
58+
def insert_kernel_with_active_idx(
59+
decoder_res,
60+
active_idx,
61+
cu_seqlens_q,
62+
output,
63+
HIDDEN_DIM: tl.constexpr,
64+
BLOCK_SIZE: tl.constexpr,
65+
):
66+
compact_id = tl.program_id(axis=0)
67+
batch_id = tl.load(active_idx + compact_id)
68+
cu_len_this_batch = tl.load(cu_seqlens_q + batch_id)
69+
70+
read_offsets = tl.arange(0, BLOCK_SIZE)
71+
decoder_res += compact_id * HIDDEN_DIM
72+
row_data = tl.load(decoder_res + read_offsets, mask=read_offsets < HIDDEN_DIM)
73+
74+
output += cu_len_this_batch * HIDDEN_DIM
75+
tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM)
76+
77+
78+
def insert_decoder_result_back_with_active_idx(
79+
decoder_result: paddle.Tensor,
80+
active_idx: paddle.Tensor,
81+
cu_seqlens_q: paddle.Tensor,
82+
mixed_token_num,
83+
):
84+
assert len(decoder_result.shape) == 4
85+
assert len(active_idx.shape) == 1
86+
assert len(cu_seqlens_q.shape) == 1
87+
88+
hidden_dim = decoder_result.shape[-2] * decoder_result.shape[-1]
89+
out = paddle.empty([mixed_token_num, hidden_dim], dtype=decoder_result.dtype)
90+
91+
BLOCK_SIZE = triton.next_power_of_2(hidden_dim)
92+
93+
insert_kernel_with_active_idx[(active_idx.shape[0],)](
94+
decoder_result,
95+
active_idx,
96+
cu_seqlens_q,
97+
out,
98+
hidden_dim,
99+
BLOCK_SIZE,
100+
)
101+
102+
return out
48103

49104

50105
def yarn_get_mscale(scale=1, mscale=1):
@@ -336,7 +391,26 @@ def forward_mixed(
336391
Mixed模式的前向传播
337392
"""
338393

339-
latent_cache = forward_meta.caches[2 * layer.layer_id] if hasattr(forward_meta, "caches") else None
394+
res = DSAAttentionBackend.forward_static(
395+
q, v, compressed_kv, k_pe, forward_meta.caches[2 * layer.layer_id], forward_meta, self.attn_softmax_scale
396+
)
397+
return res
398+
399+
@staticmethod
400+
def forward_static(
401+
q: paddle.Tensor,
402+
indexer_topk: paddle.Tensor,
403+
compressed_kv: paddle.Tensor,
404+
k_pe: paddle.Tensor,
405+
latent_cache: paddle.Tensor,
406+
forward_meta: ForwardMeta,
407+
attn_softmax_scale: float,
408+
) -> paddle.Tensor:
409+
410+
assert len(q.shape) == 3
411+
assert len(compressed_kv.shape) == 2
412+
assert len(k_pe.shape) == 3
413+
assert len(latent_cache.shape) == 4
340414

341415
if current_platform.is_cuda():
342416
import flash_mla
@@ -352,43 +426,91 @@ def forward_mixed(
352426
"fp8_ds_mla",
353427
)
354428

429+
assert len(q.shape) == 3
430+
q_num_heads = q.shape[1]
431+
ceil64_num_heads = (q_num_heads + 63) // 64 * 64
432+
355433
fmha_out_prefill = None
356434
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
435+
if ceil64_num_heads != q_num_heads:
436+
new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype)
437+
new_q[:, :q_num_heads, :] = q
438+
else:
439+
new_q = q
357440

441+
kv = paddle.concat([compressed_kv.unsqueeze(1), k_pe], axis=-1)
358442
fmha_out_prefill, _, __ = flash_mla.flash_mla_sparse_fwd(
359-
q, # q_input.contiguous(),
360-
k, # kv.unsqueeze(1),
361-
v, # indexer_top_k.unsqueeze(1),
362-
sm_scale=self.attn_softmax_scale,
443+
new_q, # q_input.contiguous(),
444+
kv, # kv.unsqueeze(1),
445+
indexer_topk, # indexer_top_k.unsqueeze(1),
446+
sm_scale=attn_softmax_scale,
363447
)
364448

449+
assert len(fmha_out_prefill.shape) == 3
450+
fmha_out_prefill = fmha_out_prefill[:, :q_num_heads, :].contiguous()
451+
365452
# Decode
366-
# if k is None:
367-
if forward_meta.max_len_tensor_cpu[2]: # max_enc_len_this_time
453+
if forward_meta.max_len_tensor_cpu[2]:
454+
455+
need_insert_decoder_result = False
456+
q_total_token_num = q.shape[0]
457+
if forward_meta.max_len_tensor_cpu[1]:
458+
# indexer_topk is generated in full-token space. Select only
459+
# real decode token rows before calling flash_mla_with_kvcache.
460+
# This is feasible because the current DSA does not support chunk-related functions.
461+
active_idx = paddle.where(forward_meta.seq_lens_decoder > 0)[0]
462+
token_idx = forward_meta.cu_seqlens_q[active_idx]
463+
q_decode = q[token_idx]
464+
indexer_topk_decode = indexer_topk[token_idx]
465+
need_insert_decoder_result = True
466+
else:
467+
q_decode = q
468+
indexer_topk_decode = indexer_topk
368469

369470
tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
370471
new_cache_shape = latent_cache.shape
371472
assert new_cache_shape[1] == 1
372473
new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1]
474+
475+
if ceil64_num_heads != q_num_heads:
476+
new_q = paddle.empty([q_decode.shape[0], ceil64_num_heads, q_decode.shape[2]], dtype=q_decode.dtype)
477+
new_q[:, :q_num_heads, :] = q_decode
478+
else:
479+
new_q = q_decode
480+
373481
fmha_out_decode, _ = flash_mla.flash_mla_with_kvcache(
374-
q.unsqueeze(1).contiguous(),
482+
new_q.unsqueeze(1).contiguous(),
375483
latent_cache.view(new_cache_shape),
376484
None, # forward_meta.block_tables,
377485
None, # cache_seqlens
378486
512, # self.qk_nope_head_dim,
379487
tile_scheduler_metadata,
380488
None, # num_splits,
381-
self.attn_softmax_scale,
489+
attn_softmax_scale,
382490
False, # casual
383491
True, # is_fp8_kvcache
384-
v, # indices,
492+
indexer_topk_decode, # indices,
385493
None, # t.attn_sink,
386494
None, # extra_k_cache,
387495
None, # extra_indices_in_kvcache: Optional[torch.Tensor] = None,
388496
None, # topk_length: Optional[torch.Tensor] = None,
389497
None, # extra_topk_length: Optional[torch.Tensor] = None
390498
)
391499

500+
fmha_out_decode = fmha_out_decode[:, :, :q_num_heads, :].contiguous()
501+
502+
if need_insert_decoder_result:
503+
fmha_out_decode = insert_decoder_result_back_with_active_idx(
504+
fmha_out_decode,
505+
active_idx,
506+
forward_meta.cu_seqlens_q,
507+
q_total_token_num,
508+
)
509+
else:
510+
fmha_out_decode = fmha_out_decode.reshape(
511+
[fmha_out_decode.shape[0], q_num_heads * fmha_out_decode.shape[-1]]
512+
)
513+
392514
if fmha_out_prefill is not None:
393515

394516
from fastdeploy.model_executor.ops.gpu import (
@@ -402,7 +524,7 @@ def forward_mixed(
402524
forward_meta.seq_lens_decoder,
403525
forward_meta.seq_lens_this_time,
404526
forward_meta.cu_seqlens_q,
405-
self.num_heads * 4,
527+
q_num_heads * 4,
406528
128,
407529
1,
408530
)

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,8 @@ def __init__(
542542
logger.info(
543543
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
544544
)
545+
# swa config
546+
self.window_attn_skip_freq = getattr(fd_config.model_config, "window_attn_skip_freq", None)
545547

546548
def init_attention_metadata(self, forward_meta: ForwardMeta):
547549
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
@@ -618,8 +620,13 @@ def get_kv_cache_shape(
618620
"""
619621
Calculate kv cache shape for MLA
620622
"""
621-
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
623+
layer_id = self.layer_id
622624
value_cache_shape = []
625+
if self.window_attn_skip_freq is not None and self.window_attn_skip_freq[layer_id] == 1:
626+
fp8_key_cahe_dim = self.kv_lora_rank + 4 * (self.kv_lora_rank // 128) + 2 * self.qk_rope_head_dim
627+
key_cache_shape = [max_num_blocks, 1, self.block_size, fp8_key_cahe_dim]
628+
else:
629+
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
623630
return key_cache_shape, value_cache_shape
624631

625632
def create_kv_cache(

0 commit comments

Comments
 (0)