1313from vllm .attention .ops .flashmla import get_mla_metadata
1414from vllm .config import VllmConfig
1515from vllm .forward_context import ForwardContext
16+ from vllm .utils import cdiv
1617from vllm .v1 .attention .backends .mla .common import MLACommonMetadata
1718from vllm .v1 .core .sched .output import SchedulerOutput
1819from vllm .v1 .request import Request , RequestStatus
@@ -136,6 +137,18 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
136137
137138 self .seq_len_threshhold = self .gsa_on_device_config .seq_len_threshhold
138139
140+ assert (
141+ self .seq_len_threshhold
142+ >= self .gsa_on_device_config .vllm_hash_attention_topk
143+ ), "seq_len_threshhold must be larger than or equal to vllm_hash_attention_topk"
144+ assert (
145+ self .gsa_on_device_config .vllm_hash_attention_topk % self .block_size == 0
146+ ), "vllm_hash_attention_topk must be divisible by block_size"
147+ assert (
148+ self .gsa_on_device_config .vllm_hash_attention_topk
149+ <= vllm_config .model_config .max_model_len
150+ ), "vllm_hash_attention_topk must be less than max_model_len"
151+
139152 if role == UcmSparseRole .WORKER :
140153 if self .is_cuda : # cuda only variables
141154 device_properties = torch .cuda .get_device_properties (self .device )
@@ -233,7 +246,9 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
233246 [
234247 self .max_batch_size ,
235248 self .num_key_heads ,
236- self .hash_topk_tokens // self .block_size ,
249+ cdiv (
250+ vllm_config .model_config .max_model_len , self .block_size
251+ ),
237252 ],
238253 dtype = torch .int32 ,
239254 device = self .device ,
@@ -340,36 +355,47 @@ def attention_begin(
340355 else : # NPU
341356 if not self .is_tensor_computed :
342357 if self .decode_mask .any (): # with at least one decode request
343- decode_req_ids = torch .nonzero (
344- self .decode_mask , as_tuple = False
345- ).flatten ()
346- decode_req_ids_npu = torch .nonzero (
347- self .decode_mask_npu , as_tuple = False
348- ).flatten ()
349- batch_size_for_hamming = self .decode_mask .sum ().item ()
350- self .query_lens_device = attn_metadata .query_lens_device [
351- decode_req_ids_npu
352- ]
358+ if self .slice_enabled :
359+ # if slice_enabled, the batch_size_for_hamming is the number of decode requests
360+ self .batch_size_for_hamming = self .num_decode_requests
361+ else :
362+ # if not slice_enabled, the batch_size_for_hamming is the number of all requests
363+ self .batch_size_for_hamming = len (
364+ attn_metadata .seq_lens
365+ )
366+ # only get decode_mask_npu when slice_enabled is False
367+ self .decode_mask_npu = (
368+ attn_metadata .query_lens_device == 1
369+ ) & (
370+ attn_metadata .seq_lens_device
371+ >= self .seq_len_threshhold
372+ )
373+
353374 self .topk_for_hamming = self .topk_for_hamming_full [
354- :batch_size_for_hamming
375+ : self . batch_size_for_hamming
355376 ]
356377 self .chunk_sizes_for_hamming = (
357378 self .chunk_sizes_for_hamming_full [
358- :batch_size_for_hamming
379+ : self . batch_size_for_hamming
359380 ]
360381 )
382+
361383 self .seq_lens_for_hamming = attn_metadata .seq_lens_device [
362- decode_req_ids_npu
384+ : self . batch_size_for_hamming
363385 ]
364386 self .max_seq_len_for_hamming = torch .max (
365- attn_metadata .seq_lens [decode_req_ids ]
387+ attn_metadata .seq_lens [: self . batch_size_for_hamming ]
366388 ).item ()
389+ self .block_table_decode = self .ori_block_table_decode [
390+ : self .batch_size_for_hamming
391+ ]
392+
367393 self .is_tensor_computed = True
368394
369395 k_hash_compute = self .hash_encoder .compute_hash (key )
370- assert (
371- k_hash_compute .shape [0 ] == attn_metadata .slot_mapping .numel ()
372- ), f"shape mismatch: k_hash_compute.shape[0]={ k_hash_compute .shape [0 ]} != attn_metadata.slot_mapping.numel()={ attn_metadata .slot_mapping .numel ()} "
396+ # assert (
397+ # k_hash_compute.shape[0] == attn_metadata.slot_mapping.numel()
398+ # ), f"shape mismatch: k_hash_compute.shape[0]={k_hash_compute.shape[0]} != attn_metadata.slot_mapping.numel()={attn_metadata.slot_mapping.numel()}"
373399 k_hash_compute = (
374400 k_hash_compute .transpose (0 , 1 )
375401 .reshape (- 1 , k_hash_compute .shape [- 1 ])
@@ -433,7 +459,10 @@ def attention_begin(
433459 if not is_rollback_layer :
434460 if is_skip_hash_layer :
435461 # 跳层 使用上一个topk结果
436- attn_metadata .block_tables = self .topk_block_table
462+ if self .is_cuda :
463+ attn_metadata .block_table = self .topk_block_table
464+ else :
465+ attn_metadata .block_tables = self .topk_block_table
437466 attn_metadata .seq_lens = self .topk_seq_lens
438467 else :
439468 if self .is_cuda :
@@ -475,26 +504,22 @@ def attention_begin(
475504 attn_metadata .seq_lens [self .decode_mask ] = (
476505 self .topk_seq_lens_qwen
477506 )
478- else : # NPU
479507
480- decode_req_ids = torch .nonzero (
481- self .decode_mask_npu , as_tuple = False
482- ).flatten ()
483- decode_token_idx = q_start [:- 1 ].index_select (
484- 0 , decode_req_ids
485- )
486- q_decode = query .index_select (0 , decode_token_idx )
508+ # topk for skip layer
509+ self .topk_block_table = attn_metadata .block_table
510+ self .topk_seq_lens = attn_metadata .seq_lens
511+ else : # NPU
512+ if self .slice_enabled :
513+ q_decode = query [: self .batch_size_for_hamming ]
514+ else :
515+ q_decode = query .index_select (0 , q_start [:- 1 ])
487516
488517 q_hash = (
489518 self .hash_encoder .compute_hash (q_decode )
490519 .unsqueeze (2 )
491520 .contiguous ()
492521 )
493522
494- block_table_decode = attn_metadata .block_table .index_select (
495- 0 , decode_req_ids
496- )
497-
498523 ucm_custom_ops .hamming_dist_top_k (
499524 q_hash ,
500525 k_hash ,
@@ -505,23 +530,37 @@ def attention_begin(
505530 self .hamming_keep_chunks_head ,
506531 self .hamming_keep_chunks_tail ,
507532 0 , # support_offload is disabled
508- block_table_decode ,
509- self .hamming_output [: len (decode_req_ids )],
533+ self .block_table_decode ,
534+ (
535+ self .decode_mask_npu
536+ if not self .slice_enabled
537+ else None
538+ ),
539+ self .hamming_output [: self .batch_size_for_hamming ],
510540 )
511- topk = self .hamming_output .shape [- 1 ]
512- attn_metadata .block_table [decode_req_ids , :topk ] = (
513- self .hamming_output [: len (decode_req_ids ), 0 , :]
514- )
515- attn_metadata .block_table [decode_req_ids , topk :] = 0
541+ new_seq_lens = self .topk_seq_lens_qwen
542+ attn_metadata .seq_lens = new_seq_lens
543+ if (
544+ self .slice_enabled
545+ and attn_metadata .attn_state
546+ != AscendAttentionState .DecodeOnly
547+ ):
548+ new_block_tables = attn_metadata .block_tables .clone ()
549+ new_block_tables [: self .batch_size_for_hamming ] = (
550+ self .hamming_output [
551+ : self .batch_size_for_hamming , 0 , :
552+ ]
553+ )
554+ else :
555+ new_block_tables = self .hamming_output [
556+ : self .batch_size_for_hamming , 0 , :
557+ ]
516558
517- # we have already computed the topk_seq_lens_qwen in `build_decode_attention_meta_npu()`
518- attn_metadata .seq_lens [self .decode_mask ] = (
519- self .topk_seq_lens_qwen
520- )
559+ attn_metadata .block_tables = new_block_tables
521560
522- # topk for skip layer
523- self .topk_block_table = attn_metadata .block_table
524- self .topk_seq_lens = attn_metadata .seq_lens
561+ # topk for skip layer
562+ self .topk_block_table = attn_metadata .block_tables
563+ self .topk_seq_lens = attn_metadata .seq_lens
525564
526565 return query , key , value , output
527566
@@ -549,7 +588,10 @@ def attention_finished(
549588 attn_metadata .decode .num_splits = self .origin_num_splits
550589 else : # 判断req decode阶段
551590 if self .decode_mask .any ():
552- attn_metadata .block_table = self .ori_block_table_decode
591+ if self .is_cuda :
592+ attn_metadata .block_table = self .ori_block_table_decode
593+ else :
594+ attn_metadata .block_tables = self .ori_block_table_decode
553595 attn_metadata .seq_lens = self .ori_seq_lens_decode
554596
555597 def request_begin (self , request_id : ReqType , prompt_token_ids : List [int ]):
@@ -645,20 +687,28 @@ def build_decode_attention_meta_npu(self, query_lens, seq_lens, block_table):
645687
646688 # self.decode_mask is on cpu in vllm-asencd under NPU device
647689 self .decode_mask = (query_lens == 1 ) & (seq_lens >= self .seq_len_threshhold )
648- self .decode_mask = self .decode_mask .pin_memory ()
690+ # self.decode_mask = self.decode_mask.pin_memory()
691+
692+ self .num_decode_requests = self .decode_mask .sum ().item ()
693+ if self .num_decode_requests > 0 :
694+ self .slice_enabled = (
695+ self .decode_mask [: self .num_decode_requests ].all ().item ()
696+ )
697+ else :
698+ self .slice_enabled = False
649699
650700 self .ori_seq_lens_decode = seq_lens .clone ()
651701 self .ori_block_table_decode = block_table .clone ()
652702
653- self .decode_mask_npu = self .decode_mask .to (self .device , non_blocking = True )
654-
655703 if self .decode_mask .any ():
656- decode_seq_lens = seq_lens [ self .decode_mask ]
704+ # self.decode_mask_npu = self.decode_mask.to(self.device, non_blocking=True)
657705 self .topk_seq_lens_qwen = update_seq_lens (
658- decode_seq_lens ,
706+ seq_lens ,
659707 topk_token = self .hash_topk_tokens ,
660708 block_size = self .block_size ,
661709 )
710+ # (ldeng) set the seq_lens for the non-decode requests to the original seq_lens
711+ self .topk_seq_lens_qwen [~ self .decode_mask ] = seq_lens [~ self .decode_mask ]
662712
663713 def maybe_init_cudagraph_buffers_for_topk (self , n , tile_scheduler_metadata ):
664714 sm_parts = tile_scheduler_metadata .size (0 )
0 commit comments