Skip to content

Commit 9310d47

Browse files
committed
[LayerNorm] Fix OOB x being included in variance calc
1 parent 6c9976c commit 9310d47

2 files changed

Lines changed: 30 additions & 1 deletion

File tree

quack/rmsnorm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,16 @@ def kernel(
266266
if const_expr(mRes is not None):
267267
copy(tXgRes, tXrRes)
268268
x += tXrRes.load().to(cute.Float32)
269+
x_centered = x - mean
270+
if const_expr(not is_even_N):
271+
# OOB lanes are zero-filled for the mean pass, but they must contribute zero
272+
# to the variance pass (not mean^2 from (0 - mean)^2).
273+
tXrX_centered = cute.make_rmem_tensor_like(tXrX, Float32)
274+
tXrX_centered.store(x_centered)
275+
utils.fill_oob(tXrX_centered, tXpX, fill_value=Float32.zero)
276+
x_centered = tXrX_centered.load()
269277
sum_sq_x_sub_mean = row_reduce(
270-
(x - mean) * (x - mean),
278+
x_centered * x_centered,
271279
cute.ReductionOp.ADD,
272280
threads_per_row,
273281
reduction_buffer[None, None, 1],

tests/test_layernorm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,27 @@ def test_layernorm_forward(M, N, input_dtype, eps):
5858
torch.testing.assert_close(mean, mean_ref_val, atol=6e-4, rtol=6e-4)
5959

6060

61+
def test_layernorm_forward_masks_oob_variance_lanes():
62+
"""Regression: ragged N must not include padding lanes in LayerNorm variance."""
63+
device = "cuda"
64+
M, N = 3, 769 # N is not a full copy/reduction tile, so the last tile has OOB lanes.
65+
eps = 1e-5
66+
67+
cols = torch.arange(N, device=device, dtype=torch.float32)
68+
rows = torch.arange(M, device=device, dtype=torch.float32).unsqueeze(1)
69+
x = 1000.0 + rows * 100.0 + ((cols % 17) - 8.0) / 8.0
70+
weight = torch.ones(N, device=device, dtype=torch.float32)
71+
72+
out, rstd, mean = layernorm_fwd(x, weight, eps=eps, return_rstd=True, return_mean=True)
73+
out_ref = layernorm_ref(x, weight, eps=eps)
74+
rstd_ref_val = layernorm_rstd_ref(x, eps=eps)
75+
mean_ref_val = layernorm_mean_ref(x)
76+
77+
torch.testing.assert_close(out, out_ref, atol=1e-4, rtol=1e-4)
78+
torch.testing.assert_close(rstd, rstd_ref_val, atol=1e-5, rtol=1e-5)
79+
torch.testing.assert_close(mean, mean_ref_val, atol=1e-5, rtol=1e-5)
80+
81+
6182
@pytest.mark.parametrize("return_rstd", [True, False])
6283
@pytest.mark.parametrize("return_mean", [True, False])
6384
def test_layernormnorm_return_rstd_option(return_rstd, return_mean):

0 commit comments

Comments
 (0)