File tree Expand file tree Collapse file tree
lmdeploy/pytorch/backends/cuda Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -73,15 +73,11 @@ def forward(self,
7373 cu_q_seqlens , num_index .to (torch .int32 ), block_offsets ,
7474 max_q_seqlen = meta .max_q_seqlen , max_k_seqlen = max_index , causal = True )
7575
76- topk_width = min ( self .index_topk , max_index )
76+ topk_width = self .index_topk
7777 topk_length = num_index .clamp (max = topk_width ).to (torch .int32 )
78-
7978 # bitonic_topk requires K to be a power of 2; fall back to torch.topk otherwise
80- if topk_width > 0 and (topk_width & (topk_width - 1 )) == 0 :
81- topk = bitonic_topk (scores , q_seqlens , num_index .to (torch .int32 ),
82- k = topk_width , fill = - 1 , descending = True ).long ()
83- else :
84- topk = scores .topk (topk_width , dim = - 1 )[1 ]
79+ topk = bitonic_topk (scores , q_seqlens , num_index .to (torch .int32 ),
80+ k = topk_width , fill = - 1 , descending = True ).long ()
8581
8682 # Always return [total_q, topk_width] — caller handles decode/prefill dimension adaptation
8783 return V4IndexerOutput (indices_in_kvcache = topk , topk_length = topk_length )
You can’t perform that action at this time.
0 commit comments