@@ -179,7 +179,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
179179 self .topk_seq_lens_qwen = None
180180 self .has_pc_hit = False
181181
182- self .cached_reqs_to_step : dict [str , int ] = dict ()
182+ self .is_prefill_flag : dict [str , bool ] = dict ()
183183
184184 self ._k_scale = torch .tensor (1.0 , dtype = torch .float32 )
185185
@@ -790,6 +790,12 @@ def rebuild_prefix_cache_info_for_req(
790790 prefix_slot_mapping ,
791791 )
792792
793+ def get_block_table_row (self , attn_metadata , req_row_id ):
794+ if self .is_cuda :
795+ return attn_metadata .block_table [req_row_id ]
796+ else :
797+ return attn_metadata .block_tables [req_row_id ]
798+
793799 def build_sparse_meta (
794800 self , scheduler_output , requests , input_batch , attn_metadata
795801 ) -> UcmSparseMetadata :
@@ -822,14 +828,12 @@ def build_sparse_meta(
822828 req = requests [req_id ]
823829 # req_state: is_decode is_first_prefil is_prefill is_last_chunk
824830 is_decode = (
825- req_id in self .cached_reqs_to_step
826- and self .cached_reqs_to_step [req_id ]
827- > 0 # step always=0 when prefill
831+ req_id in self .is_prefill_flag and not self .is_prefill_flag [req_id ]
828832 )
829833 is_first_prefil = (
830- req_id not in self .cached_reqs_to_step
834+ req_id not in self .is_prefill_flag
831835 ) # first prefill when chunkprefill
832- is_prefill = is_first_prefil or self .cached_reqs_to_step [req_id ] == 0
836+ is_prefill = is_first_prefil or self .is_prefill_flag [req_id ]
833837 is_last_chunk = is_prefill and (
834838 req .num_computed_tokens + num_scheduled_tokens
835839 >= req .num_prompt_tokens
@@ -846,7 +850,7 @@ def build_sparse_meta(
846850 num_decodes += 1
847851
848852 if is_first_prefil :
849- self .cached_reqs_to_step [req_id ] = 0
853+ self .is_prefill_flag [req_id ] = True
850854 # num_prompt_tokens -> store pc -> rebuild slotmapping
851855 req_row_id = input_batch .req_id_to_index [req_id ]
852856 ext_tokens = int (
@@ -855,13 +859,16 @@ def build_sparse_meta(
855859 )
856860 )
857861 if ext_tokens > 0 :
862+ block_table_row = self .get_block_table_row (
863+ attn_metadata , req_row_id
864+ )
858865 (
859866 num_prefix_tokens ,
860867 num_prefix_blocks ,
861868 prefix_block_ids ,
862869 prefix_slot_mapping ,
863870 ) = self .rebuild_prefix_cache_info_for_req (
864- block_table_row = attn_metadata . block_table [ req_row_id ] ,
871+ block_table_row = block_table_row ,
865872 num_prompt_tokens = req .num_prompt_tokens ,
866873 qlen = compute_q_lens [req_row_id ],
867874 block_size = self .block_size ,
@@ -879,11 +886,12 @@ def build_sparse_meta(
879886 num_pc_hit += 1
880887
881888 if is_last_chunk :
882- self .cached_reqs_to_step [req_id ] += 1
889+ self .is_prefill_flag [req_id ] = False
883890
884891 self .has_decode = num_decodes > 0
885892 self .decode_only = self .has_decode and (num_decodes == self .num_reqs )
886- if self .has_decode :
893+ # build sparse meta for cuda
894+ if self .has_decode and self .is_cuda :
887895 # for roll_back recode the full seqlens & block_table
888896 self .ori_seq_lens_decode = attn_metadata .seq_lens .clone ()
889897 self .ori_block_table_decode = attn_metadata .block_table .clone ()
@@ -939,9 +947,9 @@ def maybe_init_cudagraph_buffers_for_topk(self, n, tile_scheduler_metadata):
939947 return topk_tile_scheduler_metadata , topk_num_splits
940948
941949 def _free_cached_request (self , request_id : Union [int , str ]) -> None :
942- if request_id not in self .cached_reqs_to_step :
950+ if request_id not in self .is_prefill_flag :
943951 return
944- del self .cached_reqs_to_step [request_id ]
952+ del self .is_prefill_flag [request_id ]
945953
946954 def update_states (self , scheduler_output : SchedulerOutput ) -> None :
947955 for req_id in scheduler_output .finished_req_ids :
0 commit comments