Skip to content

[Fix] Zero-init chunk-mode backward gradient buffers to prevent NaN propagation#892

Open
xylian86 wants to merge 1 commit into
fla-org:mainfrom
xylian86:fix/zero-init-chunk-bwd-buffers
Open

[Fix] Zero-init chunk-mode backward gradient buffers to prevent NaN propagation#892
xylian86 wants to merge 1 commit into
fla-org:mainfrom
xylian86:fix/zero-init-chunk-bwd-buffers

Conversation

@xylian86

Copy link
Copy Markdown

Summary

Swap empty_like / new_emptyzeros_like / new_zeros for every chunk-mode output / gradient buffer that downstream Triton kernels write with boundary_check. Under packed sequences where T = k.shape[1] > cu_seqlens[-1]
(common with SP > 1), the trailing OOB rows are never written and inherit whatever the CUDA caching allocator hands back — often NaN from a prior intermediate. The next autograd op (e.g. dW = dk.T @ x) reduces over the full T axis and corrupts weight gradients.

Same root cause diagnosed for FA2 in Dao-AILab/flash-attention#41.

Minimal repro

import torch; torch.cuda.init()
# A previous op leaves NaN on a page, then frees it
prev = torch.softmax(torch.full((1024,), float('-inf'), device='cuda'), dim=0)
addr = prev.data_ptr(); del prev
# The next `empty` gets the SAME page back, NaN bits intact
new_buf = torch.empty(1024, dtype=torch.float32, device='cuda')
assert new_buf.data_ptr() == addr
print(new_buf.isnan().sum().item())  # → 1024

Changes

20 1:1 substitutions across 3 files (+20 / -20, no logic changes):

  • fla/ops/gated_delta_rule/wy_fast.pyw, u, dk, dv, dg, db
  • fla/ops/common/chunk_o.pyo, dv (×2), dq, dk, dg, dw
  • fla/ops/common/chunk_delta_h.pyh (×2), v_new, dh (×2), dh0, dv2

Overhead: one cudaMemsetAsync per allocation (~µs, dwarfed by the kernel it precedes).

Verification

Qwen3.5-9B SFT on 4×B200, Ulysses SP=4, packed bf16, smoke config:

Configuration iter 0 iter 1 iter 2
fla-core==0.5.0 (unpatched) ~390/427 ~400/427 crash
This PR 0/427 0/427 0/427

bad/427 = params with any NaN/inf in .grad. Loss converges normally (5.22 → 5.05 → 5.05).

SP=1 is latent — per-rank shards rarely have trailing padding. SP > 1 exposes it deterministically.

Risk

None. zeros_like is a strict superset of empty_like behavior — every cell the kernel writes is identical; only unused / OOB cells change from indeterminate to deterministic zero.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces uninitialized tensor creation with zero-initialized tensors across several files, including chunk_delta_h.py, chunk_o.py, and wy_fast.py, to ensure deterministic behavior. A review comment suggests using g.new_zeros instead of torch.zeros in chunk_o.py for better consistency and to avoid explicit device passing.

Comment thread fla/ops/common/chunk_o.py
dw = torch.empty_like(w) if w is not None else None
dq = q.new_zeros(B, T, HV, K)
dk = k.new_zeros(B, T, HV, K)
dg = torch.zeros(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the surrounding code (lines 692-693) and to avoid explicit device passing, consider using g.new_zeros here. This is also slightly more idiomatic when creating a tensor based on an existing one's properties.

Suggested change
dg = torch.zeros(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
dg = g.new_zeros(NK, *g.shape, dtype=torch.float32) if g is not None else None

@zhiyuan1i

Copy link
Copy Markdown
Collaborator

T = k.shape[1] > cu_seqlens[-1] seems a bit strange, in varlen semantics they should be strictly equal. Zeroing is safe, but training on a large scale will reduce efficiency

@xylian86

xylian86 commented May 13, 2026

Copy link
Copy Markdown
Author

@zhiyuan1i You're right that the canonical contract is T == cu_seqlens[-1]. The precondition I'm describing isn't a varlen-semantic violation — it's a storage-vs-logical mismatch one layer above FLA.

In the SP / packed path, the host allocates k, q, v with shape[1] = T_padded (rounded up so the seq dim is divisible by the SP world size, and sometimes by BT), while cu_seqlens is built from the real per-subseq lengths, so cu_seqlens[-1] = sum(L_i) ≤ T_padded. FLA's invariants still hold — the kernel re-binds T = eos - bos, boundary_check is anchored to that local T, and every position in [0, cu_seqlens[-1]) is written exactly once. The leak is the slack rows [cu_seqlens[-1], T_padded): no kernel ever stores to them, new_empty leaves them holding allocator-recycled NaN bytes, and the next autograd matmul (dW = dk.T @ x) reduces over the full T_padded axis and pulls those NaNs into the weight gradient.

So the real contract any varlen kernel taking explicit cu_seqlens has to defend is T ≥ cu_seqlens[-1], not equality. I hit it first under Ulysses SP + Qwen3.5, but the same precondition is created by chunk-BT padding in trainers, HF DataCollator's pad_to_multiple_of.

On efficiency: I measured ~0.6 % end-to-end slowdown on Qwen3.5-9B SFT (SP=4, 32K, B200).

@zhiyuan1i

Copy link
Copy Markdown
Collaborator

@zhiyuan1i You're right that the canonical contract is . The precondition I'm describing isn't a varlen-semantic violation — it's a storage-vs-logical mismatch one layer above FLA.T == cu_seqlens[-1]

In the SP / packed path, the host allocates , , with (rounded up so the seq dim is divisible by the SP world size, and sometimes by ), while is built from the real per-subseq lengths, so . FLA's invariants still hold — the kernel re-binds , is anchored to that local , and every position in is written exactly once. The leak is the slack rows : no kernel ever stores to them, leaves them holding allocator-recycled NaN bytes, and the next autograd matmul () reduces over the full axis and pulls those NaNs into the weight gradient.k``q``v``shape[1] = T_padded``BT``cu_seqlens``cu_seqlens[-1] = sum(L_i) ≤ T_padded``T = eos - bos``boundary_check``T``[0, cu_seqlens[-1])``[cu_seqlens[-1], T_padded)``new_empty``dW = dk.T @ x``T_padded

So the real contract any varlen kernel taking explicit has to defend is , not equality. I hit it first under Ulysses SP + Qwen3.5, but the same precondition is created by chunk- padding in trainers, HF DataCollator's .cu_seqlens``T ≥ cu_seqlens[-1]``BT``pad_to_multiple_of

On efficiency: I measured ~0.6 % end-to-end slowdown on Qwen3.5-9B SFT (SP=4, 32K, B200).

I think slice should be used to get the true length before going into the kernel. Every kernel and check counts for large scale training.

@xylian86

Copy link
Copy Markdown
Author

@zhiyuan1i Agreed that caller-side slicing is the cleanest approach in principle. The catch is that the configurations triggering this bug are exactly the ones where the caller can't slice.

Concrete example, Ulysses SP, P = 4, S_global = 100, packed batch with one real subseq of length 90 + 10 padding. Rank 3 holds tokens [75, 100): 15 valid + 10 padding, so local cu_seqlens = [0, 15] but k.shape[1] = 25. Slicing k[:, :15] immediately breaks the next collectives — _ConvCPBoundary's all_gather_into_tensor, the next layer's Ulysses all-to-all (assumes S_global = S_local · P), and _get_cp_context's position_ids gather all require uniform shape across ranks. The only way to recover is to F.pad rank 3 back to 25 before the next collective — which costs a .item() host sync + an alloc + memcpy + zero-fill, strictly more than one cudaMemsetAsync.

Also worth noting chunk_bwd_dqkwg already uses new_zeros for dq / dk today (chunk_o.py:692-693), so fla-core has already accepted this cost for half the buffers in the same function.

For the efficiency concern: happy to gate the zero-init on cu_seqlens is not None, so non-varlen pays zero and only the varlen path — where the slack exists — pays the memset (~0.6 % in my measurement).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants