Skip to content

Commit 43543d8

Browse files
committed
force bitonic topk
1 parent 89d89c2 commit 43543d8

1 file changed

Lines changed: 3 additions & 7 deletions

File tree

lmdeploy/pytorch/backends/cuda/v4_indexer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,11 @@ def forward(self,
7373
cu_q_seqlens, num_index.to(torch.int32), block_offsets,
7474
max_q_seqlen=meta.max_q_seqlen, max_k_seqlen=max_index, causal=True)
7575

76-
topk_width = min(self.index_topk, max_index)
76+
topk_width = self.index_topk
7777
topk_length = num_index.clamp(max=topk_width).to(torch.int32)
78-
7978
# bitonic_topk requires K to be a power of 2; fall back to torch.topk otherwise
80-
if topk_width > 0 and (topk_width & (topk_width - 1)) == 0:
81-
topk = bitonic_topk(scores, q_seqlens, num_index.to(torch.int32),
82-
k=topk_width, fill=-1, descending=True).long()
83-
else:
84-
topk = scores.topk(topk_width, dim=-1)[1]
79+
topk = bitonic_topk(scores, q_seqlens, num_index.to(torch.int32),
80+
k=topk_width, fill=-1, descending=True).long()
8581

8682
# Always return [total_q, topk_width] — caller handles decode/prefill dimension adaptation
8783
return V4IndexerOutput(indices_in_kvcache=topk, topk_length=topk_length)

0 commit comments

Comments
 (0)