Skip to content

Commit 2ff65f5

Browse files
authored
[None][feat] Optimize qwen3.5 decode delta kernel (#12740)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
1 parent a1777fd commit 2ff65f5

File tree

2 files changed

+173
-134
lines changed

2 files changed

+173
-134
lines changed

tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py

Lines changed: 161 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import triton
77
import triton.language as tl
88

9-
from tensorrt_llm._torch.modules.fla.utils import input_guard
9+
from tensorrt_llm._torch.modules.fla.utils import custom_device_ctx
1010

1111

1212
@triton.heuristics({
@@ -30,6 +30,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
3030
cu_seqlens,
3131
scale,
3232
T,
33+
total_nh,
34+
stride_q,
35+
stride_k,
36+
stride_v,
37+
stride_a,
38+
stride_b,
3339
s_h0_0,
3440
h0_dim0,
3541
B: tl.constexpr,
@@ -46,117 +52,127 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
4652
"""
4753
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
4854
"""
49-
i_nh, i_v, i_k = tl.program_id(0), tl.program_id(1), tl.program_id(2)
50-
i_n, i_hv = i_nh // HV, i_nh % HV
51-
i_h = i_hv // (HV // H)
52-
53-
if IS_VARLEN:
54-
bos, eos = (
55-
tl.load(cu_seqlens + i_n).to(tl.int64),
56-
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
57-
)
58-
all = T
59-
T = eos - bos
60-
else:
61-
bos, eos = i_n * T, i_n * T + T
62-
all = B * T
63-
55+
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
6456
o_k = i_k * BK + tl.arange(0, BK)
6557
o_v = i_v * BV + tl.arange(0, BV)
66-
67-
p_q = q + (bos * H + i_h) * K + o_k
68-
p_k = k + (bos * H + i_h) * K + o_k
69-
p_v = v + (bos * HV + i_hv) * V + o_v
70-
p_b = b + bos * HV + i_hv
71-
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
72-
73-
# Gating computation pointers
74-
p_A_log = A_log + i_hv
75-
p_a = a + bos * HV + i_hv
76-
p_dt_bias = dt_bias + i_hv
77-
7858
mask_k = o_k < K
7959
mask_v = o_v < V
8060
mask_h = mask_k[:, None] & mask_v[None, :]
61+
grid_stride_nh = tl.num_programs(2)
8162

82-
b_h = tl.zeros([BK, BV], dtype=tl.float32)
83-
if USE_INITIAL_STATE:
84-
idx = tl.load(h0_indices + i_n).to(tl.int64) # prevent int32 overflow
85-
if idx >= 0:
86-
tl.device_assert(idx < h0_dim0,
87-
"idx out of bounds in h0_source load")
88-
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k[:, None] * V +
89-
o_v[None, :])
90-
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
91-
92-
for _ in range(0, T):
93-
# Load inputs
94-
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
95-
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
96-
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
97-
b_b = tl.load(p_b).to(tl.float32)
98-
99-
# Compute sigmoid gating
100-
# Load gating parameters
101-
b_A_log = tl.load(p_A_log).to(tl.float32)
102-
b_a = tl.load(p_a).to(tl.float32)
103-
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
104-
105-
# Compute g = -exp(A_log) * softplus(a + dt_bias)
106-
x = b_a + b_dt_bias
107-
beta_x = softplus_beta * x
108-
# Apply softplus with numerical stability
109-
softplus_x = tl.where(
110-
beta_x <= softplus_threshold,
111-
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
112-
x,
113-
)
114-
b_g = -tl.exp(b_A_log) * softplus_x
63+
while i_nh < total_nh:
64+
i_n, i_hv = i_nh // HV, i_nh % HV
65+
i_h = i_hv // (HV // H)
66+
67+
if IS_VARLEN:
68+
bos, eos = (
69+
tl.load(cu_seqlens + i_n).to(tl.int64),
70+
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
71+
)
72+
all = T
73+
seq_T = eos - bos
74+
else:
75+
bos, eos = i_n * T, i_n * T + T
76+
all = B * T
77+
seq_T = T
78+
79+
# Decode q/k/v/a/b often arrive as views sliced out of larger packed tensors.
80+
# Use the caller-provided token strides so the kernel can consume those views
81+
# directly instead of relying on a packed contiguous layout.
82+
p_q = q + bos * stride_q + i_h * K + o_k
83+
p_k = k + bos * stride_k + i_h * K + o_k
84+
p_v = v + bos * stride_v + i_hv * V + o_v
85+
p_b = b + bos * stride_b + i_hv
86+
# o is allocated in this wrapper and kept contiguous, so the output
87+
# pointer arithmetic can use the packed [NK, B, T, HV, V] layout.
88+
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
11589

116-
# Compute beta = sigmoid(b)
117-
b_beta = 1.0 / (1.0 + tl.exp(-b_b))
90+
# Gating computation pointers
91+
p_A_log = A_log + i_hv
92+
p_a = a + bos * stride_a + i_hv
93+
p_dt_bias = dt_bias + i_hv
11894

119-
# Apply L2 normalization if enabled
120-
if USE_QK_L2NORM_IN_KERNEL:
121-
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
122-
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
95+
b_h = tl.zeros([BK, BV], dtype=tl.float32)
96+
if USE_INITIAL_STATE:
97+
idx = tl.load(h0_indices + i_n).to(tl.int64)
98+
if idx >= 0:
99+
tl.device_assert(idx < h0_dim0,
100+
"idx out of bounds in h0_source load")
101+
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V +
102+
o_k[:, None] * V + o_v[None, :])
103+
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
123104

124-
b_q = b_q * scale
105+
for _ in range(0, seq_T):
106+
# Load inputs
107+
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
108+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
109+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
110+
b_b = tl.load(p_b).to(tl.float32)
125111

126-
# Apply gating to hidden state: h *= exp(g)
127-
b_h *= tl.exp(b_g)
112+
# Compute sigmoid gating
113+
# Load gating parameters
114+
b_A_log = tl.load(p_A_log).to(tl.float32)
115+
b_a = tl.load(p_a).to(tl.float32)
116+
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
128117

129-
# Delta rule: v -= sum(h * k, dim=0)
130-
b_v -= tl.sum(b_h * b_k[:, None], 0)
118+
# Compute g = -exp(A_log) * softplus(a + dt_bias)
119+
x = b_a + b_dt_bias
120+
beta_x = softplus_beta * x
121+
# Apply softplus with numerical stability
122+
softplus_x = tl.where(
123+
beta_x <= softplus_threshold,
124+
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
125+
x,
126+
)
127+
b_g = -tl.exp(b_A_log) * softplus_x
131128

132-
# Apply beta gating: v *= beta
133-
b_v *= b_beta
129+
# Compute beta = sigmoid(b)
130+
b_beta = 1.0 / (1.0 + tl.exp(-b_b))
134131

135-
# Update hidden state: h += k[:, None] * v[None, :]
136-
b_h += b_k[:, None] * b_v[None, :]
132+
# Apply L2 normalization if enabled
133+
if USE_QK_L2NORM_IN_KERNEL:
134+
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
135+
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
137136

138-
# Compute output: o = sum(h * q, dim=0)
139-
b_o = tl.sum(b_h * b_q[:, None], 0)
140-
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
137+
b_q = b_q * scale
141138

142-
# Update pointers for next timestep
143-
p_q += H * K
144-
p_k += H * K
145-
p_o += HV * V
146-
p_v += HV * V
147-
p_b += HV
148-
p_a += HV
139+
# Apply gating to hidden state: h *= exp(g)
140+
b_h *= tl.exp(b_g)
149141

150-
# Store final state back to h0_source with bounds checking
151-
if USE_INITIAL_STATE:
152-
idx = tl.load(h0_indices + i_n).to(tl.int64)
153-
if idx >= 0:
154-
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k[:, None] * V +
155-
o_v[None, :])
156-
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
142+
# Delta rule: v -= sum(h * k, dim=0)
143+
b_v -= tl.sum(b_h * b_k[:, None], 0)
144+
145+
# Apply beta gating: v *= beta
146+
b_v *= b_beta
147+
148+
# Update hidden state: h += k[:, None] * v[None, :]
149+
b_h += b_k[:, None] * b_v[None, :]
150+
151+
# Compute output: o = sum(h * q, dim=0)
152+
b_o = tl.sum(b_h * b_q[:, None], 0)
153+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
154+
155+
# Update pointers for next timestep
156+
p_q += stride_q
157+
p_k += stride_k
158+
p_o += HV * V
159+
p_v += stride_v
160+
p_b += stride_b
161+
p_a += stride_a
162+
163+
# Store final state back to h0_source with bounds checking
164+
if USE_INITIAL_STATE:
165+
idx = tl.load(h0_indices + i_n).to(tl.int64)
166+
if idx >= 0:
167+
tl.device_assert(idx < h0_dim0,
168+
"idx out of bounds in h0_source store")
169+
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V +
170+
o_k[:, None] * V + o_v[None, :])
171+
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
172+
173+
i_nh += grid_stride_nh
157174

158175

159-
@input_guard(exclude_args=["initial_state_source"])
160176
def fused_sigmoid_gating_delta_rule_update(
161177
A_log: torch.Tensor,
162178
a: torch.Tensor,
@@ -181,6 +197,14 @@ def fused_sigmoid_gating_delta_rule_update(
181197
B, T, H, K, V = *k.shape, v.shape[-1]
182198
HV = v.shape[2]
183199
N = B if cu_seqlens is None else len(cu_seqlens) - 1
200+
201+
# Accept native view layouts from forward_decode rather than forcing packed
202+
# copies through input_guard.
203+
stride_q = q.stride(1)
204+
stride_k = k.stride(1)
205+
stride_v = v.stride(1)
206+
stride_a = a.stride(-2)
207+
stride_b = b.stride(-2)
184208
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
185209
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
186210
assert NK == 1, "NK > 1 is not supported yet"
@@ -193,7 +217,10 @@ def fused_sigmoid_gating_delta_rule_update(
193217
assert scale > 0, "scale must be positive"
194218

195219
o = q.new_empty(NK, *v.shape)
196-
grid = (N * HV, NV, NK)
220+
# (NK, NV, N * HV) is found faster than (N * HV, NV, NK)
221+
# As max of grid.z is 65535, we cap grid.z and let each Triton program
222+
# grid-stride across the remaining N * HV tiles.
223+
grid = (NK, NV, min(N * HV, 65535))
197224

198225
if initial_state_source is not None:
199226
s_h0_0, s_h0_1, s_h0_2, s_h0_3 = initial_state_source.stride()
@@ -205,34 +232,44 @@ def fused_sigmoid_gating_delta_rule_update(
205232
s_h0_0 = 0
206233
slot_num = 0
207234

208-
fused_sigmoid_gating_delta_rule_update_kernel[grid](
209-
A_log=A_log,
210-
a=a,
211-
dt_bias=dt_bias,
212-
softplus_beta=softplus_beta,
213-
softplus_threshold=softplus_threshold,
214-
q=q,
215-
k=k,
216-
v=v,
217-
b=b,
218-
o=o,
219-
h0_source=initial_state_source,
220-
h0_indices=initial_state_indices,
221-
cu_seqlens=cu_seqlens,
222-
scale=scale,
223-
T=T,
224-
s_h0_0=s_h0_0,
225-
h0_dim0=slot_num,
226-
B=B,
227-
H=H,
228-
HV=HV,
229-
K=K,
230-
V=V,
231-
BK=BK,
232-
BV=BV,
233-
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
234-
num_warps=num_warps,
235-
num_stages=num_stages,
236-
)
235+
# input_guard used to set the active CUDA device and make inputs contiguous.
236+
# We keep only the device-context part here so Triton launches on q's device
237+
# without re-packing the decode views.
238+
with custom_device_ctx(q.device.index):
239+
fused_sigmoid_gating_delta_rule_update_kernel[grid](
240+
A_log=A_log,
241+
a=a,
242+
dt_bias=dt_bias,
243+
softplus_beta=softplus_beta,
244+
softplus_threshold=softplus_threshold,
245+
q=q,
246+
k=k,
247+
v=v,
248+
b=b,
249+
o=o,
250+
h0_source=initial_state_source,
251+
h0_indices=initial_state_indices,
252+
cu_seqlens=cu_seqlens,
253+
scale=scale,
254+
T=T,
255+
total_nh=N * HV,
256+
stride_q=stride_q,
257+
stride_k=stride_k,
258+
stride_v=stride_v,
259+
stride_a=stride_a,
260+
stride_b=stride_b,
261+
s_h0_0=s_h0_0,
262+
h0_dim0=slot_num,
263+
B=B,
264+
H=H,
265+
HV=HV,
266+
K=K,
267+
V=V,
268+
BK=BK,
269+
BV=BV,
270+
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
271+
num_warps=num_warps,
272+
num_stages=num_stages,
273+
)
237274
o = o.squeeze(0)
238275
return o

tensorrt_llm/_torch/modules/mamba/gdn_mixer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,10 @@ def __init__(
236236
self.head_v_dim = config.linear_value_head_dim
237237
self.key_dim = self.head_k_dim * self.num_k_heads
238238
self.value_dim = self.head_v_dim * self.num_v_heads
239+
self.num_k_heads_per_tp = divide(self.num_k_heads, self.attn_tp_size)
240+
self.num_v_heads_per_tp = divide(self.num_v_heads, self.attn_tp_size)
241+
self.key_dim_per_tp = self.head_k_dim * self.num_k_heads_per_tp
242+
self.value_dim_per_tp = self.head_v_dim * self.num_v_heads_per_tp
239243

240244
self.conv_kernel_size = config.linear_conv_kernel_dim
241245
self.layer_idx = layer_idx
@@ -480,17 +484,15 @@ def forward_decode(
480484
conv_state_indices=cache_indices,
481485
)
482486

483-
# Direct slicing instead of torch.split for better performance
484-
key_size = self.key_dim // self.attn_tp_size
485-
query = mixed_qkv[..., :key_size]
486-
key = mixed_qkv[..., key_size : key_size * 2]
487-
value = mixed_qkv[..., key_size * 2 :]
488-
# Reshape from [l, h*d] to [1, l, h, d]
487+
# Keep q/k/v as views over mixed_qkv so the fused decode kernel can
488+
# consume their native strides without forcing packed copies.
489+
query = mixed_qkv[..., : self.key_dim_per_tp]
490+
key = mixed_qkv[..., self.key_dim_per_tp : self.key_dim_per_tp * 2]
491+
value = mixed_qkv[..., self.key_dim_per_tp * 2 :]
489492
seq_len = query.shape[0]
490-
num_heads = query.shape[1] // self.head_k_dim
491-
query = query.view(1, seq_len, num_heads, self.head_k_dim)
492-
key = key.view(1, seq_len, num_heads, self.head_k_dim)
493-
value = value.view(1, seq_len, value.shape[1] // self.head_v_dim, self.head_v_dim)
493+
query = query.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim)
494+
key = key.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim)
495+
value = value.view(1, seq_len, self.num_v_heads_per_tp, self.head_v_dim)
494496

495497
core_attn_out = fused_sigmoid_gating_delta_rule_update(
496498
A_log=self.A_log,

0 commit comments

Comments
 (0)