diff --git a/tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py b/tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py index 21e10a51b327..e270503ca71f 100644 --- a/tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py +++ b/tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py @@ -51,6 +51,36 @@ def _extract_transpose_prefill_kernel( tl.store(dst_ptr + dst_offsets, tl.trans(data), mask=conv_mask[:, None] & seq_mask[None, :]) +def extract_transpose_prefill_slice( + src: torch.Tensor, + num_prefill_tokens: int, + start_col: int, + width: int, +) -> torch.Tensor: + """ + Extract and transpose a contiguous prefill slice for causal_conv1d_fn. + + Input: src[num_tokens, num_cols] + Output: [width, num_prefill_tokens] + """ + out = torch.empty(width, num_prefill_tokens, dtype=src.dtype, device=src.device) + + BLOCK_SEQ, BLOCK_CONV = 32, 128 + grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(width, BLOCK_CONV)) + + _extract_transpose_prefill_kernel[grid]( + src, + out, + num_prefill_tokens, + src.shape[1], + start_col, + width, + BLOCK_SEQ, + BLOCK_CONV, + ) + return out + + def extract_transpose_xbc_prefill( zxbcdt: torch.Tensor, num_prefill_tokens: int, @@ -63,22 +93,12 @@ def extract_transpose_xbc_prefill( Input: zxbcdt[num_tokens, d_in_proj] Output: [conv_dim, num_prefill_tokens] """ - out = torch.empty(conv_dim, num_prefill_tokens, dtype=zxbcdt.dtype, device=zxbcdt.device) - - BLOCK_SEQ, BLOCK_CONV = 32, 128 - grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(conv_dim, BLOCK_CONV)) - - _extract_transpose_prefill_kernel[grid]( + return extract_transpose_prefill_slice( zxbcdt, - out, num_prefill_tokens, - zxbcdt.shape[1], d_inner, conv_dim, - BLOCK_SEQ, - BLOCK_CONV, ) - return out @triton.jit diff --git a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py index 2f62d40befdf..18989c901d22 100644 --- a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py @@ -28,6 +28,7 @@ from ..multi_stream_utils import maybe_execute_in_parallel from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update from .causal_conv1d_triton import causal_conv1d_update as causal_conv1d_update_triton +from .fuse_elementwise_ops import extract_transpose_prefill_slice from .layernorm_gated import RMSNorm as RMSNormGated from .mamba2_metadata import Mamba2Metadata @@ -544,8 +545,14 @@ def forward_extend( query_start_loc_p = query_start_loc[: num_prefill + 1] has_initial_states_p = has_initial_states[:num_prefill] - mixed_qkv_p = causal_conv1d_fn( - mixed_qkv_p.transpose(0, 1), + mixed_qkv_p_t = extract_transpose_prefill_slice( + mixed_qkv_p, + mixed_qkv_p.shape[0], + 0, + mixed_qkv_p.shape[1], + ) + mixed_qkv_p_t = causal_conv1d_fn( + mixed_qkv_p_t, self.conv1d.weight, self.conv1d.bias, activation=self.activation, @@ -553,7 +560,7 @@ def forward_extend( has_initial_state=has_initial_states_p, cache_indices=state_indices_p, query_start_loc=query_start_loc_p, - ).transpose(0, 1) + ) if is_target_verify: draft_token_num = spec_metadata.max_draft_len + 1 @@ -588,10 +595,17 @@ def forward_extend( activation=self.activation, conv_state_indices=state_indices_d, ) + mixed_qkv_p.copy_(mixed_qkv_p_t.transpose(0, 1)) mixed_qkv = torch.cat((mixed_qkv_p, mixed_qkv_d), dim=0) else: + mixed_qkv_t = extract_transpose_prefill_slice( + mixed_qkv, + mixed_qkv.shape[0], + 0, + mixed_qkv.shape[1], + ) mixed_qkv = causal_conv1d_fn( - mixed_qkv.transpose(0, 1), + mixed_qkv_t, self.conv1d.weight, self.conv1d.bias, activation=self.activation,