feat: qwen3.5 perf opt#1351
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several performance optimizations and features, including prioritizing the fa3 attention backend, implementing fast CUDA graph planning for FlashInfer, adding a fused add_rmsnorm Triton kernel, and optimizing the Qwen3Next model's GDN decode path and sampling backend. The code review feedback highlights critical correctness and performance improvements: reversing in-place logit division if FlashInfer sampling fails, skipping temperature scaling for greedy sampling when logprobs are not needed, removing redundant loops and memory operations in the fused RMSNorm kernel, avoiding redundant concurrent writes in the MoE alignment kernel, asserting tensor contiguity in GDN decode kernels, and replacing manual sigmoid implementations with Triton's built-in tl.sigmoid.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if temperature != 1.0: | ||
| logits.div_(temperature) | ||
|
|
||
| if top_k == vocab_size and top_p != 1.0: | ||
| top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) | ||
| return _flashinfer_top_p_sample_from_logits(logits, top_p_tensor) | ||
|
|
||
| top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) | ||
| top_k_tensor = _get_uniform_tensor(top_k, logits.shape[0], torch.int32, logits.device) | ||
| return _flashinfer_top_p_top_k_sample_from_logits(logits, top_p_tensor, top_k_tensor) |
There was a problem hiding this comment.
In _try_flashinfer_sample_without_penalty, logits.div_(temperature) is performed in-place. If the subsequent flashinfer sampling function returns None (e.g., because flashinfer is not available or fails), the function returns None, and the caller falls back to standard sampling. However, logits has already been modified in-place, so the caller will divide logits by the temperature again, leading to incorrect sampling probabilities. We should restore logits by multiplying back by temperature if the flashinfer call returns None.
| if temperature != 1.0: | |
| logits.div_(temperature) | |
| if top_k == vocab_size and top_p != 1.0: | |
| top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) | |
| return _flashinfer_top_p_sample_from_logits(logits, top_p_tensor) | |
| top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) | |
| top_k_tensor = _get_uniform_tensor(top_k, logits.shape[0], torch.int32, logits.device) | |
| return _flashinfer_top_p_top_k_sample_from_logits(logits, top_p_tensor, top_k_tensor) | |
| if top_k == vocab_size and top_p != 1.0: | |
| top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) | |
| if temperature != 1.0: | |
| logits.div_(temperature) | |
| res = _flashinfer_top_p_sample_from_logits(logits, top_p_tensor) | |
| if res is None and temperature != 1.0: | |
| logits.mul_(temperature) | |
| return res | |
| top_p_tensor = _get_uniform_tensor(top_p, logits.shape[0], torch.float32, logits.device) | |
| top_k_tensor = _get_uniform_tensor(top_k, logits.shape[0], torch.int32, logits.device) | |
| if temperature != 1.0: | |
| logits.div_(temperature) | |
| res = _flashinfer_top_p_top_k_sample_from_logits(logits, top_p_tensor, top_k_tensor) | |
| if res is None and temperature != 1.0: | |
| logits.mul_(temperature) | |
| return res |
| @triton.jit | ||
| def _add_rms_norm_fwd_fused( | ||
| X, | ||
| R, | ||
| Y, | ||
| W, | ||
| x_stride0, | ||
| x_stride1, | ||
| r_stride0, | ||
| r_stride1, | ||
| y_stride0, | ||
| y_stride1, | ||
| N, | ||
| eps, | ||
| HAS_WEIGHT: tl.constexpr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| row = tl.program_id(0) | ||
| X += row * x_stride0 | ||
| R += row * r_stride0 | ||
| Y += row * y_stride0 | ||
|
|
||
| _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
| for off in range(0, N, BLOCK_SIZE): | ||
| cols = off + tl.arange(0, BLOCK_SIZE) | ||
| mask = cols < N | ||
| x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) | ||
| r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) | ||
| x = x + r | ||
| tl.store(X + cols * x_stride1, x.to(X.dtype.element_ty), mask=mask) | ||
| _var += x * x | ||
|
|
||
| var = tl.sum(_var, axis=0) / N | ||
| rstd = 1 / tl.sqrt(var + eps) | ||
| for off in range(0, N, BLOCK_SIZE): | ||
| cols = off + tl.arange(0, BLOCK_SIZE) | ||
| mask = cols < N | ||
| x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) | ||
| y = x * rstd | ||
| if HAS_WEIGHT: | ||
| w = tl.load(W + cols, mask=mask).to(tl.float32) | ||
| y *= w | ||
| tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) | ||
|
|
There was a problem hiding this comment.
The _add_rms_norm_fwd_fused kernel uses a loop over N with BLOCK_SIZE steps. However, N > BLOCK_SIZE is explicitly forbidden and raises a RuntimeError in the Python wrapper. Since N <= BLOCK_SIZE is guaranteed, the loop is completely redundant. More importantly, storing X and then reloading it in the second loop causes unnecessary global memory read/write operations, which is extremely slow for memory-bandwidth-bound kernels like RMSNorm. We can load X only once and avoid the loop entirely.
@triton.jit
def _add_rms_norm_fwd_fused(
X,
R,
Y,
W,
x_stride0,
x_stride1,
r_stride0,
r_stride1,
y_stride0,
y_stride1,
N,
eps,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0)
X += row * x_stride0
R += row * r_stride0
Y += row * y_stride0
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + cols * x_stride1, mask=mask, other=0.0).to(tl.float32)
r = tl.load(R + cols * r_stride1, mask=mask, other=0.0).to(tl.float32)
x = x + r
tl.store(X + cols * x_stride1, x.to(X.dtype.element_ty), mask=mask)
var = tl.sum(x * x, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
y = x * rstd
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
y *= w
tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask)| if b_temperatures is not None: | ||
| logits.div_(b_temperatures.view((-1, 1))) | ||
|
|
||
| if is_all_greedy and not need_logprobs: | ||
| batch_next_token_ids = torch.argmax(logits, -1) | ||
| if get_env_start_args().mtp_mode: | ||
| batch_next_token_logprobs = torch.zeros( | ||
| batch_next_token_ids.shape, dtype=torch.float32, device=batch_next_token_ids.device | ||
| ) | ||
| return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) | ||
| return batch_next_token_ids.view(-1), None |
There was a problem hiding this comment.
The in-place division logits.div_(b_temperatures.view((-1, 1))) is performed before checking if the request is greedy and doesn't need logprobs. For greedy sampling, temperature scaling does not affect the argmax result. Performing this division on a large logits tensor (e.g., vocab size 150k+) is a significant waste of GPU compute and memory bandwidth. Reordering the greedy check to return early before performing the division will improve performance.
| if b_temperatures is not None: | |
| logits.div_(b_temperatures.view((-1, 1))) | |
| if is_all_greedy and not need_logprobs: | |
| batch_next_token_ids = torch.argmax(logits, -1) | |
| if get_env_start_args().mtp_mode: | |
| batch_next_token_logprobs = torch.zeros( | |
| batch_next_token_ids.shape, dtype=torch.float32, device=batch_next_token_ids.device | |
| ) | |
| return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) | |
| return batch_next_token_ids.view(-1), None | |
| if is_all_greedy and not need_logprobs: | |
| batch_next_token_ids = torch.argmax(logits, -1) | |
| if get_env_start_args().mtp_mode: | |
| batch_next_token_logprobs = torch.zeros( | |
| batch_next_token_ids.shape, dtype=torch.float32, device=batch_next_token_ids.device | |
| ) | |
| return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) | |
| return batch_next_token_ids.view(-1), None | |
| if b_temperatures is not None: | |
| logits.div_(b_temperatures.view((-1, 1))) |
| if ZERO_EXPERT_TOKEN_NUM: | ||
| expert_offs = tl.arange(0, BLOCK_EXPERT) | ||
| tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) |
There was a problem hiding this comment.
In moe_align_fused_kernel, if ZERO_EXPERT_TOKEN_NUM is True, every program block in the grid will concurrently write 0 to expert_token_num_ptr. While functionally correct, this is redundant and causes unnecessary memory write contention. Adding and token_block == 0 to the condition ensures only the first block performs the zeroing.
| if ZERO_EXPERT_TOKEN_NUM: | |
| expert_offs = tl.arange(0, BLOCK_EXPERT) | |
| tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) | |
| if ZERO_EXPERT_TOKEN_NUM and token_block == 0: | |
| expert_offs = tl.arange(0, BLOCK_EXPERT) | |
| tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) |
| def pack_gdn_decode_inputs( | ||
| mixed_qkv: torch.Tensor, | ||
| z_raw: torch.Tensor, | ||
| a_raw: torch.Tensor, | ||
| b_raw: torch.Tensor, | ||
| num_k_heads: int, | ||
| head_k_dim: int, | ||
| num_v_heads: int, | ||
| head_v_dim: int, | ||
| ): |
There was a problem hiding this comment.
In pack_gdn_decode_inputs, z_raw is loaded in the Triton kernel using flattened indexing (z_raw + row * stride_z_b + qkv_offsets), which assumes that the last two dimensions of z_raw are contiguous. However, there is no assertion in the Python wrapper to ensure z_raw is contiguous. If a non-contiguous tensor is passed, it will silently produce incorrect results. We should add an assertion to prevent this.
| def pack_gdn_decode_inputs( | |
| mixed_qkv: torch.Tensor, | |
| z_raw: torch.Tensor, | |
| a_raw: torch.Tensor, | |
| b_raw: torch.Tensor, | |
| num_k_heads: int, | |
| head_k_dim: int, | |
| num_v_heads: int, | |
| head_v_dim: int, | |
| ): | |
| @torch.no_grad() | |
| def pack_gdn_decode_inputs( | |
| mixed_qkv: torch.Tensor, | |
| z_raw: torch.Tensor, | |
| a_raw: torch.Tensor, | |
| b_raw: torch.Tensor, | |
| num_k_heads: int, | |
| head_k_dim: int, | |
| num_v_heads: int, | |
| head_v_dim: int, | |
| ): | |
| assert z_raw.is_contiguous(), "z_raw must be contiguous" |
| def conv_pack_gdn_decode_inputs( | ||
| mixed_qkv: torch.Tensor, | ||
| z_raw: torch.Tensor, | ||
| a_raw: torch.Tensor, | ||
| b_raw: torch.Tensor, | ||
| conv_state: torch.Tensor, | ||
| conv_weight: torch.Tensor, | ||
| conv_bias: torch.Tensor, | ||
| conv_state_indices: torch.Tensor, | ||
| activation: str, | ||
| num_k_heads: int, | ||
| head_k_dim: int, | ||
| num_v_heads: int, | ||
| head_v_dim: int, | ||
| ): |
There was a problem hiding this comment.
In conv_pack_gdn_decode_inputs, z_raw is loaded in the Triton kernel using flattened indexing, which assumes that the last two dimensions of z_raw are contiguous. However, there is no assertion in the Python wrapper to ensure z_raw is contiguous. If a non-contiguous tensor is passed, it will silently produce incorrect results. We should add an assertion to prevent this.
| def conv_pack_gdn_decode_inputs( | |
| mixed_qkv: torch.Tensor, | |
| z_raw: torch.Tensor, | |
| a_raw: torch.Tensor, | |
| b_raw: torch.Tensor, | |
| conv_state: torch.Tensor, | |
| conv_weight: torch.Tensor, | |
| conv_bias: torch.Tensor, | |
| conv_state_indices: torch.Tensor, | |
| activation: str, | |
| num_k_heads: int, | |
| head_k_dim: int, | |
| num_v_heads: int, | |
| head_v_dim: int, | |
| ): | |
| @torch.no_grad() | |
| def conv_pack_gdn_decode_inputs( | |
| mixed_qkv: torch.Tensor, | |
| z_raw: torch.Tensor, | |
| a_raw: torch.Tensor, | |
| b_raw: torch.Tensor, | |
| conv_state: torch.Tensor, | |
| conv_weight: torch.Tensor, | |
| conv_bias: torch.Tensor, | |
| conv_state_indices: torch.Tensor, | |
| activation: str, | |
| num_k_heads: int, | |
| head_k_dim: int, | |
| num_v_heads: int, | |
| head_v_dim: int, | |
| ): | |
| assert z_raw.is_contiguous(), "z_raw must be contiguous" |
| mask=offs_dim < dim_end, | ||
| other=0.0, | ||
| ).to(tl.float32) | ||
| gate = 1.0 / (1.0 + tl.exp(-gate)) |
There was a problem hiding this comment.
| else: | ||
| gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) | ||
| hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32) | ||
| gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) |
There was a problem hiding this comment.
| gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) | ||
| else: | ||
| gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) | ||
| gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) |
There was a problem hiding this comment.
No description provided.