Skip to content

Commit b9e6069

Browse files
committed
Optimize qwen3.5 decode delta kernel
- keep decode qkv views and make the fused recurrent kernel stride-aware - restore the decode tile choice that wins on the representative bs256 pure-decode benchmark Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
1 parent dc3a468 commit b9e6069

2 files changed

Lines changed: 155 additions & 135 deletions

File tree

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,10 @@ def __init__(self,
456456
self.head_v_dim = config.linear_value_head_dim
457457
self.key_dim = self.head_k_dim * self.num_k_heads
458458
self.value_dim = self.head_v_dim * self.num_v_heads
459+
self.num_k_heads_per_tp = divide(self.num_k_heads, self.attn_tp_size)
460+
self.num_v_heads_per_tp = divide(self.num_v_heads, self.attn_tp_size)
461+
self.key_dim_per_tp = self.head_k_dim * self.num_k_heads_per_tp
462+
self.value_dim_per_tp = self.head_v_dim * self.num_v_heads_per_tp
459463

460464
self.conv_kernel_size = config.linear_conv_kernel_dim
461465
self.layer_idx = layer_idx
@@ -620,18 +624,15 @@ def forward_decode(
620624
conv_state_indices=cache_indices,
621625
)
622626

623-
# Direct slicing instead of torch.split for better performance
624-
key_size = self.key_dim // self.attn_tp_size
625-
query = mixed_qkv[..., :key_size]
626-
key = mixed_qkv[..., key_size:key_size * 2]
627-
value = mixed_qkv[..., key_size * 2:]
628-
# Reshape from [l, h*d] to [1, l, h, d]
627+
# Keep q/k/v as views over mixed_qkv so the fused decode kernel can
628+
# consume their native strides without forcing packed copies.
629+
query = mixed_qkv[..., :self.key_dim_per_tp]
630+
key = mixed_qkv[..., self.key_dim_per_tp:self.key_dim_per_tp * 2]
631+
value = mixed_qkv[..., self.key_dim_per_tp * 2:]
629632
seq_len = query.shape[0]
630-
num_heads = query.shape[1] // self.head_k_dim
631-
query = query.view(1, seq_len, num_heads, self.head_k_dim)
632-
key = key.view(1, seq_len, num_heads, self.head_k_dim)
633-
value = value.view(1, seq_len, value.shape[1] // self.head_v_dim,
634-
self.head_v_dim)
633+
query = query.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim)
634+
key = key.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim)
635+
value = value.view(1, seq_len, self.num_v_heads_per_tp, self.head_v_dim)
635636

636637
core_attn_out = fused_sigmoid_gating_delta_rule_update(
637638
A_log=self.A_log,

tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py

Lines changed: 143 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import triton
77
import triton.language as tl
88

9-
from tensorrt_llm._torch.modules.fla.utils import input_guard
9+
from tensorrt_llm._torch.modules.fla.utils import custom_device_ctx
1010

1111

1212
@triton.heuristics({
@@ -30,6 +30,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
3030
cu_seqlens,
3131
scale,
3232
T,
33+
total_nh,
34+
stride_q,
35+
stride_k,
36+
stride_v,
37+
stride_a,
38+
stride_b,
3339
s_h0_0,
3440
h0_dim0,
3541
B: tl.constexpr,
@@ -46,117 +52,109 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
4652
"""
4753
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
4854
"""
49-
i_nh, i_v, i_k = tl.program_id(0), tl.program_id(1), tl.program_id(2)
50-
i_n, i_hv = i_nh // HV, i_nh % HV
51-
i_h = i_hv // (HV // H)
52-
53-
if IS_VARLEN:
54-
bos, eos = (
55-
tl.load(cu_seqlens + i_n).to(tl.int64),
56-
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
57-
)
58-
all = T
59-
T = eos - bos
60-
else:
61-
bos, eos = i_n * T, i_n * T + T
62-
all = B * T
63-
55+
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
6456
o_k = i_k * BK + tl.arange(0, BK)
6557
o_v = i_v * BV + tl.arange(0, BV)
66-
67-
p_q = q + (bos * H + i_h) * K + o_k
68-
p_k = k + (bos * H + i_h) * K + o_k
69-
p_v = v + (bos * HV + i_hv) * V + o_v
70-
p_b = b + bos * HV + i_hv
71-
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
72-
73-
# Gating computation pointers
74-
p_A_log = A_log + i_hv
75-
p_a = a + bos * HV + i_hv
76-
p_dt_bias = dt_bias + i_hv
77-
7858
mask_k = o_k < K
7959
mask_v = o_v < V
8060
mask_h = mask_k[:, None] & mask_v[None, :]
61+
grid_stride_nh = tl.num_programs(2)
8162

82-
b_h = tl.zeros([BK, BV], dtype=tl.float32)
83-
if USE_INITIAL_STATE:
84-
idx = tl.load(h0_indices + i_n).to(tl.int64) # prevent int32 overflow
85-
if idx >= 0:
86-
tl.device_assert(idx < h0_dim0,
87-
"idx out of bounds in h0_source load")
88-
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k[:, None] * V +
89-
o_v[None, :])
90-
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
91-
92-
for _ in range(0, T):
93-
# Load inputs
94-
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
95-
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
96-
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
97-
b_b = tl.load(p_b).to(tl.float32)
98-
99-
# Compute sigmoid gating
100-
# Load gating parameters
101-
b_A_log = tl.load(p_A_log).to(tl.float32)
102-
b_a = tl.load(p_a).to(tl.float32)
103-
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
104-
105-
# Compute g = -exp(A_log) * softplus(a + dt_bias)
106-
x = b_a + b_dt_bias
107-
beta_x = softplus_beta * x
108-
# Apply softplus with numerical stability
109-
softplus_x = tl.where(
110-
beta_x <= softplus_threshold,
111-
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
112-
x,
113-
)
114-
b_g = -tl.exp(b_A_log) * softplus_x
63+
while i_nh < total_nh:
64+
i_n, i_hv = i_nh // HV, i_nh % HV
65+
i_h = i_hv // (HV // H)
11566

116-
# Compute beta = sigmoid(b)
117-
b_beta = 1.0 / (1.0 + tl.exp(-b_b))
67+
if IS_VARLEN:
68+
bos, eos = (
69+
tl.load(cu_seqlens + i_n).to(tl.int64),
70+
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
71+
)
72+
all = T
73+
seq_T = eos - bos
74+
else:
75+
bos, eos = i_n * T, i_n * T + T
76+
all = B * T
77+
seq_T = T
11878

119-
# Apply L2 normalization if enabled
120-
if USE_QK_L2NORM_IN_KERNEL:
121-
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
122-
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
79+
# Decode q/k/v/a/b often arrive as views sliced out of larger packed tensors.
80+
# Use the caller-provided token strides so the kernel can consume those views
81+
# directly instead of relying on a packed contiguous layout.
82+
p_q = q + bos * stride_q + i_h * K + o_k
83+
p_k = k + bos * stride_k + i_h * K + o_k
84+
p_v = v + bos * stride_v + i_hv * V + o_v
85+
p_b = b + bos * stride_b + i_hv
86+
# o is allocated in this wrapper and kept contiguous, so the output
87+
# pointer arithmetic can use the packed [NK, B, T, HV, V] layout.
88+
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
12389

124-
b_q = b_q * scale
90+
# Gating computation pointers
91+
p_A_log = A_log + i_hv
92+
p_a = a + bos * stride_a + i_hv
93+
p_dt_bias = dt_bias + i_hv
12594

126-
# Apply gating to hidden state: h *= exp(g)
127-
b_h *= tl.exp(b_g)
95+
b_h = tl.zeros([BK, BV], dtype=tl.float32)
96+
if USE_INITIAL_STATE:
97+
idx = tl.load(h0_indices + i_n).to(tl.int64)
98+
if idx >= 0:
99+
tl.device_assert(idx < h0_dim0,
100+
"idx out of bounds in h0_source load")
101+
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V +
102+
o_k[:, None] * V + o_v[None, :])
103+
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
128104

129-
# Delta rule: v -= sum(h * k, dim=0)
130-
b_v -= tl.sum(b_h * b_k[:, None], 0)
105+
for _ in range(0, seq_T):
106+
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
107+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
108+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
109+
b_b = tl.load(p_b).to(tl.float32)
131110

132-
# Apply beta gating: v *= beta
133-
b_v *= b_beta
111+
b_A_log = tl.load(p_A_log).to(tl.float32)
112+
b_a = tl.load(p_a).to(tl.float32)
113+
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
134114

135-
# Update hidden state: h += k[:, None] * v[None, :]
136-
b_h += b_k[:, None] * b_v[None, :]
115+
x = b_a + b_dt_bias
116+
beta_x = softplus_beta * x
117+
softplus_x = tl.where(
118+
beta_x <= softplus_threshold,
119+
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
120+
x,
121+
)
122+
b_g = -tl.exp(b_A_log) * softplus_x
137123

138-
# Compute output: o = sum(h * q, dim=0)
139-
b_o = tl.sum(b_h * b_q[:, None], 0)
140-
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
124+
b_beta = 1.0 / (1.0 + tl.exp(-b_b))
141125

142-
# Update pointers for next timestep
143-
p_q += H * K
144-
p_k += H * K
145-
p_o += HV * V
146-
p_v += HV * V
147-
p_b += HV
148-
p_a += HV
126+
if USE_QK_L2NORM_IN_KERNEL:
127+
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
128+
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
149129

150-
# Store final state back to h0_source with bounds checking
151-
if USE_INITIAL_STATE:
152-
idx = tl.load(h0_indices + i_n).to(tl.int64)
153-
if idx >= 0:
154-
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k[:, None] * V +
155-
o_v[None, :])
156-
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
130+
b_q = b_q * scale
131+
b_h *= tl.exp(b_g)
132+
b_v -= tl.sum(b_h * b_k[:, None], 0)
133+
b_v *= b_beta
134+
b_h += b_k[:, None] * b_v[None, :]
135+
136+
b_o = tl.sum(b_h * b_q[:, None], 0)
137+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
138+
139+
p_q += stride_q
140+
p_k += stride_k
141+
p_o += HV * V
142+
p_v += stride_v
143+
p_b += stride_b
144+
p_a += stride_a
145+
146+
if USE_INITIAL_STATE:
147+
idx = tl.load(h0_indices + i_n).to(tl.int64)
148+
if idx >= 0:
149+
tl.device_assert(idx < h0_dim0,
150+
"idx out of bounds in h0_source store")
151+
p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V +
152+
o_k[:, None] * V + o_v[None, :])
153+
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
154+
155+
i_nh += grid_stride_nh
157156

158157

159-
@input_guard(exclude_args=["initial_state_source"])
160158
def fused_sigmoid_gating_delta_rule_update(
161159
A_log: torch.Tensor,
162160
a: torch.Tensor,
@@ -181,6 +179,14 @@ def fused_sigmoid_gating_delta_rule_update(
181179
B, T, H, K, V = *k.shape, v.shape[-1]
182180
HV = v.shape[2]
183181
N = B if cu_seqlens is None else len(cu_seqlens) - 1
182+
183+
# Accept native view layouts from forward_decode rather than forcing packed
184+
# copies through input_guard.
185+
stride_q = q.stride(1)
186+
stride_k = k.stride(1)
187+
stride_v = v.stride(1)
188+
stride_a = a.stride(-2)
189+
stride_b = b.stride(-2)
184190
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
185191
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
186192
assert NK == 1, "NK > 1 is not supported yet"
@@ -193,7 +199,10 @@ def fused_sigmoid_gating_delta_rule_update(
193199
assert scale > 0, "scale must be positive"
194200

195201
o = q.new_empty(NK, *v.shape)
196-
grid = (N * HV, NV, NK)
202+
# (NK, NV, N * HV) is found faster than (N * HV, NV, NK)
203+
# As max of grid.z is 65535, we cap grid.z and let each Triton program
204+
# grid-stride across the remaining N * HV tiles.
205+
grid = (NK, NV, min(N * HV, 65535))
197206

198207
if initial_state_source is not None:
199208
s_h0_0, s_h0_1, s_h0_2, s_h0_3 = initial_state_source.stride()
@@ -205,34 +214,44 @@ def fused_sigmoid_gating_delta_rule_update(
205214
s_h0_0 = 0
206215
slot_num = 0
207216

208-
fused_sigmoid_gating_delta_rule_update_kernel[grid](
209-
A_log=A_log,
210-
a=a,
211-
dt_bias=dt_bias,
212-
softplus_beta=softplus_beta,
213-
softplus_threshold=softplus_threshold,
214-
q=q,
215-
k=k,
216-
v=v,
217-
b=b,
218-
o=o,
219-
h0_source=initial_state_source,
220-
h0_indices=initial_state_indices,
221-
cu_seqlens=cu_seqlens,
222-
scale=scale,
223-
T=T,
224-
s_h0_0=s_h0_0,
225-
h0_dim0=slot_num,
226-
B=B,
227-
H=H,
228-
HV=HV,
229-
K=K,
230-
V=V,
231-
BK=BK,
232-
BV=BV,
233-
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
234-
num_warps=num_warps,
235-
num_stages=num_stages,
236-
)
217+
# input_guard used to set the active CUDA device and make inputs contiguous.
218+
# We keep only the device-context part here so Triton launches on q's device
219+
# without re-packing the decode views.
220+
with custom_device_ctx(q.device.index):
221+
fused_sigmoid_gating_delta_rule_update_kernel[grid](
222+
A_log=A_log,
223+
a=a,
224+
dt_bias=dt_bias,
225+
softplus_beta=softplus_beta,
226+
softplus_threshold=softplus_threshold,
227+
q=q,
228+
k=k,
229+
v=v,
230+
b=b,
231+
o=o,
232+
h0_source=initial_state_source,
233+
h0_indices=initial_state_indices,
234+
cu_seqlens=cu_seqlens,
235+
scale=scale,
236+
T=T,
237+
total_nh=N * HV,
238+
stride_q=stride_q,
239+
stride_k=stride_k,
240+
stride_v=stride_v,
241+
stride_a=stride_a,
242+
stride_b=stride_b,
243+
s_h0_0=s_h0_0,
244+
h0_dim0=slot_num,
245+
B=B,
246+
H=H,
247+
HV=HV,
248+
K=K,
249+
V=V,
250+
BK=BK,
251+
BV=BV,
252+
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
253+
num_warps=num_warps,
254+
num_stages=num_stages,
255+
)
237256
o = o.squeeze(0)
238257
return o

0 commit comments

Comments
 (0)