1- from dataclasses import dataclass , field
21from importlib import resources
32from pathlib import Path
43from typing import Dict , List , Optional , Union
3534 )
3635 from ucm .sparse .gsa_on_device .hash_encoder import reshape_and_cache_khash_triton
3736
38- from vllm .utils import cdiv
39-
4037from ucm .sparse .gsa_on_device .gsa_on_device_config import GSAOnDeviceConfig
4138from ucm .sparse .gsa_on_device .hash_encoder import HashEncoder
4239from ucm .utils import Config
@@ -231,6 +228,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
231228 self .decode_only = False
232229
233230 if not self .is_cuda : # NPU only variables
231+ self .decode_mask = None
234232 self .decode_mask_npu = None
235233 self .is_tensor_computed = False
236234
@@ -383,34 +381,37 @@ def cache_k_hash_gqa_cuda(
383381 def cache_k_hash_gqa_npu (self , key , k_hash , attn_metadata ):
384382 if not self .is_tensor_computed :
385383 if self .decode_mask .any (): # with at least one decode request
386- decode_req_ids = torch . nonzero (
387- self . decode_mask , as_tuple = False
388- ). flatten ()
389- decode_req_ids_npu = torch . nonzero (
390- self . decode_mask_npu , as_tuple = False
391- ). flatten ( )
392- batch_size_for_hamming = self . decode_mask . sum (). item ()
393- self .query_lens_device = attn_metadata .query_lens_device [
394- decode_req_ids_npu
395- ]
384+ if self . slice_enabled :
385+ # if slice_enabled, the batch_size_for_hamming is the number of decode requests
386+ self . batch_size_for_hamming = self . num_decode_requests
387+ else :
388+ # if not slice_enabled, the batch_size_for_hamming is the number of all requests
389+ self . batch_size_for_hamming = len ( attn_metadata . seq_lens )
390+ # only get decode_mask_npu when slice_enabled is False
391+ self .decode_mask_npu = ( attn_metadata .query_lens_device == 1 ) & (
392+ attn_metadata . seq_lens_device >= self . seq_len_threshhold
393+ )
396394 self .topk_for_hamming = self .topk_for_hamming_full [
397- :batch_size_for_hamming
395+ : self . batch_size_for_hamming
398396 ]
399397 self .chunk_sizes_for_hamming = self .chunk_sizes_for_hamming_full [
400- :batch_size_for_hamming
398+ : self . batch_size_for_hamming
401399 ]
402400 self .seq_lens_for_hamming = attn_metadata .seq_lens_device [
403- decode_req_ids_npu
401+ : self . batch_size_for_hamming
404402 ]
405403 self .max_seq_len_for_hamming = torch .max (
406- attn_metadata .seq_lens [decode_req_ids ]
404+ attn_metadata .seq_lens [: self . batch_size_for_hamming ]
407405 ).item ()
406+ self .block_table_decode = self .ori_block_table_decode [
407+ : self .batch_size_for_hamming
408+ ]
408409 self .is_tensor_computed = True
409410
410411 k_hash_compute = self .hash_encoder .compute_hash (key )
411- assert (
412- k_hash_compute .shape [0 ] == attn_metadata .slot_mapping .numel ()
413- ), f"shape mismatch: k_hash_compute.shape[0]={ k_hash_compute .shape [0 ]} != attn_metadata.slot_mapping.numel()={ attn_metadata .slot_mapping .numel ()} "
412+ # assert (
413+ # k_hash_compute.shape[0] == attn_metadata.slot_mapping.numel()
414+ # ), f"shape mismatch: k_hash_compute.shape[0]={k_hash_compute.shape[0]} != attn_metadata.slot_mapping.numel()={attn_metadata.slot_mapping.numel()}"
414415 k_hash_compute = (
415416 k_hash_compute .transpose (0 , 1 )
416417 .reshape (- 1 , k_hash_compute .shape [- 1 ])
@@ -510,17 +511,18 @@ def update_decode_topk_gqa_cuda(self, query, k_hash, attn_metadata):
510511 0 , self .decode_req_ids , self .topk_seq_lens_qwen
511512 )
512513 attn_metadata .seq_lens = self .new_seq_lens
514+ # topk for skip layer
515+ self .topk_block_table = attn_metadata .block_table
516+ self .topk_seq_lens = attn_metadata .seq_lens
513517
514518 def update_decode_topk_gqa_npu (self , query , k_hash , attn_metadata ):
515519 q_start = attn_metadata .query_start_loc
516- decode_req_ids = torch . nonzero ( self .decode_mask_npu , as_tuple = False ). flatten ()
517- decode_token_idx = q_start [: - 1 ]. index_select ( 0 , decode_req_ids )
518- q_decode = query . index_select ( 0 , decode_token_idx )
519-
520+ if self .slice_enabled :
521+ q_decode = query [: self . batch_size_for_hamming ]
522+ else :
523+ q_decode = query . index_select ( 0 , q_start [: - 1 ])
520524 q_hash = self .hash_encoder .compute_hash (q_decode ).unsqueeze (2 ).contiguous ()
521525
522- block_table_decode = attn_metadata .block_table .index_select (0 , decode_req_ids )
523-
524526 ucm_custom_ops .hamming_dist_top_k (
525527 q_hash ,
526528 k_hash ,
@@ -531,17 +533,26 @@ def update_decode_topk_gqa_npu(self, query, k_hash, attn_metadata):
531533 self .hamming_keep_chunks_head ,
532534 self .hamming_keep_chunks_tail ,
533535 0 , # support_offload is disabled
534- block_table_decode ,
535- self .hamming_output [: len (decode_req_ids )],
536+ self .block_table_decode ,
537+ (self .decode_mask_npu if not self .slice_enabled else None ),
538+ self .hamming_output [: self .batch_size_for_hamming ],
536539 )
537- topk = self .hamming_output .shape [- 1 ]
538- attn_metadata .block_table [decode_req_ids , :topk ] = self .hamming_output [
539- : len (decode_req_ids ), 0 , :
540- ]
541- attn_metadata .block_table [decode_req_ids , topk :] = 0
542-
543- # we have already computed the topk_seq_lens_qwen in `build_decode_attention_meta_npu()`
544- attn_metadata .seq_lens [self .decode_mask ] = self .topk_seq_lens_qwen
540+ new_seq_lens = self .topk_seq_lens_qwen
541+ attn_metadata .seq_lens = new_seq_lens
542+ if (
543+ self .slice_enabled
544+ and attn_metadata .attn_state != AscendAttentionState .DecodeOnly
545+ ):
546+ new_block_tables = attn_metadata .block_tables .clone ()
547+ new_block_tables [: self .batch_size_for_hamming ] = self .hamming_output [
548+ : self .batch_size_for_hamming , 0 , :
549+ ]
550+ else :
551+ new_block_tables = self .hamming_output [: self .batch_size_for_hamming , 0 , :]
552+ attn_metadata .block_tables = new_block_tables
553+ # topk for skip layer
554+ self .topk_block_table = attn_metadata .block_tables
555+ self .topk_seq_lens = attn_metadata .seq_lens
545556
546557 def attention_begin (
547558 self ,
@@ -608,9 +619,6 @@ def attention_begin(
608619 self .update_decode_topk_gqa_npu (
609620 query , k_hash , attn_metadata
610621 )
611- # topk for skip layer
612- self .topk_block_table = attn_metadata .block_table
613- self .topk_seq_lens = attn_metadata .seq_lens
614622
615623 return query , key , value , output
616624
@@ -719,6 +727,35 @@ def build_decode_hash(self, seq_lens):
719727 )
720728 return topk_seq_lens , topk_tile_scheduler_metadata , topk_num_splits
721729
730+ def build_decode_attention_meta_npu (self , query_lens , seq_lens , block_table ):
731+
732+ from ucm .sparse .gsa_on_device .hamming_topk import update_seq_lens
733+
734+ # self.decode_mask is on cpu in vllm-asencd under NPU device
735+ self .decode_mask = (query_lens == 1 ) & (seq_lens >= self .seq_len_threshhold )
736+ # self.decode_mask = self.decode_mask.pin_memory()
737+
738+ self .num_decode_requests = self .decode_mask .sum ().item ()
739+ if self .num_decode_requests > 0 :
740+ self .slice_enabled = (
741+ self .decode_mask [: self .num_decode_requests ].all ().item ()
742+ )
743+ else :
744+ self .slice_enabled = False
745+
746+ self .ori_seq_lens_decode = seq_lens .clone ()
747+ self .ori_block_table_decode = block_table .clone ()
748+
749+ if self .decode_mask .any ():
750+ # self.decode_mask_npu = self.decode_mask.to(self.device, non_blocking=True)
751+ self .topk_seq_lens_qwen = update_seq_lens (
752+ seq_lens ,
753+ topk_token = self .hash_topk_tokens ,
754+ block_size = self .block_size ,
755+ )
756+ # (ldeng) set the seq_lens for the non-decode requests to the original seq_lens
757+ self .topk_seq_lens_qwen [~ self .decode_mask ] = seq_lens [~ self .decode_mask ]
758+
722759 def rebuild_prefix_cache_info_for_req (
723760 self ,
724761 block_table_row : torch .Tensor ,
@@ -915,4 +952,4 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
915952 req_data .req_ids , req_data .resumed_from_preemption
916953 ):
917954 if resumed_from_preemption :
918- self ._free_cached_request (req_id )
955+ self ._free_cached_request (req_id )
0 commit comments