Skip to content

Commit dc4f541

Browse files
authored
fix trtllm_mla attention backend when disabling cuda graph. (#12687)
1 parent 0648eb4 commit dc4f541

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

python/sglang/srt/layers/attention/trtllm_mla_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
585585
if forward_batch.forward_mode.is_target_verify():
586586
max_seq = max_seq + self.num_draft_tokens
587587
seq_lens = seq_lens + self.num_draft_tokens
588-
self.forward_decode_metadata.seq_lens_k = seq_lens
588+
self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)
589589
elif forward_batch.forward_mode.is_draft_extend(include_v2=True):
590590
max_seq = forward_batch.seq_lens_cpu.max().item()
591591

@@ -604,7 +604,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
604604
self.forward_decode_metadata.sum_seq_lens_q = sum_seq_lens_q
605605
self.forward_decode_metadata.cu_seqlens_q = cu_seqlens_q
606606
self.forward_decode_metadata.seq_lens_q = forward_batch.extend_seq_lens
607-
self.forward_decode_metadata.seq_lens_k = seq_lens
607+
self.forward_decode_metadata.seq_lens_k = seq_lens.to(torch.int32)
608608

609609
max_seqlen_pad = self._calc_padded_blocks(max_seq)
610610
block_kv_indices = self._create_block_kv_indices(

0 commit comments

Comments
 (0)