Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 133 additions & 11 deletions fastdeploy/model_executor/layers/attention/dsa_attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,68 @@
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta

import triton
import triton.language as tl

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
enable_compat_on_triton_kernel,
)


@enable_compat_on_triton_kernel
@triton.jit()
def insert_kernel_with_active_idx(
decoder_res,
active_idx,
cu_seqlens_q,
output,
HIDDEN_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
compact_id = tl.program_id(axis=0)
batch_id = tl.load(active_idx + compact_id)
cu_len_this_batch = tl.load(cu_seqlens_q + batch_id)

read_offsets = tl.arange(0, BLOCK_SIZE)
decoder_res += compact_id * HIDDEN_DIM
row_data = tl.load(decoder_res + read_offsets, mask=read_offsets < HIDDEN_DIM)

output += cu_len_this_batch * HIDDEN_DIM
tl.store(output + read_offsets, row_data, mask=read_offsets < HIDDEN_DIM)


def insert_decoder_result_back_with_active_idx(
decoder_result: paddle.Tensor,
active_idx: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
mixed_token_num,
):
assert len(decoder_result.shape) == 4
assert len(active_idx.shape) == 1
assert len(cu_seqlens_q.shape) == 1

hidden_dim = decoder_result.shape[-2] * decoder_result.shape[-1]
out = paddle.empty([mixed_token_num, hidden_dim], dtype=decoder_result.dtype)

BLOCK_SIZE = triton.next_power_of_2(hidden_dim)

insert_kernel_with_active_idx[(active_idx.shape[0],)](
decoder_result,
active_idx,
cu_seqlens_q,
out,
hidden_dim,
BLOCK_SIZE,
)

return out


def yarn_get_mscale(scale=1, mscale=1):
Expand Down Expand Up @@ -336,7 +391,26 @@ def forward_mixed(
Mixed模式的前向传播
"""

latent_cache = forward_meta.caches[2 * layer.layer_id] if hasattr(forward_meta, "caches") else None
res = DSAAttentionBackend.forward_static(
q, v, compressed_kv, k_pe, forward_meta.caches[2 * layer.layer_id], forward_meta, self.attn_softmax_scale
)
return res

@staticmethod
def forward_static(
q: paddle.Tensor,
indexer_topk: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
latent_cache: paddle.Tensor,
forward_meta: ForwardMeta,
attn_softmax_scale: float,
) -> paddle.Tensor:

assert len(q.shape) == 3
assert len(compressed_kv.shape) == 2
assert len(k_pe.shape) == 3
assert len(latent_cache.shape) == 4

if current_platform.is_cuda():
import flash_mla
Expand All @@ -352,43 +426,91 @@ def forward_mixed(
"fp8_ds_mla",
)

assert len(q.shape) == 3
q_num_heads = q.shape[1]
ceil64_num_heads = (q_num_heads + 63) // 64 * 64

fmha_out_prefill = None
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
if ceil64_num_heads != q_num_heads:
new_q = paddle.empty([q.shape[0], ceil64_num_heads, q.shape[2]], dtype=q.dtype)
new_q[:, :q_num_heads, :] = q
else:
new_q = q

kv = paddle.concat([compressed_kv.unsqueeze(1), k_pe], axis=-1)
fmha_out_prefill, _, __ = flash_mla.flash_mla_sparse_fwd(
q, # q_input.contiguous(),
k, # kv.unsqueeze(1),
v, # indexer_top_k.unsqueeze(1),
sm_scale=self.attn_softmax_scale,
new_q, # q_input.contiguous(),
kv, # kv.unsqueeze(1),
indexer_topk, # indexer_top_k.unsqueeze(1),
sm_scale=attn_softmax_scale,
)

assert len(fmha_out_prefill.shape) == 3
fmha_out_prefill = fmha_out_prefill[:, :q_num_heads, :].contiguous()

# Decode
# if k is None:
if forward_meta.max_len_tensor_cpu[2]: # max_enc_len_this_time
if forward_meta.max_len_tensor_cpu[2]:

need_insert_decoder_result = False
q_total_token_num = q.shape[0]
if forward_meta.max_len_tensor_cpu[1]:
# indexer_topk is generated in full-token space. Select only
# real decode token rows before calling flash_mla_with_kvcache.
# This is feasible because the current DSA does not support chunk-related functions.
active_idx = paddle.where(forward_meta.seq_lens_decoder > 0)[0]
token_idx = forward_meta.cu_seqlens_q[active_idx]
q_decode = q[token_idx]
indexer_topk_decode = indexer_topk[token_idx]
need_insert_decoder_result = True
else:
q_decode = q
indexer_topk_decode = indexer_topk

tile_scheduler_metadata, _ = flash_mla.get_mla_metadata()
new_cache_shape = latent_cache.shape
assert new_cache_shape[1] == 1
new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1]

if ceil64_num_heads != q_num_heads:
new_q = paddle.empty([q_decode.shape[0], ceil64_num_heads, q_decode.shape[2]], dtype=q_decode.dtype)
new_q[:, :q_num_heads, :] = q_decode
else:
new_q = q_decode

fmha_out_decode, _ = flash_mla.flash_mla_with_kvcache(
q.unsqueeze(1).contiguous(),
new_q.unsqueeze(1).contiguous(),
latent_cache.view(new_cache_shape),
None, # forward_meta.block_tables,
None, # cache_seqlens
512, # self.qk_nope_head_dim,
tile_scheduler_metadata,
None, # num_splits,
self.attn_softmax_scale,
attn_softmax_scale,
False, # casual
True, # is_fp8_kvcache
v, # indices,
indexer_topk_decode, # indices,
None, # t.attn_sink,
None, # extra_k_cache,
None, # extra_indices_in_kvcache: Optional[torch.Tensor] = None,
None, # topk_length: Optional[torch.Tensor] = None,
None, # extra_topk_length: Optional[torch.Tensor] = None
)

fmha_out_decode = fmha_out_decode[:, :, :q_num_heads, :].contiguous()

if need_insert_decoder_result:
fmha_out_decode = insert_decoder_result_back_with_active_idx(
fmha_out_decode,
active_idx,
forward_meta.cu_seqlens_q,
q_total_token_num,
)
else:
fmha_out_decode = fmha_out_decode.reshape(
[fmha_out_decode.shape[0], q_num_heads * fmha_out_decode.shape[-1]]
)

if fmha_out_prefill is not None:

from fastdeploy.model_executor.ops.gpu import (
Expand All @@ -402,7 +524,7 @@ def forward_mixed(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_heads * 4,
q_num_heads * 4,
128,
1,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ def __init__(
logger.info(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)
# swa config
self.window_attn_skip_freq = getattr(fd_config.model_config, "window_attn_skip_freq", None)

def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
Expand Down Expand Up @@ -618,8 +620,13 @@ def get_kv_cache_shape(
"""
Calculate kv cache shape for MLA
"""
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
layer_id = self.layer_id
value_cache_shape = []
if self.window_attn_skip_freq is not None and self.window_attn_skip_freq[layer_id] == 1:
fp8_key_cahe_dim = self.kv_lora_rank + 4 * (self.kv_lora_rank // 128) + 2 * self.qk_rope_head_dim
key_cache_shape = [max_num_blocks, 1, self.block_size, fp8_key_cahe_dim]
else:
key_cache_shape = [max_num_blocks, 1, self.block_size, self.kv_lora_rank + self.qk_rope_head_dim]
return key_cache_shape, value_cache_shape

def create_kv_cache(
Expand Down
Loading
Loading