Skip to content

Commit 2dd5c67

Browse files
authored
[None][fix] Stabilize Mamba replay state update (#14841)
Signed-off-by: qgai <qgai@nvidia.com>
1 parent ae9226e commit 2dd5c67

3 files changed

Lines changed: 11 additions & 11 deletions

File tree

tensorrt_llm/_torch/modules/mamba/causal_conv1d_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
158158
conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N]
159159
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
160160
else:
161-
# prior-tokens are zeros
161+
# No cached prefix: start the convolution window from zeros.
162162
if KERNEL_WIDTH >= 2: # STRATEGY1
163163
# first chunk and does not have prior-token, so just set to 0
164164
col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,6 @@ def forward(
348348
has_initial_states = mamba_metadata.has_initial_states[:
349349
num_prefills]
350350

351-
has_initial_states_p = has_initial_states[:num_prefills]
352-
conv_states[state_indices_p[~has_initial_states_p]].zero_()
353351
# Fused kernel to avoid expensive .contiguous() call in causal_conv1d_fn.
354352
xbc_p_t = extract_transpose_xbc_prefill(zxbcdt, num_prefill_tokens,
355353
self.tp_d_inner,
@@ -376,6 +374,7 @@ def forward(
376374

377375
initial_states = None
378376
if mamba_metadata.use_initial_states:
377+
# Rows without cached prefix state start SSM from zero.
379378
initial_states = torch.where(
380379
has_initial_states[:, None, None, None],
381380
ssm_states[state_indices_p], 0)

tensorrt_llm/_torch/modules/mamba/replay_selective_state_update.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _replay_precompute_kernel(
226226
other=0.0,
227227
)
228228

229-
# Compute raw CB onceshared across all heads in this block
229+
# Compute raw CB once, shared across all heads in this block.
230230
raw_CB = tl.dot(C_all.to(tl.bfloat16), tl.trans(B_all).to(tl.bfloat16))
231231

232232
# Store B to cache (once per group, only if this block covers the first heads)
@@ -458,17 +458,18 @@ def _replay_state_update_kernel(
458458
# two back). coeff is all-zero (offs_t < 0), total_decay is 1.0, so the
459459
# replay leaves `state` unchanged — cache contents don't matter on step 0.
460460
coeff = tl.exp(total_dA_cumsum - old_dA_cumsum_all) * old_dt_all
461-
coeff = tl.where(offs_t < prev_num_accepted_tokens, coeff, 0.0)
461+
accepted_mask = t_mask & (offs_t < prev_num_accepted_tokens)
462+
coeff = tl.where(accepted_mask, coeff, 0.0)
462463

463-
# Load old_x: (BLOCK_SIZE_T, BLOCK_SIZE_M) — single-buffered
464+
# Zero stale rows beyond PNAT to prevent Inf/NaN from reaching tl.dot.
464465
old_x_base = old_x_ptr + cache_batch_idx * stride_old_x_cache + pid_h * stride_old_x_head
465466
old_x_all = tl.load(
466467
old_x_base + offs_t[:, None] * stride_old_x_T + offs_m[None, :] * stride_old_x_dim,
467-
mask=t_mask[:, None] & m_mask[None, :],
468+
mask=accepted_mask[:, None] & m_mask[None, :],
468469
other=0.0,
469470
)
470471

471-
# Load old_B from READ buffer: (BLOCK_SIZE_T, BLOCK_SIZE_DSTATE)
472+
# Apply the same accepted-row mask to old_B.
472473
old_B_base = (
473474
old_B_ptr
474475
+ cache_batch_idx * stride_old_B_cache
@@ -477,7 +478,7 @@ def _replay_state_update_kernel(
477478
)
478479
old_B_all = tl.load(
479480
old_B_base + offs_t[:, None] * stride_old_B_T + offs_n[None, :] * stride_old_B_dstate,
480-
mask=t_mask[:, None] & n_mask[None, :],
481+
mask=accepted_mask[:, None] & n_mask[None, :],
481482
other=0.0,
482483
).to(tl.float32)
483484

@@ -488,7 +489,7 @@ def _replay_state_update_kernel(
488489
total_decay = tl.where(prev_num_accepted_tokens > 0, tl.exp(total_dA_cumsum), 1.0)
489490
state *= total_decay
490491

491-
# tl.dot fast-forward: old_x^T @ dB_scaled (M, dstate)
492+
# tl.dot fast-forward: old_x^T @ dB_scaled -> (M, dstate)
492493
state += tl.dot(tl.trans(old_x_all).to(tl.bfloat16), dB_scaled.to(tl.bfloat16))
493494

494495
# Write post-replay state
@@ -771,7 +772,7 @@ def replay_selective_state_update(
771772
device = x.device
772773
BLOCK_SIZE_T = max(triton.next_power_of_2(T), 16)
773774

774-
# Allocate precomputed intermediates (per-call, not cached)
775+
# Allocate precomputed intermediates (per-call, not cached).
775776
cb_scaled = torch.empty(
776777
batch, nheads, BLOCK_SIZE_T, BLOCK_SIZE_T, device=device, dtype=torch.float32
777778
)

0 commit comments

Comments
 (0)