Skip to content

Commit bfa9b83

Browse files
fix & opt
1 parent 909b8e7 commit bfa9b83

2 files changed

Lines changed: 21 additions & 13 deletions

File tree

ucm/sparse/esa/esa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
444444
self._sparse_metadata: ESASparseMetaData = ESASparseMetaData()
445445
self.request_hasher = RequestHasher(vllm_config, 0)
446446
self.block_size = vllm_config.cache_config.block_size
447-
self.block_hashes: dict[int, dict[str, list[bytes]]] = {}
447+
self.block_hashes: dict[str, dict[int, list[bytes]]] = {}
448448
global data
449449

450450
if data is None:

ucm/sparse/gsa_on_device/gsa_on_device.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
179179
self.topk_seq_lens_qwen = None
180180
self.has_pc_hit = False
181181

182-
self.cached_reqs_to_step: dict[str, int] = dict()
182+
self.is_prefill_flag: dict[str, bool] = dict()
183183

184184
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
185185

@@ -790,6 +790,12 @@ def rebuild_prefix_cache_info_for_req(
790790
prefix_slot_mapping,
791791
)
792792

793+
def get_block_table_row(self, attn_metadata, req_row_id):
794+
if self.is_cuda:
795+
return attn_metadata.block_table[req_row_id]
796+
else:
797+
return attn_metadata.block_tables[req_row_id]
798+
793799
def build_sparse_meta(
794800
self, scheduler_output, requests, input_batch, attn_metadata
795801
) -> UcmSparseMetadata:
@@ -822,14 +828,12 @@ def build_sparse_meta(
822828
req = requests[req_id]
823829
# req_state: is_decode is_first_prefil is_prefill is_last_chunk
824830
is_decode = (
825-
req_id in self.cached_reqs_to_step
826-
and self.cached_reqs_to_step[req_id]
827-
> 0 # step always=0 when prefill
831+
req_id in self.is_prefill_flag and not self.is_prefill_flag[req_id]
828832
)
829833
is_first_prefil = (
830-
req_id not in self.cached_reqs_to_step
834+
req_id not in self.is_prefill_flag
831835
) # first prefill when chunkprefill
832-
is_prefill = is_first_prefil or self.cached_reqs_to_step[req_id] == 0
836+
is_prefill = is_first_prefil or self.is_prefill_flag[req_id]
833837
is_last_chunk = is_prefill and (
834838
req.num_computed_tokens + num_scheduled_tokens
835839
>= req.num_prompt_tokens
@@ -846,7 +850,7 @@ def build_sparse_meta(
846850
num_decodes += 1
847851

848852
if is_first_prefil:
849-
self.cached_reqs_to_step[req_id] = 0
853+
self.is_prefill_flag[req_id] = True
850854
# num_prompt_tokens -> store pc -> rebuild slotmapping
851855
req_row_id = input_batch.req_id_to_index[req_id]
852856
ext_tokens = int(
@@ -855,13 +859,16 @@ def build_sparse_meta(
855859
)
856860
)
857861
if ext_tokens > 0:
862+
block_table_row = self.get_block_table_row(
863+
attn_metadata, req_row_id
864+
)
858865
(
859866
num_prefix_tokens,
860867
num_prefix_blocks,
861868
prefix_block_ids,
862869
prefix_slot_mapping,
863870
) = self.rebuild_prefix_cache_info_for_req(
864-
block_table_row=attn_metadata.block_table[req_row_id],
871+
block_table_row=block_table_row,
865872
num_prompt_tokens=req.num_prompt_tokens,
866873
qlen=compute_q_lens[req_row_id],
867874
block_size=self.block_size,
@@ -879,11 +886,12 @@ def build_sparse_meta(
879886
num_pc_hit += 1
880887

881888
if is_last_chunk:
882-
self.cached_reqs_to_step[req_id] += 1
889+
self.is_prefill_flag[req_id] = False
883890

884891
self.has_decode = num_decodes > 0
885892
self.decode_only = self.has_decode and (num_decodes == self.num_reqs)
886-
if self.has_decode:
893+
# build sparse meta for cuda
894+
if self.has_decode and self.is_cuda:
887895
# for roll_back recode the full seqlens & block_table
888896
self.ori_seq_lens_decode = attn_metadata.seq_lens.clone()
889897
self.ori_block_table_decode = attn_metadata.block_table.clone()
@@ -939,9 +947,9 @@ def maybe_init_cudagraph_buffers_for_topk(self, n, tile_scheduler_metadata):
939947
return topk_tile_scheduler_metadata, topk_num_splits
940948

941949
def _free_cached_request(self, request_id: Union[int, str]) -> None:
942-
if request_id not in self.cached_reqs_to_step:
950+
if request_id not in self.is_prefill_flag:
943951
return
944-
del self.cached_reqs_to_step[request_id]
952+
del self.is_prefill_flag[request_id]
945953

946954
def update_states(self, scheduler_output: SchedulerOutput) -> None:
947955
for req_id in scheduler_output.finished_req_ids:

0 commit comments

Comments
 (0)