Commit b25d23c
pack_segments backward CUDA: zero-init gradient buffer to fix uninit-memory NaNs (#5754)
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2685
Pull Request resolved: #5754
## TL;DR
`pack_segments_backward_cuda` allocates the input-gradient tensor with `at::empty(...)`. When any `lengths[seq] > max_length`, the unpack kernel never writes the tail rows of that segment, leaving them as **uninitialized memory** that propagates into upstream gradients and causes NaN cascades. Switch the allocator to `at::zeros`.
## Bug
The unpack kernel writes only positions `cumsum[seq] + cell` for `cell < min(lengths[seq], max_length)`. When `lengths[seq] > max_length`, positions `[cumsum[seq] + max_length, cumsum[seq] + lengths[seq])` are **never written** and retain whatever was in the freshly-allocated buffer.
These rows correspond to events that the forward pass truncated, so they MUST receive zero gradient. With `at::empty` they instead receive garbage. The garbage flows upstream and triggers NaN/Inf cascades in deep networks — for example, LayerNorm backward amplifies random O(1) magnitude values via `1/sqrt(var+eps)` into Inf/NaN within a few hundred steps.
## Fix
`at::empty(shape, ...)` → `at::zeros(shape, ...)` for the output tensor. One-line change. The added cost is one device-side memset over the gradient buffer per backward call, which is negligible relative to the unpack kernel and downstream backward work.
## Repro
```
import torch
lengths = torch.tensor([10, 5, 8], dtype=torch.int32, device="cuda")
t_in = torch.randn(23, 8, device="cuda", requires_grad=True)
out = torch.ops.fbgemm.pack_segments(t_in, lengths, max_length=4)
out.backward(torch.ones_like(out))
# Print abs-max of rows 0..9 (seq 0 has length 10 > max_length=4,
# so rows 4..9 are the truncated tail).
```
Real values captured by running the snippet 5 times. Each row shows abs-max across the cell dimension. `lengths[0] = 10`, `max_length = 4`, so rows 0..3 are in-bounds (expected `≈1`) and rows 4..9 are truncated (expected `0`).
**BEFORE fix (`at::empty`)** — rows 4..9 vary wildly across trials, confirming uninitialized memory:
```
row: 0 1 2 3 4 5 6 7 8 9
trial 0: 1.0000 1.0000 1.0000 1.0000 1.8152 1.9762 0.8584 2.3934 2.4721 0.0000
trial 1: 1.0000 1.0000 1.0000 1.0000 2.2231 1.6936 1.8451 1.9498 1.6774 0.5991
trial 2: 1.0000 1.0000 1.0000 1.0000 1.7331 1.6970 1.5790 1.6874 2.4351 1.9974
trial 3: 1.0000 1.0000 1.0000 1.0000 1.1584 2.8627 1.8524 3.2550 1.2574 1.0000
trial 4: 1.0000 1.0000 1.0000 1.0000 2.0911 1.8118 1.6238 1.3304 1.3858 1.6397
```
**AFTER fix (`at::zeros`)** — rows 4..9 are exactly 0 and identical across trials:
```
row: 0 1 2 3 4 5 6 7 8 9
trial 0: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
trial 1: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
trial 2: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
trial 3: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
trial 4: 1.0000 1.0000 1.0000 1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
```
Reviewed By: q10
Differential Revision: D104184777
fbshipit-source-id: 848b007ed535b884256d50af3095ebc5c71810281 parent 0b8730f commit b25d23c
2 files changed
Lines changed: 80 additions & 2 deletions
Lines changed: 9 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
70 | 70 | | |
71 | 71 | | |
72 | 72 | | |
73 | | - | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
74 | 81 | | |
75 | 82 | | |
76 | 83 | | |
77 | | - | |
| 84 | + | |
78 | 85 | | |
79 | 86 | | |
80 | 87 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
553 | 553 | | |
554 | 554 | | |
555 | 555 | | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
| 624 | + | |
| 625 | + | |
| 626 | + | |
556 | 627 | | |
557 | 628 | | |
558 | 629 | | |
| |||
0 commit comments