Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,36 @@ def _extract_transpose_prefill_kernel(
tl.store(dst_ptr + dst_offsets, tl.trans(data), mask=conv_mask[:, None] & seq_mask[None, :])


def extract_transpose_prefill_slice(
src: torch.Tensor,
num_prefill_tokens: int,
start_col: int,
width: int,
) -> torch.Tensor:
"""
Extract and transpose a contiguous prefill slice for causal_conv1d_fn.

Input: src[num_tokens, num_cols]
Output: [width, num_prefill_tokens]
"""
out = torch.empty(width, num_prefill_tokens, dtype=src.dtype, device=src.device)

BLOCK_SEQ, BLOCK_CONV = 32, 128
grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(width, BLOCK_CONV))

_extract_transpose_prefill_kernel[grid](
src,
out,
num_prefill_tokens,
src.shape[1],
start_col,
width,
BLOCK_SEQ,
BLOCK_CONV,
)
Comment thread
nv-guomingz marked this conversation as resolved.
return out


def extract_transpose_xbc_prefill(
zxbcdt: torch.Tensor,
num_prefill_tokens: int,
Expand All @@ -63,22 +93,12 @@ def extract_transpose_xbc_prefill(
Input: zxbcdt[num_tokens, d_in_proj]
Output: [conv_dim, num_prefill_tokens]
"""
out = torch.empty(conv_dim, num_prefill_tokens, dtype=zxbcdt.dtype, device=zxbcdt.device)

BLOCK_SEQ, BLOCK_CONV = 32, 128
grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(conv_dim, BLOCK_CONV))

_extract_transpose_prefill_kernel[grid](
return extract_transpose_prefill_slice(
zxbcdt,
out,
num_prefill_tokens,
zxbcdt.shape[1],
d_inner,
conv_dim,
BLOCK_SEQ,
BLOCK_CONV,
)
return out


@triton.jit
Expand Down
22 changes: 18 additions & 4 deletions tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..multi_stream_utils import maybe_execute_in_parallel
from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .causal_conv1d_triton import causal_conv1d_update as causal_conv1d_update_triton
from .fuse_elementwise_ops import extract_transpose_prefill_slice
from .layernorm_gated import RMSNorm as RMSNormGated
from .mamba2_metadata import Mamba2Metadata

Expand Down Expand Up @@ -544,16 +545,22 @@ def forward_extend(
query_start_loc_p = query_start_loc[: num_prefill + 1]
has_initial_states_p = has_initial_states[:num_prefill]

mixed_qkv_p = causal_conv1d_fn(
mixed_qkv_p.transpose(0, 1),
mixed_qkv_p_t = extract_transpose_prefill_slice(
mixed_qkv_p,
mixed_qkv_p.shape[0],
0,
mixed_qkv_p.shape[1],
)
mixed_qkv_p_t = causal_conv1d_fn(
mixed_qkv_p_t,
self.conv1d.weight,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_states_to_use,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_p,
query_start_loc=query_start_loc_p,
).transpose(0, 1)
)

if is_target_verify:
draft_token_num = spec_metadata.max_draft_len + 1
Expand Down Expand Up @@ -588,10 +595,17 @@ def forward_extend(
activation=self.activation,
conv_state_indices=state_indices_d,
)
mixed_qkv_p.copy_(mixed_qkv_p_t.transpose(0, 1))
mixed_qkv = torch.cat((mixed_qkv_p, mixed_qkv_d), dim=0)
else:
mixed_qkv_t = extract_transpose_prefill_slice(
mixed_qkv,
mixed_qkv.shape[0],
0,
mixed_qkv.shape[1],
)
mixed_qkv = causal_conv1d_fn(
mixed_qkv.transpose(0, 1),
mixed_qkv_t,
self.conv1d.weight,
self.conv1d.bias,
activation=self.activation,
Expand Down
Loading