Skip to content

Commit b25d23c

Browse files
Guang Houmeta-codesync[bot]
authored andcommitted
pack_segments backward CUDA: zero-init gradient buffer to fix uninit-memory NaNs (#5754)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2685 Pull Request resolved: #5754 ## TL;DR `pack_segments_backward_cuda` allocates the input-gradient tensor with `at::empty(...)`. When any `lengths[seq] > max_length`, the unpack kernel never writes the tail rows of that segment, leaving them as **uninitialized memory** that propagates into upstream gradients and causes NaN cascades. Switch the allocator to `at::zeros`. ## Bug The unpack kernel writes only positions `cumsum[seq] + cell` for `cell < min(lengths[seq], max_length)`. When `lengths[seq] > max_length`, positions `[cumsum[seq] + max_length, cumsum[seq] + lengths[seq])` are **never written** and retain whatever was in the freshly-allocated buffer. These rows correspond to events that the forward pass truncated, so they MUST receive zero gradient. With `at::empty` they instead receive garbage. The garbage flows upstream and triggers NaN/Inf cascades in deep networks — for example, LayerNorm backward amplifies random O(1) magnitude values via `1/sqrt(var+eps)` into Inf/NaN within a few hundred steps. ## Fix `at::empty(shape, ...)` → `at::zeros(shape, ...)` for the output tensor. One-line change. The added cost is one device-side memset over the gradient buffer per backward call, which is negligible relative to the unpack kernel and downstream backward work. ## Repro ``` import torch lengths = torch.tensor([10, 5, 8], dtype=torch.int32, device="cuda") t_in = torch.randn(23, 8, device="cuda", requires_grad=True) out = torch.ops.fbgemm.pack_segments(t_in, lengths, max_length=4) out.backward(torch.ones_like(out)) # Print abs-max of rows 0..9 (seq 0 has length 10 > max_length=4, # so rows 4..9 are the truncated tail). ``` Real values captured by running the snippet 5 times. Each row shows abs-max across the cell dimension. `lengths[0] = 10`, `max_length = 4`, so rows 0..3 are in-bounds (expected `≈1`) and rows 4..9 are truncated (expected `0`). **BEFORE fix (`at::empty`)** — rows 4..9 vary wildly across trials, confirming uninitialized memory: ``` row: 0 1 2 3 4 5 6 7 8 9 trial 0: 1.0000 1.0000 1.0000 1.0000 1.8152 1.9762 0.8584 2.3934 2.4721 0.0000 trial 1: 1.0000 1.0000 1.0000 1.0000 2.2231 1.6936 1.8451 1.9498 1.6774 0.5991 trial 2: 1.0000 1.0000 1.0000 1.0000 1.7331 1.6970 1.5790 1.6874 2.4351 1.9974 trial 3: 1.0000 1.0000 1.0000 1.0000 1.1584 2.8627 1.8524 3.2550 1.2574 1.0000 trial 4: 1.0000 1.0000 1.0000 1.0000 2.0911 1.8118 1.6238 1.3304 1.3858 1.6397 ``` **AFTER fix (`at::zeros`)** — rows 4..9 are exactly 0 and identical across trials: ``` row: 0 1 2 3 4 5 6 7 8 9 trial 0: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 trial 1: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 trial 2: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 trial 3: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 trial 4: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 ``` Reviewed By: q10 Differential Revision: D104184777 fbshipit-source-id: 848b007ed535b884256d50af3095ebc5c7181028
1 parent 0b8730f commit b25d23c

2 files changed

Lines changed: 80 additions & 2 deletions

File tree

fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,18 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda(
7070
AT_DISPATCH_INDEX_TYPES(lengths_c.scalar_type(), "unpack_segments_cuda", [&] {
7171
const auto* const lengths_data = lengths_c.const_data_ptr<index_t>();
7272

73-
// Create output tensor of appropriate dimensions
73+
// Create output tensor of appropriate dimensions.
74+
// Use at::zeros (not at::empty): when lengths[seq] > max_length, the
75+
// unpack kernel only writes positions cumsum[seq]+cell for cell<max_length,
76+
// leaving positions [cumsum[seq]+max_length, cumsum[seq]+lengths[seq])
77+
// uninitialized. Those rows correspond to events that were truncated by
78+
// forward pack_segments and so MUST receive zero gradient. With at::empty
79+
// they would receive uninitialized memory, corrupting upstream gradients
80+
// and causing NaN cascades in deep networks.
7481
auto shape = data_contig->sizes().vec();
7582
shape.erase(shape.begin());
7683
shape[0] = total_length;
77-
unpacked_tensor = at::empty(shape, data_contig->options());
84+
unpacked_tensor = at::zeros(shape, data_contig->options());
7885

7986
if (!(data_contig->size(0) &&
8087
data_contig->size(1))) { // TODO: What does this mean?

fbgemm_gpu/test/sparse/pack_segments_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,77 @@ def test_pack_segments_noncontig(
553553
msg="Expected input gradients to be equal but they are not",
554554
)
555555

556+
@unittest.skipIf(*gpu_unavailable)
557+
@given(
558+
dtype=st.sampled_from(
559+
[
560+
torch.float,
561+
torch.half,
562+
torch.bfloat16,
563+
]
564+
),
565+
)
566+
@settings(deadline=None)
567+
def test_pack_segments_backward_truncated(self, dtype: torch.dtype) -> None:
568+
"""
569+
Regression test: when lengths[seq] > max_length, the backward kernel
570+
previously left positions [cumsum[seq]+max_length, cumsum[seq]+lengths[seq])
571+
in the input gradient as uninitialized memory (allocated via at::empty).
572+
573+
After the fix (at::empty -> at::zeros), those positions must be exactly 0
574+
because they correspond to events that were truncated by the forward pass
575+
and so cannot influence the loss.
576+
577+
Without the fix, these positions contain garbage, which propagates upstream
578+
and can cause NaN cascades in deep networks (LayerNorm backward amplification).
579+
"""
580+
# Choose lengths intentionally larger than max_length for some segments
581+
max_length = 4
582+
lengths_cpu = torch.tensor([10, 5, 8, 2], dtype=torch.int)
583+
total_length = int(lengths_cpu.sum().item())
584+
cell_size = 8
585+
586+
# Run multiple trials to detect uninitialized memory:
587+
# if positions are uninit, values change across trials.
588+
observed_grads = []
589+
for _ in range(5):
590+
input_data = (
591+
torch.randn(total_length, cell_size, dtype=dtype)
592+
.cuda() # noqa: CITRINE(redundant_cuda_to_device)
593+
.requires_grad_(True)
594+
)
595+
lengths = lengths_cpu.cuda()
596+
597+
packed = torch.ops.fbgemm.pack_segments(
598+
t_in=input_data, lengths=lengths, max_length=max_length
599+
)
600+
grad_out = torch.ones_like(packed)
601+
packed.backward(grad_out)
602+
603+
# pyre-ignore[16]
604+
observed_grads.append(input_data.grad.detach().cpu().clone())
605+
606+
# Verify: positions where cell < min(lengths[seq], max_length) get grad=1
607+
# positions where cell >= max_length but cell < lengths[seq] get grad=0
608+
cumsum = 0
609+
for seq, L in enumerate(lengths_cpu.tolist()):
610+
for cell in range(L):
611+
row = cumsum + cell
612+
expected = 1.0 if cell < max_length else 0.0
613+
for trial, grad in enumerate(observed_grads):
614+
actual = grad[row].abs().max().item()
615+
self.assertAlmostEqual(
616+
actual,
617+
expected,
618+
places=2,
619+
msg=(
620+
f"trial={trial} seq={seq} cell={cell} row={row}: "
621+
f"expected grad abs.max={expected}, got {actual}. "
622+
"Truncated rows must receive zero gradient (not uninit memory)."
623+
),
624+
)
625+
cumsum += L
626+
556627

557628
extend_test_class(PackedSegmentsTest)
558629

0 commit comments

Comments
 (0)