[Fix] Zero-init chunk-mode backward gradient buffers to prevent NaN propagation#892
[Fix] Zero-init chunk-mode backward gradient buffers to prevent NaN propagation#892xylian86 wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request replaces uninitialized tensor creation with zero-initialized tensors across several files, including chunk_delta_h.py, chunk_o.py, and wy_fast.py, to ensure deterministic behavior. A review comment suggests using g.new_zeros instead of torch.zeros in chunk_o.py for better consistency and to avoid explicit device passing.
| dw = torch.empty_like(w) if w is not None else None | ||
| dq = q.new_zeros(B, T, HV, K) | ||
| dk = k.new_zeros(B, T, HV, K) | ||
| dg = torch.zeros(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None |
There was a problem hiding this comment.
For consistency with the surrounding code (lines 692-693) and to avoid explicit device passing, consider using g.new_zeros here. This is also slightly more idiomatic when creating a tensor based on an existing one's properties.
| dg = torch.zeros(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None | |
| dg = g.new_zeros(NK, *g.shape, dtype=torch.float32) if g is not None else None |
|
T = k.shape[1] > cu_seqlens[-1] seems a bit strange, in varlen semantics they should be strictly equal. Zeroing is safe, but training on a large scale will reduce efficiency |
|
@zhiyuan1i You're right that the canonical contract is In the SP / packed path, the host allocates So the real contract any varlen kernel taking explicit On efficiency: I measured ~0.6 % end-to-end slowdown on Qwen3.5-9B SFT (SP=4, 32K, B200). |
I think slice should be used to get the true length before going into the kernel. Every kernel and check counts for large scale training. |
|
@zhiyuan1i Agreed that caller-side slicing is the cleanest approach in principle. The catch is that the configurations triggering this bug are exactly the ones where the caller can't slice. Concrete example, Ulysses SP, Also worth noting For the efficiency concern: happy to gate the zero-init on |
Summary
Swap
empty_like/new_empty→zeros_like/new_zerosfor every chunk-mode output / gradient buffer that downstream Triton kernels write withboundary_check. Under packed sequences whereT = k.shape[1] > cu_seqlens[-1](common with SP > 1), the trailing OOB rows are never written and inherit whatever the CUDA caching allocator hands back — often NaN from a prior intermediate. The next autograd op (e.g.
dW = dk.T @ x) reduces over the fullTaxis and corrupts weight gradients.Same root cause diagnosed for FA2 in Dao-AILab/flash-attention#41.
Minimal repro
Changes
20 1:1 substitutions across 3 files (
+20 / -20, no logic changes):fla/ops/gated_delta_rule/wy_fast.py—w, u, dk, dv, dg, dbfla/ops/common/chunk_o.py—o, dv (×2), dq, dk, dg, dwfla/ops/common/chunk_delta_h.py—h (×2), v_new, dh (×2), dh0, dv2Overhead: one
cudaMemsetAsyncper allocation (~µs, dwarfed by the kernel it precedes).Verification
Qwen3.5-9B SFT on 4×B200, Ulysses SP=4, packed bf16, smoke config:
fla-core==0.5.0(unpatched)bad/427= params with any NaN/inf in.grad. Loss converges normally (5.22 → 5.05 → 5.05).SP=1 is latent — per-rank shards rarely have trailing padding. SP > 1 exposes it deterministically.
Risk
None.
zeros_likeis a strict superset ofempty_likebehavior — every cell the kernel writes is identical; only unused / OOB cells change from indeterminate to deterministic zero.