3636
3737from ucm .sparse .gsa_on_device .gsa_on_device_config import GSAOnDeviceConfig
3838from ucm .sparse .gsa_on_device .hash_encoder import HashEncoder
39+ from ucm .sparse .utils import cdiv
3940from ucm .utils import Config
4041
4142logger = init_logger (__name__ )
@@ -255,13 +256,14 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
255256
256257 def init_for_pc (self ):
257258 # for pc hit
258- self .prefix_slot_mapping_buf = self ._make_buffer (
259- self .max_num_tokens * self .max_batch_size , dtype = torch .int64
259+ self .prefix_slot_mapping_buf = torch .empty (
260+ self .max_num_tokens * self .max_batch_size ,
261+ device = self .device ,
262+ dtype = torch .int64 ,
260263 )
261- self .prefix_block_ids_buf = self ._make_buffer (
262- (self .max_num_tokens * self .max_batch_size + self .block_size - 1 )
263- // self .block_size
264- + 1 ,
264+ self .prefix_block_ids_buf = torch .empty (
265+ cdiv (self .max_num_tokens * self .max_batch_size , self .block_size ),
266+ device = self .device ,
265267 dtype = torch .int32 ,
266268 )
267269 self .token_idx_buf = torch .arange (
@@ -275,11 +277,6 @@ def _make_buffer(
275277 * size , dtype = dtype , device = self .device , pin_memory = True , with_numpy = numpy
276278 )
277279
278- def _clear_buffer (self ) -> None :
279- self .decode_req_ids_buf .clear ()
280- self .prefix_slot_mapping_buf .clear ()
281- self .prefix_block_ids_buf .clear ()
282-
283280 def hash_code (
284281 self ,
285282 nope : Optional [torch .Tensor ] = None ,
@@ -709,7 +706,8 @@ def rebuild_prefix_cache_info_for_req(
709706 assert 0 <= qlen <= num_prompt_tokens
710707 num_prefix_tokens = num_prompt_tokens - qlen
711708 if num_prefix_tokens <= 0 :
712- return [], []
709+ empty = block_table_row [:0 ]
710+ return 0 , 0 , empty , empty
713711
714712 num_prefix_blocks = (num_prefix_tokens + block_size - 1 ) // block_size
715713 prefix_block_ids = block_table_row [:num_prefix_blocks ] # [prefix_blocks]
@@ -749,7 +747,7 @@ def build_sparse_meta(
749747 compute_q_lens = (
750748 attn_metadata .query_start_loc [1 :] - attn_metadata .query_start_loc [:- 1 ]
751749 )
752- self ._clear_buffer ()
750+ self .decode_req_ids_buf . clear ()
753751
754752 self .num_reqs = len (scheduler_output .num_scheduled_tokens )
755753 for (
@@ -804,10 +802,10 @@ def build_sparse_meta(
804802 block_size = self .block_size ,
805803 )
806804
807- self .prefix_slot_mapping_buf . gpu [
805+ self .prefix_slot_mapping_buf [
808806 all_prefix_tokens : all_prefix_tokens + num_prefix_tokens
809807 ] = prefix_slot_mapping
810- self .prefix_block_ids_buf . gpu [
808+ self .prefix_block_ids_buf [
811809 all_prefix_blocks : all_prefix_blocks + num_prefix_blocks
812810 ] = prefix_block_ids
813811
@@ -853,12 +851,10 @@ def build_sparse_meta(
853851
854852 self .has_pc_hit = num_pc_hit > 0
855853 if self .has_pc_hit :
856- self .prefix_slot_mapping = self .prefix_slot_mapping_buf . gpu [
854+ self .prefix_slot_mapping = self .prefix_slot_mapping_buf [
857855 :all_prefix_tokens
858856 ]
859- self .prefix_block_ids = self .prefix_block_ids_buf .gpu [
860- :all_prefix_blocks
861- ]
857+ self .prefix_block_ids = self .prefix_block_ids_buf [:all_prefix_blocks ]
862858
863859 def maybe_init_cudagraph_buffers_for_topk (self , n , tile_scheduler_metadata ):
864860 sm_parts = tile_scheduler_metadata .size (0 )
0 commit comments