From 8a2468a88840b9f9c0407114b1c270bc70dfef3c Mon Sep 17 00:00:00 2001 From: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Date: Fri, 3 Apr 2026 15:11:02 +0000 Subject: [PATCH] Optimize qwen3.5 decode delta kernel - keep decode qkv views and make the fused recurrent kernel stride-aware - restore the decode tile choice that wins on the representative bs256 pure-decode benchmark Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> --- .../fla/fused_sigmoid_gating_recurrent.py | 285 ++++++++++-------- .../_torch/modules/mamba/gdn_mixer.py | 22 +- 2 files changed, 173 insertions(+), 134 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py b/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py index be6f0971a5ae..ffe1fafc649c 100644 --- a/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py +++ b/tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py @@ -6,7 +6,7 @@ import triton import triton.language as tl -from tensorrt_llm._torch.modules.fla.utils import input_guard +from tensorrt_llm._torch.modules.fla.utils import custom_device_ctx @triton.heuristics({ @@ -30,6 +30,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel( cu_seqlens, scale, T, + total_nh, + stride_q, + stride_k, + stride_v, + stride_a, + stride_b, s_h0_0, h0_dim0, B: tl.constexpr, @@ -46,117 +52,127 @@ def fused_sigmoid_gating_delta_rule_update_kernel( """ Fused kernel that combines sigmoid gating computation with recurrent delta rule update. """ - i_nh, i_v, i_k = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_n, i_hv = i_nh // HV, i_nh % HV - i_h = i_hv // (HV // H) - - if IS_VARLEN: - bos, eos = ( - tl.load(cu_seqlens + i_n).to(tl.int64), - tl.load(cu_seqlens + i_n + 1).to(tl.int64), - ) - all = T - T = eos - bos - else: - bos, eos = i_n * T, i_n * T + T - all = B * T - + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) o_k = i_k * BK + tl.arange(0, BK) o_v = i_v * BV + tl.arange(0, BV) - - p_q = q + (bos * H + i_h) * K + o_k - p_k = k + (bos * H + i_h) * K + o_k - p_v = v + (bos * HV + i_hv) * V + o_v - p_b = b + bos * HV + i_hv - p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v - - # Gating computation pointers - p_A_log = A_log + i_hv - p_a = a + bos * HV + i_hv - p_dt_bias = dt_bias + i_hv - mask_k = o_k < K mask_v = o_v < V mask_h = mask_k[:, None] & mask_v[None, :] + grid_stride_nh = tl.num_programs(2) - b_h = tl.zeros([BK, BV], dtype=tl.float32) - if USE_INITIAL_STATE: - idx = tl.load(h0_indices + i_n).to(tl.int64) # prevent int32 overflow - if idx >= 0: - tl.device_assert(idx < h0_dim0, - "idx out of bounds in h0_source load") - p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k[:, None] * V + - o_v[None, :]) - b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) - - for _ in range(0, T): - # Load inputs - b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) - b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) - b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) - b_b = tl.load(p_b).to(tl.float32) - - # Compute sigmoid gating - # Load gating parameters - b_A_log = tl.load(p_A_log).to(tl.float32) - b_a = tl.load(p_a).to(tl.float32) - b_dt_bias = tl.load(p_dt_bias).to(tl.float32) - - # Compute g = -exp(A_log) * softplus(a + dt_bias) - x = b_a + b_dt_bias - beta_x = softplus_beta * x - # Apply softplus with numerical stability - softplus_x = tl.where( - beta_x <= softplus_threshold, - (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), - x, - ) - b_g = -tl.exp(b_A_log) * softplus_x + while i_nh < total_nh: + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + seq_T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + seq_T = T + + # Decode q/k/v/a/b often arrive as views sliced out of larger packed tensors. + # Use the caller-provided token strides so the kernel can consume those views + # directly instead of relying on a packed contiguous layout. + p_q = q + bos * stride_q + i_h * K + o_k + p_k = k + bos * stride_k + i_h * K + o_k + p_v = v + bos * stride_v + i_hv * V + o_v + p_b = b + bos * stride_b + i_hv + # o is allocated in this wrapper and kept contiguous, so the output + # pointer arithmetic can use the packed [NK, B, T, HV, V] layout. + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v - # Compute beta = sigmoid(b) - b_beta = 1.0 / (1.0 + tl.exp(-b_b)) + # Gating computation pointers + p_A_log = A_log + i_hv + p_a = a + bos * stride_a + i_hv + p_dt_bias = dt_bias + i_hv - # Apply L2 normalization if enabled - if USE_QK_L2NORM_IN_KERNEL: - b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) - b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n).to(tl.int64) + if idx >= 0: + tl.device_assert(idx < h0_dim0, + "idx out of bounds in h0_source load") + p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + + o_k[:, None] * V + o_v[None, :]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) - b_q = b_q * scale + for _ in range(0, seq_T): + # Load inputs + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b).to(tl.float32) - # Apply gating to hidden state: h *= exp(g) - b_h *= tl.exp(b_g) + # Compute sigmoid gating + # Load gating parameters + b_A_log = tl.load(p_A_log).to(tl.float32) + b_a = tl.load(p_a).to(tl.float32) + b_dt_bias = tl.load(p_dt_bias).to(tl.float32) - # Delta rule: v -= sum(h * k, dim=0) - b_v -= tl.sum(b_h * b_k[:, None], 0) + # Compute g = -exp(A_log) * softplus(a + dt_bias) + x = b_a + b_dt_bias + beta_x = softplus_beta * x + # Apply softplus with numerical stability + softplus_x = tl.where( + beta_x <= softplus_threshold, + (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), + x, + ) + b_g = -tl.exp(b_A_log) * softplus_x - # Apply beta gating: v *= beta - b_v *= b_beta + # Compute beta = sigmoid(b) + b_beta = 1.0 / (1.0 + tl.exp(-b_b)) - # Update hidden state: h += k[:, None] * v[None, :] - b_h += b_k[:, None] * b_v[None, :] + # Apply L2 normalization if enabled + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) - # Compute output: o = sum(h * q, dim=0) - b_o = tl.sum(b_h * b_q[:, None], 0) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + b_q = b_q * scale - # Update pointers for next timestep - p_q += H * K - p_k += H * K - p_o += HV * V - p_v += HV * V - p_b += HV - p_a += HV + # Apply gating to hidden state: h *= exp(g) + b_h *= tl.exp(b_g) - # Store final state back to h0_source with bounds checking - if USE_INITIAL_STATE: - idx = tl.load(h0_indices + i_n).to(tl.int64) - if idx >= 0: - p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + o_k[:, None] * V + - o_v[None, :]) - tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) + # Delta rule: v -= sum(h * k, dim=0) + b_v -= tl.sum(b_h * b_k[:, None], 0) + + # Apply beta gating: v *= beta + b_v *= b_beta + + # Update hidden state: h += k[:, None] * v[None, :] + b_h += b_k[:, None] * b_v[None, :] + + # Compute output: o = sum(h * q, dim=0) + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # Update pointers for next timestep + p_q += stride_q + p_k += stride_k + p_o += HV * V + p_v += stride_v + p_b += stride_b + p_a += stride_a + + # Store final state back to h0_source with bounds checking + if USE_INITIAL_STATE: + idx = tl.load(h0_indices + i_n).to(tl.int64) + if idx >= 0: + tl.device_assert(idx < h0_dim0, + "idx out of bounds in h0_source store") + p_h0 = (h0_source + idx * s_h0_0 + i_hv * K * V + + o_k[:, None] * V + o_v[None, :]) + tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) + + i_nh += grid_stride_nh -@input_guard(exclude_args=["initial_state_source"]) def fused_sigmoid_gating_delta_rule_update( A_log: torch.Tensor, a: torch.Tensor, @@ -181,6 +197,14 @@ def fused_sigmoid_gating_delta_rule_update( B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + # Accept native view layouts from forward_decode rather than forcing packed + # copies through input_guard. + stride_q = q.stride(1) + stride_k = k.stride(1) + stride_v = v.stride(1) + stride_a = a.stride(-2) + stride_b = b.stride(-2) BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) assert NK == 1, "NK > 1 is not supported yet" @@ -193,7 +217,10 @@ def fused_sigmoid_gating_delta_rule_update( assert scale > 0, "scale must be positive" o = q.new_empty(NK, *v.shape) - grid = (N * HV, NV, NK) + # (NK, NV, N * HV) is found faster than (N * HV, NV, NK) + # As max of grid.z is 65535, we cap grid.z and let each Triton program + # grid-stride across the remaining N * HV tiles. + grid = (NK, NV, min(N * HV, 65535)) if initial_state_source is not None: s_h0_0, s_h0_1, s_h0_2, s_h0_3 = initial_state_source.stride() @@ -205,34 +232,44 @@ def fused_sigmoid_gating_delta_rule_update( s_h0_0 = 0 slot_num = 0 - fused_sigmoid_gating_delta_rule_update_kernel[grid]( - A_log=A_log, - a=a, - dt_bias=dt_bias, - softplus_beta=softplus_beta, - softplus_threshold=softplus_threshold, - q=q, - k=k, - v=v, - b=b, - o=o, - h0_source=initial_state_source, - h0_indices=initial_state_indices, - cu_seqlens=cu_seqlens, - scale=scale, - T=T, - s_h0_0=s_h0_0, - h0_dim0=slot_num, - B=B, - H=H, - HV=HV, - K=K, - V=V, - BK=BK, - BV=BV, - USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, - num_warps=num_warps, - num_stages=num_stages, - ) + # input_guard used to set the active CUDA device and make inputs contiguous. + # We keep only the device-context part here so Triton launches on q's device + # without re-packing the decode views. + with custom_device_ctx(q.device.index): + fused_sigmoid_gating_delta_rule_update_kernel[grid]( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + o=o, + h0_source=initial_state_source, + h0_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + total_nh=N * HV, + stride_q=stride_q, + stride_k=stride_k, + stride_v=stride_v, + stride_a=stride_a, + stride_b=stride_b, + s_h0_0=s_h0_0, + h0_dim0=slot_num, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) o = o.squeeze(0) return o diff --git a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py index 2f62d40befdf..6d3002298e87 100644 --- a/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/gdn_mixer.py @@ -235,6 +235,10 @@ def __init__( self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads self.value_dim = self.head_v_dim * self.num_v_heads + self.num_k_heads_per_tp = divide(self.num_k_heads, self.attn_tp_size) + self.num_v_heads_per_tp = divide(self.num_v_heads, self.attn_tp_size) + self.key_dim_per_tp = self.head_k_dim * self.num_k_heads_per_tp + self.value_dim_per_tp = self.head_v_dim * self.num_v_heads_per_tp self.conv_kernel_size = config.linear_conv_kernel_dim self.layer_idx = layer_idx @@ -479,17 +483,15 @@ def forward_decode( conv_state_indices=cache_indices, ) - # Direct slicing instead of torch.split for better performance - key_size = self.key_dim // self.attn_tp_size - query = mixed_qkv[..., :key_size] - key = mixed_qkv[..., key_size : key_size * 2] - value = mixed_qkv[..., key_size * 2 :] - # Reshape from [l, h*d] to [1, l, h, d] + # Keep q/k/v as views over mixed_qkv so the fused decode kernel can + # consume their native strides without forcing packed copies. + query = mixed_qkv[..., : self.key_dim_per_tp] + key = mixed_qkv[..., self.key_dim_per_tp : self.key_dim_per_tp * 2] + value = mixed_qkv[..., self.key_dim_per_tp * 2 :] seq_len = query.shape[0] - num_heads = query.shape[1] // self.head_k_dim - query = query.view(1, seq_len, num_heads, self.head_k_dim) - key = key.view(1, seq_len, num_heads, self.head_k_dim) - value = value.view(1, seq_len, value.shape[1] // self.head_v_dim, self.head_v_dim) + query = query.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim) + key = key.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim) + value = value.view(1, seq_len, self.num_v_heads_per_tp, self.head_v_dim) core_attn_out = fused_sigmoid_gating_delta_rule_update( A_log=self.A_log,