@@ -226,7 +226,7 @@ def _replay_precompute_kernel(
226226 other = 0.0 ,
227227 )
228228
229- # Compute raw CB once — shared across all heads in this block
229+ # Compute raw CB once, shared across all heads in this block.
230230 raw_CB = tl .dot (C_all .to (tl .bfloat16 ), tl .trans (B_all ).to (tl .bfloat16 ))
231231
232232 # Store B to cache (once per group, only if this block covers the first heads)
@@ -458,17 +458,18 @@ def _replay_state_update_kernel(
458458 # two back). coeff is all-zero (offs_t < 0), total_decay is 1.0, so the
459459 # replay leaves `state` unchanged — cache contents don't matter on step 0.
460460 coeff = tl .exp (total_dA_cumsum - old_dA_cumsum_all ) * old_dt_all
461- coeff = tl .where (offs_t < prev_num_accepted_tokens , coeff , 0.0 )
461+ accepted_mask = t_mask & (offs_t < prev_num_accepted_tokens )
462+ coeff = tl .where (accepted_mask , coeff , 0.0 )
462463
463- # Load old_x: (BLOCK_SIZE_T, BLOCK_SIZE_M) — single-buffered
464+ # Zero stale rows beyond PNAT to prevent Inf/NaN from reaching tl.dot.
464465 old_x_base = old_x_ptr + cache_batch_idx * stride_old_x_cache + pid_h * stride_old_x_head
465466 old_x_all = tl .load (
466467 old_x_base + offs_t [:, None ] * stride_old_x_T + offs_m [None , :] * stride_old_x_dim ,
467- mask = t_mask [:, None ] & m_mask [None , :],
468+ mask = accepted_mask [:, None ] & m_mask [None , :],
468469 other = 0.0 ,
469470 )
470471
471- # Load old_B from READ buffer: (BLOCK_SIZE_T, BLOCK_SIZE_DSTATE)
472+ # Apply the same accepted-row mask to old_B.
472473 old_B_base = (
473474 old_B_ptr
474475 + cache_batch_idx * stride_old_B_cache
@@ -477,7 +478,7 @@ def _replay_state_update_kernel(
477478 )
478479 old_B_all = tl .load (
479480 old_B_base + offs_t [:, None ] * stride_old_B_T + offs_n [None , :] * stride_old_B_dstate ,
480- mask = t_mask [:, None ] & n_mask [None , :],
481+ mask = accepted_mask [:, None ] & n_mask [None , :],
481482 other = 0.0 ,
482483 ).to (tl .float32 )
483484
@@ -488,7 +489,7 @@ def _replay_state_update_kernel(
488489 total_decay = tl .where (prev_num_accepted_tokens > 0 , tl .exp (total_dA_cumsum ), 1.0 )
489490 state *= total_decay
490491
491- # tl.dot fast-forward: old_x^T @ dB_scaled → (M, dstate)
492+ # tl.dot fast-forward: old_x^T @ dB_scaled -> (M, dstate)
492493 state += tl .dot (tl .trans (old_x_all ).to (tl .bfloat16 ), dB_scaled .to (tl .bfloat16 ))
493494
494495 # Write post-replay state
@@ -771,7 +772,7 @@ def replay_selective_state_update(
771772 device = x .device
772773 BLOCK_SIZE_T = max (triton .next_power_of_2 (T ), 16 )
773774
774- # Allocate precomputed intermediates (per-call, not cached)
775+ # Allocate precomputed intermediates (per-call, not cached).
775776 cb_scaled = torch .empty (
776777 batch , nheads , BLOCK_SIZE_T , BLOCK_SIZE_T , device = device , dtype = torch .float32
777778 )
0 commit comments