Skip to content

Commit 4496e69

Browse files
authored
[None][feat] fix mamba metadata prefill bubble in chunked prefill serving (#12736)
Signed-off-by: Shijie Wang <jaywan@nvidia.com>
1 parent 6488d7f commit 4496e69

File tree

1 file changed

+35
-10
lines changed

1 file changed

+35
-10
lines changed

tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def _cu_seqlens_triton_kernel(
3131
cu_seqlens_ptr, # [num_seqs + 1]
3232
chunk_indices_ptr, # [N] output
3333
chunk_offsets_ptr, # [N] output
34-
num_seqs: tl.constexpr,
34+
num_seqs,
3535
chunk_size: tl.constexpr,
36-
N: tl.constexpr,
36+
N,
3737
BLOCK_SIZE: tl.constexpr,
3838
):
3939
"""Computes chunk_indices and chunk_offsets in a single kernel launch."""
@@ -65,8 +65,18 @@ def _cu_seqlens_triton_kernel(
6565

6666
def cu_seqlens_to_chunk_indices_offsets_triton(
6767
cu_seqlens: torch.Tensor,
68-
chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
69-
"""Optimized version of cu_seqlens_to_chunk_indices_offsets."""
68+
chunk_size: int,
69+
total_seqlens: int = -1,
70+
extra_chunks: int = -1) -> Tuple[torch.Tensor, torch.Tensor]:
71+
"""Optimized version of cu_seqlens_to_chunk_indices_offsets.
72+
73+
Args:
74+
total_seqlens: If provided (>= 0), avoids a GPU->CPU sync to read
75+
cu_seqlens[-1]. Callers that already know the total number of
76+
context tokens should pass it here.
77+
extra_chunks: If provided (>= 0), avoids a GPU->CPU sync to compute
78+
the number of extra chunks from misaligned sequence boundaries.
79+
"""
7080
device = cu_seqlens.device
7181
num_seqs = cu_seqlens.numel() - 1
7282

@@ -75,18 +85,20 @@ def cu_seqlens_to_chunk_indices_offsets_triton(
7585
torch.empty(0, dtype=torch.int, device=device))
7686

7787
cu = cu_seqlens.to(dtype=torch.int64)
78-
total_seqlens = cu[-1].item()
88+
if total_seqlens < 0:
89+
total_seqlens = cu[-1].item()
7990

8091
if num_seqs == 1:
8192
# Fast path for single sequence (no boundaries to process)
8293
N = (total_seqlens + chunk_size - 1) // chunk_size
8394
return (torch.arange(N, device=device, dtype=torch.int),
8495
torch.zeros(N, device=device, dtype=torch.int))
8596

86-
seq_starts = cu[1:-1]
87-
misaligned = ((seq_starts % chunk_size) > 0).to(torch.int64)
88-
p = torch.cumsum(misaligned, dim=0)
89-
extra_chunks = p[-1].item() if p.numel() > 0 else 0
97+
if extra_chunks < 0:
98+
seq_starts = cu[1:-1]
99+
misaligned = ((seq_starts % chunk_size) > 0).to(torch.int64)
100+
p = torch.cumsum(misaligned, dim=0)
101+
extra_chunks = p[-1].item() if p.numel() > 0 else 0
90102
N = (total_seqlens + chunk_size - 1) // chunk_size + extra_chunks
91103
chunk_indices = torch.empty(N, device=device, dtype=torch.int)
92104
chunk_offsets = torch.empty(N, device=device, dtype=torch.int)
@@ -279,8 +291,21 @@ def prepare(self, attn_metadata: AttentionMetadata):
279291
self.has_initial_states_cpu[:num_contexts].any())
280292

281293
if self.use_initial_states:
294+
# Compute extra_chunks using pure Python arithmetic on CPU
295+
# seq_lens to avoid any GPU->CPU sync point.
296+
_cs = self.chunk_size
297+
_cumsum = 0
298+
_extra = 0
299+
for i in range(num_contexts - 1):
300+
_cumsum += int(attn_metadata.seq_lens[i])
301+
if _cumsum % _cs != 0:
302+
_extra += 1
303+
282304
self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets_triton(
283-
self.cu_seqlens[:num_contexts + 1], self.chunk_size)
305+
self.cu_seqlens[:num_contexts + 1],
306+
self.chunk_size,
307+
total_seqlens=num_ctx_tokens,
308+
extra_chunks=_extra)
284309
else:
285310
self.chunk_indices = None
286311
self.chunk_offsets = None

0 commit comments

Comments
 (0)