Skip to content

Commit d522412

Browse files
committed
[None][feat] reuse triton slicing kernel for GDN prefill transpose
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
1 parent 1045f38 commit d522412

File tree

2 files changed

+50
-16
lines changed

2 files changed

+50
-16
lines changed

tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,36 @@ def _extract_transpose_prefill_kernel(
5151
tl.store(dst_ptr + dst_offsets, tl.trans(data), mask=conv_mask[:, None] & seq_mask[None, :])
5252

5353

54+
def extract_transpose_prefill_slice(
55+
src: torch.Tensor,
56+
num_prefill_tokens: int,
57+
start_col: int,
58+
width: int,
59+
) -> torch.Tensor:
60+
"""
61+
Extract and transpose a contiguous prefill slice for causal_conv1d_fn.
62+
63+
Input: src[num_tokens, num_cols]
64+
Output: [width, num_prefill_tokens]
65+
"""
66+
out = torch.empty(width, num_prefill_tokens, dtype=src.dtype, device=src.device)
67+
68+
BLOCK_SEQ, BLOCK_CONV = 32, 128
69+
grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(width, BLOCK_CONV))
70+
71+
_extract_transpose_prefill_kernel[grid](
72+
src,
73+
out,
74+
num_prefill_tokens,
75+
src.shape[1],
76+
start_col,
77+
width,
78+
BLOCK_SEQ,
79+
BLOCK_CONV,
80+
)
81+
return out
82+
83+
5484
def extract_transpose_xbc_prefill(
5585
zxbcdt: torch.Tensor,
5686
num_prefill_tokens: int,
@@ -63,22 +93,12 @@ def extract_transpose_xbc_prefill(
6393
Input: zxbcdt[num_tokens, d_in_proj]
6494
Output: [conv_dim, num_prefill_tokens]
6595
"""
66-
out = torch.empty(conv_dim, num_prefill_tokens, dtype=zxbcdt.dtype, device=zxbcdt.device)
67-
68-
BLOCK_SEQ, BLOCK_CONV = 32, 128
69-
grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(conv_dim, BLOCK_CONV))
70-
71-
_extract_transpose_prefill_kernel[grid](
96+
return extract_transpose_prefill_slice(
7297
zxbcdt,
73-
out,
7498
num_prefill_tokens,
75-
zxbcdt.shape[1],
7699
d_inner,
77100
conv_dim,
78-
BLOCK_SEQ,
79-
BLOCK_CONV,
80101
)
81-
return out
82102

83103

84104
@triton.jit

tensorrt_llm/_torch/modules/mamba/gdn_mixer.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..multi_stream_utils import maybe_execute_in_parallel
2929
from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update
3030
from .causal_conv1d_triton import causal_conv1d_update as causal_conv1d_update_triton
31+
from .fuse_elementwise_ops import extract_transpose_prefill_slice
3132
from .layernorm_gated import RMSNorm as RMSNormGated
3233
from .mamba2_metadata import Mamba2Metadata
3334

@@ -544,16 +545,22 @@ def forward_extend(
544545
query_start_loc_p = query_start_loc[: num_prefill + 1]
545546
has_initial_states_p = has_initial_states[:num_prefill]
546547

547-
mixed_qkv_p = causal_conv1d_fn(
548-
mixed_qkv_p.transpose(0, 1),
548+
mixed_qkv_p_t = extract_transpose_prefill_slice(
549+
mixed_qkv_p,
550+
mixed_qkv_p.shape[0],
551+
0,
552+
mixed_qkv_p.shape[1],
553+
)
554+
mixed_qkv_p_t = causal_conv1d_fn(
555+
mixed_qkv_p_t,
549556
self.conv1d.weight,
550557
self.conv1d.bias,
551558
activation=self.activation,
552559
conv_states=conv_states_to_use,
553560
has_initial_state=has_initial_states_p,
554561
cache_indices=state_indices_p,
555562
query_start_loc=query_start_loc_p,
556-
).transpose(0, 1)
563+
)
557564

558565
if is_target_verify:
559566
draft_token_num = spec_metadata.max_draft_len + 1
@@ -588,18 +595,25 @@ def forward_extend(
588595
activation=self.activation,
589596
conv_state_indices=state_indices_d,
590597
)
598+
mixed_qkv_p.copy_(mixed_qkv_p_t.transpose(0, 1))
591599
mixed_qkv = torch.cat((mixed_qkv_p, mixed_qkv_d), dim=0)
592600
else:
601+
mixed_qkv_t = extract_transpose_prefill_slice(
602+
mixed_qkv,
603+
mixed_qkv.shape[0],
604+
0,
605+
mixed_qkv.shape[1],
606+
)
593607
mixed_qkv = causal_conv1d_fn(
594-
mixed_qkv.transpose(0, 1),
608+
mixed_qkv_t,
595609
self.conv1d.weight,
596610
self.conv1d.bias,
597611
activation=self.activation,
598612
conv_states=conv_states_to_use,
599613
has_initial_state=has_initial_states,
600614
cache_indices=cache_indices,
601615
query_start_loc=query_start_loc,
602-
).transpose(0, 1)
616+
)
603617

604618
key_split_dim = self.key_dim // self.attn_tp_size
605619
value_split_dim = self.value_dim // self.attn_tp_size

0 commit comments

Comments
 (0)