3838if TYPE_CHECKING :
3939 from fastdeploy .model_executor .forward_meta import ForwardMeta
4040
41+ import triton
42+ import triton .language as tl
43+
4144from fastdeploy .config import FDConfig
4245from fastdeploy .model_executor .layers .attention .attention import Attention
4346from fastdeploy .model_executor .layers .attention .base_attention_backend import (
4447 AttentionBackend ,
4548 AttentionMetadata ,
4649)
4750from fastdeploy .model_executor .layers .attention .utils import init_rank_and_device_id
51+ from fastdeploy .model_executor .ops .triton_ops .triton_utils import (
52+ enable_compat_on_triton_kernel ,
53+ )
54+
55+
56+ @enable_compat_on_triton_kernel
57+ @triton .jit ()
58+ def insert_kernel_with_active_idx (
59+ decoder_res ,
60+ active_idx ,
61+ cu_seqlens_q ,
62+ output ,
63+ HIDDEN_DIM : tl .constexpr ,
64+ BLOCK_SIZE : tl .constexpr ,
65+ ):
66+ compact_id = tl .program_id (axis = 0 )
67+ batch_id = tl .load (active_idx + compact_id )
68+ cu_len_this_batch = tl .load (cu_seqlens_q + batch_id )
69+
70+ read_offsets = tl .arange (0 , BLOCK_SIZE )
71+ decoder_res += compact_id * HIDDEN_DIM
72+ row_data = tl .load (decoder_res + read_offsets , mask = read_offsets < HIDDEN_DIM )
73+
74+ output += cu_len_this_batch * HIDDEN_DIM
75+ tl .store (output + read_offsets , row_data , mask = read_offsets < HIDDEN_DIM )
76+
77+
78+ def insert_decoder_result_back_with_active_idx (
79+ decoder_result : paddle .Tensor ,
80+ active_idx : paddle .Tensor ,
81+ cu_seqlens_q : paddle .Tensor ,
82+ mixed_token_num ,
83+ ):
84+ assert len (decoder_result .shape ) == 4
85+ assert len (active_idx .shape ) == 1
86+ assert len (cu_seqlens_q .shape ) == 1
87+
88+ hidden_dim = decoder_result .shape [- 2 ] * decoder_result .shape [- 1 ]
89+ out = paddle .empty ([mixed_token_num , hidden_dim ], dtype = decoder_result .dtype )
90+
91+ BLOCK_SIZE = triton .next_power_of_2 (hidden_dim )
92+
93+ insert_kernel_with_active_idx [(active_idx .shape [0 ],)](
94+ decoder_result ,
95+ active_idx ,
96+ cu_seqlens_q ,
97+ out ,
98+ hidden_dim ,
99+ BLOCK_SIZE ,
100+ )
101+
102+ return out
48103
49104
50105def yarn_get_mscale (scale = 1 , mscale = 1 ):
@@ -336,7 +391,26 @@ def forward_mixed(
336391 Mixed模式的前向传播
337392 """
338393
339- latent_cache = forward_meta .caches [2 * layer .layer_id ] if hasattr (forward_meta , "caches" ) else None
394+ res = DSAAttentionBackend .forward_static (
395+ q , v , compressed_kv , k_pe , forward_meta .caches [2 * layer .layer_id ], forward_meta , self .attn_softmax_scale
396+ )
397+ return res
398+
399+ @staticmethod
400+ def forward_static (
401+ q : paddle .Tensor ,
402+ indexer_topk : paddle .Tensor ,
403+ compressed_kv : paddle .Tensor ,
404+ k_pe : paddle .Tensor ,
405+ latent_cache : paddle .Tensor ,
406+ forward_meta : ForwardMeta ,
407+ attn_softmax_scale : float ,
408+ ) -> paddle .Tensor :
409+
410+ assert len (q .shape ) == 3
411+ assert len (compressed_kv .shape ) == 2
412+ assert len (k_pe .shape ) == 3
413+ assert len (latent_cache .shape ) == 4
340414
341415 if current_platform .is_cuda ():
342416 import flash_mla
@@ -352,43 +426,91 @@ def forward_mixed(
352426 "fp8_ds_mla" ,
353427 )
354428
429+ assert len (q .shape ) == 3
430+ q_num_heads = q .shape [1 ]
431+ ceil64_num_heads = (q_num_heads + 63 ) // 64 * 64
432+
355433 fmha_out_prefill = None
356434 if forward_meta .max_len_tensor_cpu [1 ]: # max_enc_len_this_time
435+ if ceil64_num_heads != q_num_heads :
436+ new_q = paddle .empty ([q .shape [0 ], ceil64_num_heads , q .shape [2 ]], dtype = q .dtype )
437+ new_q [:, :q_num_heads , :] = q
438+ else :
439+ new_q = q
357440
441+ kv = paddle .concat ([compressed_kv .unsqueeze (1 ), k_pe ], axis = - 1 )
358442 fmha_out_prefill , _ , __ = flash_mla .flash_mla_sparse_fwd (
359- q , # q_input.contiguous(),
360- k , # kv.unsqueeze(1),
361- v , # indexer_top_k.unsqueeze(1),
362- sm_scale = self . attn_softmax_scale ,
443+ new_q , # q_input.contiguous(),
444+ kv , # kv.unsqueeze(1),
445+ indexer_topk , # indexer_top_k.unsqueeze(1),
446+ sm_scale = attn_softmax_scale ,
363447 )
364448
449+ assert len (fmha_out_prefill .shape ) == 3
450+ fmha_out_prefill = fmha_out_prefill [:, :q_num_heads , :].contiguous ()
451+
365452 # Decode
366- # if k is None:
367- if forward_meta .max_len_tensor_cpu [2 ]: # max_enc_len_this_time
453+ if forward_meta .max_len_tensor_cpu [2 ]:
454+
455+ need_insert_decoder_result = False
456+ q_total_token_num = q .shape [0 ]
457+ if forward_meta .max_len_tensor_cpu [1 ]:
458+ # indexer_topk is generated in full-token space. Select only
459+ # real decode token rows before calling flash_mla_with_kvcache.
460+ # This is feasible because the current DSA does not support chunk-related functions.
461+ active_idx = paddle .where (forward_meta .seq_lens_decoder > 0 )[0 ]
462+ token_idx = forward_meta .cu_seqlens_q [active_idx ]
463+ q_decode = q [token_idx ]
464+ indexer_topk_decode = indexer_topk [token_idx ]
465+ need_insert_decoder_result = True
466+ else :
467+ q_decode = q
468+ indexer_topk_decode = indexer_topk
368469
369470 tile_scheduler_metadata , _ = flash_mla .get_mla_metadata ()
370471 new_cache_shape = latent_cache .shape
371472 assert new_cache_shape [1 ] == 1
372473 new_cache_shape [1 ], new_cache_shape [2 ] = new_cache_shape [2 ], new_cache_shape [1 ]
474+
475+ if ceil64_num_heads != q_num_heads :
476+ new_q = paddle .empty ([q_decode .shape [0 ], ceil64_num_heads , q_decode .shape [2 ]], dtype = q_decode .dtype )
477+ new_q [:, :q_num_heads , :] = q_decode
478+ else :
479+ new_q = q_decode
480+
373481 fmha_out_decode , _ = flash_mla .flash_mla_with_kvcache (
374- q .unsqueeze (1 ).contiguous (),
482+ new_q .unsqueeze (1 ).contiguous (),
375483 latent_cache .view (new_cache_shape ),
376484 None , # forward_meta.block_tables,
377485 None , # cache_seqlens
378486 512 , # self.qk_nope_head_dim,
379487 tile_scheduler_metadata ,
380488 None , # num_splits,
381- self . attn_softmax_scale ,
489+ attn_softmax_scale ,
382490 False , # casual
383491 True , # is_fp8_kvcache
384- v , # indices,
492+ indexer_topk_decode , # indices,
385493 None , # t.attn_sink,
386494 None , # extra_k_cache,
387495 None , # extra_indices_in_kvcache: Optional[torch.Tensor] = None,
388496 None , # topk_length: Optional[torch.Tensor] = None,
389497 None , # extra_topk_length: Optional[torch.Tensor] = None
390498 )
391499
500+ fmha_out_decode = fmha_out_decode [:, :, :q_num_heads , :].contiguous ()
501+
502+ if need_insert_decoder_result :
503+ fmha_out_decode = insert_decoder_result_back_with_active_idx (
504+ fmha_out_decode ,
505+ active_idx ,
506+ forward_meta .cu_seqlens_q ,
507+ q_total_token_num ,
508+ )
509+ else :
510+ fmha_out_decode = fmha_out_decode .reshape (
511+ [fmha_out_decode .shape [0 ], q_num_heads * fmha_out_decode .shape [- 1 ]]
512+ )
513+
392514 if fmha_out_prefill is not None :
393515
394516 from fastdeploy .model_executor .ops .gpu import (
@@ -402,7 +524,7 @@ def forward_mixed(
402524 forward_meta .seq_lens_decoder ,
403525 forward_meta .seq_lens_this_time ,
404526 forward_meta .cu_seqlens_q ,
405- self . num_heads * 4 ,
527+ q_num_heads * 4 ,
406528 128 ,
407529 1 ,
408530 )
0 commit comments