Skip to content

Commit 2ee7939

Browse files
author
niushengxiao
committed
fix
1 parent d3a0dd8 commit 2ee7939

2 files changed

Lines changed: 10 additions & 3 deletions

File tree

lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,22 @@ def _get_indices(
227227

228228
import deep_gemm
229229

230-
logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke)
230+
logits = deep_gemm.fp8_mqa_logits(
231+
q_fp8,
232+
(k_fp8_, k_scale_),
233+
weights.squeeze(-1),
234+
ks,
235+
ke,
236+
clean_logits=False,
237+
max_seqlen_k=infer_state.max_kv_seq_len,
238+
)
231239

232240
from sgl_kernel import fast_topk_v2
233241

234242
b_topk_index = fast_topk_v2(
235243
score=logits,
236244
lengths=lengths,
237245
topk=self.index_topk,
238-
row_starts=ks,
239246
)
240247
b_topk_index = torch.where(b_topk_index != -1, b_topk_index + ks.view(-1, 1), -1)
241248
# 将 topk index 转化为 mem index

lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,4 @@ def extract_indexer_ks(
112112
num_stages=1,
113113
)
114114

115-
return O_fp8, O_scale
115+
return O_fp8, O_scale.squeeze(-1)

0 commit comments

Comments
 (0)