Skip to content

Commit 813ca54

Browse files
committed
fix qwen
1 parent c6f93dd commit 813ca54

2 files changed

Lines changed: 229 additions & 32 deletions

File tree

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
{
2+
"model_name": "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8",
3+
"is_mla": false,
4+
"hash_weight_type": "random",
5+
"num_hidden_layers": 48,
6+
"seq_len_threshhold": 2048,
7+
"chunk_size": 128,
8+
"chunk_repre_method": "max",
9+
"head_dim": 128,
10+
"hash_bits": 128,
11+
"top_k_ratio_per_layer": [
12+
1,
13+
1,
14+
0.3,
15+
0.3,
16+
0.3,
17+
0.3,
18+
0.3,
19+
0.3,
20+
0.3,
21+
0.3,
22+
0.3,
23+
0.3,
24+
0.3,
25+
0.3,
26+
0.3,
27+
0.3,
28+
0.3,
29+
0.3,
30+
0.3,
31+
0.3,
32+
0.3,
33+
0.3,
34+
0.3,
35+
0.3,
36+
0.3,
37+
0.3,
38+
0.3,
39+
0.3,
40+
0.3,
41+
0.3,
42+
0.3,
43+
0.3,
44+
0.3,
45+
0.3,
46+
0.3,
47+
0.3,
48+
0.3,
49+
0.3,
50+
0.3,
51+
0.3,
52+
0.3,
53+
0.3,
54+
0.3,
55+
0.3,
56+
0.3,
57+
1,
58+
1,
59+
1
60+
],
61+
"top_k_index_reuse": [
62+
-1,
63+
-1,
64+
-1,
65+
-1,
66+
-1,
67+
-1,
68+
-1,
69+
-1,
70+
-1,
71+
-1,
72+
-1,
73+
-1,
74+
-1,
75+
-1,
76+
-1,
77+
-1,
78+
-1,
79+
-1,
80+
-1,
81+
-1,
82+
-1,
83+
-1,
84+
-1,
85+
-1,
86+
-1,
87+
-1,
88+
-1,
89+
-1,
90+
-1,
91+
-1,
92+
-1,
93+
-1,
94+
-1,
95+
-1,
96+
-1,
97+
-1,
98+
-1,
99+
-1,
100+
-1,
101+
-1,
102+
-1,
103+
-1,
104+
-1,
105+
-1,
106+
-1,
107+
-1,
108+
-1,
109+
-1
110+
],
111+
"must_select_blocks": [
112+
0,
113+
-2,
114+
-1
115+
],
116+
"hash_weight": null,
117+
"kv_lora_rank": null,
118+
"qk_rope_head_dim": null,
119+
"hash_bits_kv_lora": null,
120+
"hash_bits_qk_rope": null,
121+
"hash_weight_kv_lora": null,
122+
"hash_weight_qk_rope": null,
123+
"vllm_hash_attention_topk": 2048,
124+
"vllm_hash_attention_reduction_head_num": null,
125+
"vllm_hash_attention_rollback_layers": [
126+
0,
127+
1,
128+
2,
129+
3,
130+
4,
131+
5,
132+
45,
133+
46,
134+
47
135+
],
136+
"vllm_hash_attention_skip_layers": [
137+
true,
138+
true,
139+
true,
140+
true,
141+
true,
142+
true,
143+
false,
144+
false,
145+
true,
146+
false,
147+
true,
148+
false,
149+
false,
150+
false,
151+
false,
152+
false,
153+
true,
154+
true,
155+
true,
156+
false,
157+
true,
158+
true,
159+
true,
160+
true,
161+
true,
162+
false,
163+
true,
164+
false,
165+
false,
166+
false,
167+
true,
168+
false,
169+
true,
170+
false,
171+
false,
172+
true,
173+
true,
174+
true,
175+
true,
176+
true,
177+
false,
178+
true,
179+
true,
180+
true,
181+
true,
182+
true,
183+
true,
184+
true
185+
]
186+
}

ucm/sparse/kvcomp/kvcomp_hbm.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,16 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
198198
device=self.device,
199199
)
200200

201-
if not self.is_cuda: # NPU only variables
202-
self.decode_mask_npu = None
203-
self.is_tensor_computed = False
204-
self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
201+
self.is_tensor_computed = False
202+
self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
205203

204+
if self.is_cuda: # CUDA only variables
205+
self.seq_len_decode = torch.zeros(
206+
[self.max_batch_size], dtype=torch.int32, device=self.device
207+
)
208+
209+
else: # NPU only variables
210+
self.decode_mask_npu = None
206211
self.hamming_keep_chunks_head = 1
207212
self.hamming_keep_chunks_tail = 4
208213

@@ -325,6 +330,29 @@ def attention_begin(
325330
)
326331
else: # GQA
327332
if self.is_cuda:
333+
if not self.is_tensor_computed:
334+
if self.decode_mask.any(): # with at least one decode request
335+
self.decode_req_ids = torch.nonzero(
336+
self.decode_mask, as_tuple=False
337+
).flatten()
338+
339+
q_start = attn_metadata.query_start_loc
340+
341+
self.decode_token_idx = q_start[:-1].index_select(
342+
0, self.decode_req_ids
343+
)
344+
345+
self.block_table_decode = attn_metadata.block_table.index_select(
346+
0, self.decode_req_ids
347+
)
348+
349+
self.seq_len_decode = self.ori_seq_lens_decode.index_select(
350+
0, self.decode_req_ids
351+
)
352+
self.new_block_table = attn_metadata.block_table
353+
self.new_seq_lens = attn_metadata.seq_lens
354+
self.is_tensor_computed = True
355+
328356
k_hash_compute = self.hash_encoder.compute_hash(key).view(
329357
torch.bfloat16
330358
)
@@ -426,53 +454,36 @@ def attention_begin(
426454
)
427455
attn_metadata.decode.num_splits = num_splits
428456
else: # GQA
429-
q_start = attn_metadata.query_start_loc
430457
if self.decode_mask.any(): # 有decode阶段的req
431458
if not is_rollback_layer:
432459
if is_skip_hash_layer:
433460
# 跳层 使用上一个topk结果
434-
attn_metadata.block_tables = self.topk_block_table
461+
attn_metadata.block_table = self.topk_block_table
435462
attn_metadata.seq_lens = self.topk_seq_lens
436463
else:
437464
if self.is_cuda:
438-
439-
decode_req_ids = torch.nonzero(
440-
self.decode_mask, as_tuple=False
441-
).flatten()
442-
decode_token_idx = q_start[:-1].index_select(
443-
0, decode_req_ids
444-
)
445-
q_decode = query.index_select(0, decode_token_idx)
465+
q_decode = query.index_select(0, self.decode_token_idx)
446466
q_hash = self.hash_code(query=q_decode)
447-
448-
topk_token = self.hash_topk_tokens
449-
450-
block_table_decode = attn_metadata.block_table.index_select(
451-
0, decode_req_ids
452-
)
453-
seq_len_decode = self.ori_seq_lens_decode.index_select(
454-
0, decode_req_ids
455-
)
456467
block_table_decode = cuda_hamming_topk(
457468
q_hash.unsqueeze(1),
458469
k_hash,
459-
block_table_decode,
460-
seq_len_decode,
461-
topk_token=topk_token,
470+
self.block_table_decode,
471+
self.seq_len_decode,
472+
topk_token=self.hash_topk_tokens,
462473
sink_token=64,
463474
recent_token=512,
464475
is_mla=self.is_mla,
465476
)
466477
# update topk_block_table
467478
topk = block_table_decode.shape[1]
468-
attn_metadata.block_table[decode_req_ids, :topk] = (
479+
self.new_block_table[self.decode_req_ids, :topk] = (
469480
block_table_decode
470481
)
471-
attn_metadata.block_table[decode_req_ids, topk:] = 0
472-
473-
attn_metadata.seq_lens[self.decode_mask] = (
474-
self.topk_seq_lens_qwen
475-
)
482+
self.new_block_table[self.decode_req_ids, topk:] = 0
483+
attn_metadata.block_table = self.new_block_table
484+
self.new_seq_lens[self.decode_mask] = self.topk_seq_lens_qwen
485+
attn_metadata.seq_lens = self.new_seq_lens
486+
476487
else: # NPU
477488

478489
decode_req_ids = torch.nonzero(

0 commit comments

Comments
 (0)