Skip to content

Commit e0e2e1e

Browse files
bugfix
1 parent ce3710d commit e0e2e1e

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

ucm/sparse/gsa_on_device/configs/gsa_on_device_qwen3_coder_30B_A3B_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120
"hash_bits_qk_rope": null,
121121
"hash_weight_kv_lora": null,
122122
"hash_weight_qk_rope": null,
123-
"vllm_hash_attention_topk": 2048,
123+
"vllm_hash_attention_topk": 4096,
124124
"vllm_hash_attention_reduction_head_num": null,
125125
"vllm_hash_attention_rollback_layers": [
126126
0,

ucm/sparse/gsa_on_device/gsa_on_device.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -847,10 +847,16 @@ def build_sparse_meta(
847847
self.has_decode = num_decodes > 0
848848
self.decode_only = self.has_decode and (num_decodes == self.num_reqs)
849849
if self.has_decode:
850+
# for roll_back recode the full seqlens & block_table
851+
self.ori_seq_lens_decode = attn_metadata.seq_lens.clone()
852+
self.ori_block_table_decode = attn_metadata.block_table.clone()
853+
850854
if self.decode_only:
851-
decode_seq_lens = attn_metadata.seq_lens[: self.num_reqs]
852-
self.block_table_decode = attn_metadata.block_table[: self.num_reqs]
853-
self.seq_len_decode = attn_metadata.seq_lens[: self.num_reqs]
855+
decode_seq_lens = self.ori_seq_lens_decode[: self.num_reqs]
856+
self.block_table_decode = self.ori_block_table_decode[
857+
: self.num_reqs
858+
]
859+
self.seq_len_decode = self.ori_seq_lens_decode[: self.num_reqs]
854860
else:
855861
self.decode_req_ids_buf.copy_to_gpu(num_decodes)
856862
self.decode_req_ids = self.decode_req_ids_buf.gpu[:num_decodes]
@@ -870,9 +876,6 @@ def build_sparse_meta(
870876
topk_token=self.hash_topk_tokens,
871877
block_size=self.block_size,
872878
)
873-
# for roll_back
874-
self.ori_seq_lens_decode = attn_metadata.seq_lens.clone()
875-
self.ori_block_table_decode = attn_metadata.block_table.clone()
876879

877880
self.new_block_table = attn_metadata.block_table
878881
self.new_seq_lens = attn_metadata.seq_lens

0 commit comments

Comments
 (0)