Skip to content

Commit 1968c46

Browse files
committed
[RMSNorm] Support per-head norm
1 parent a925f40 commit 1968c46

2 files changed

Lines changed: 171 additions & 46 deletions

File tree

quack/rmsnorm.py

Lines changed: 101 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ def _set_cluster_n(self):
6969
@cute.jit
7070
def __call__(
7171
self,
72-
mX: cute.Tensor,
73-
mW: Optional[cute.Tensor],
74-
mB: Optional[cute.Tensor],
75-
mRes: Optional[cute.Tensor],
76-
mO: cute.Tensor,
72+
mX: cute.Tensor, # (b, N) or (b, H, N)
73+
mW: Optional[cute.Tensor], # (N,) or (H, N)
74+
mB: Optional[cute.Tensor], # (N,) or (H, N)
75+
mRes: Optional[cute.Tensor], # (b, N) or (b, H, N)
76+
mO: cute.Tensor, # (b, N) or (b, H, N)
7777
mResO: Optional[cute.Tensor],
7878
mRstd: Optional[cute.Tensor],
7979
mMean: Optional[cute.Tensor],
@@ -93,13 +93,16 @@ def __call__(
9393
for mT in (mW, mB)
9494
]
9595
mRstd, mMean = [
96-
layout_utils.expand(mT, dim=1, size=self.N) if const_expr(mT is not None) else None
96+
layout_utils.expand(mT, dim=cute.rank(mT), size=self.N)
97+
if const_expr(mT is not None)
98+
else None
9799
for mT in (mRstd, mMean)
98100
]
101+
num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1
99102
self.kernel(
100103
mX, mW, mB, mRes, mO, mResO, mRstd, mMean, eps, tiler_mn, tiled_copy, threads_per_row
101104
).launch(
102-
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
105+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, num_heads],
103106
block=[num_threads, 1, 1],
104107
cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
105108
stream=stream,
@@ -122,7 +125,7 @@ def kernel(
122125
threads_per_row: cutlass.Constexpr[int],
123126
):
124127
tidx, _, _ = cute.arch.thread_idx()
125-
bidx, _, _ = cute.arch.block_idx()
128+
bidx, _, bidz = cute.arch.block_idx()
126129
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
127130
tv_layout = tiled_copy.layout_tv_tiled
128131

@@ -138,9 +141,16 @@ def kernel(
138141
)
139142
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
140143

141-
shape = mX.shape
144+
# Slice per head
145+
if const_expr(cute.rank(mX) == 3):
146+
mX, mW, mB, mRes, mO, mResO, mRstd, mMean = [
147+
mT[None, bidz, None] if const_expr(mT is not None) else None
148+
for mT in (mX, mW, mB, mRes, mO, mResO, mRstd, mMean)
149+
]
150+
151+
shape = (cute.size(mX, mode=[0]), cute.size(mX, mode=[1]))
142152
idX = cute.make_identity_tensor(shape)
143-
# slice for CTAs
153+
# Slice for CTAs
144154
gX, gRes, gO, gResO, gRstd, gMean, cX = [
145155
cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) if mT is not None else None
146156
for mT in (mX, mRes, mO, mResO, mRstd, mMean, idX)
@@ -323,7 +333,7 @@ def _rmsnorm_fwd(
323333
"""RMSNorm/LayerNorm forward pass.
324334
Args:
325335
x: Input tensor of shape (M, N)
326-
weight: Optional weight tensor of shape (N,)
336+
weight: Optional weight tensor of shape (N,) or (H, N) for per-head weight
327337
eps: Small value for numerical stability
328338
is_layernorm: If True, compute LayerNorm instead of RMSNorm
329339
Returns:
@@ -337,7 +347,8 @@ def _rmsnorm_fwd(
337347
if residual is not None:
338348
assert residual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
339349

340-
_, N = x.shape
350+
N = x.size(-1)
351+
per_head = (weight is not None and weight.dim() == 2) or (bias is not None and bias.dim() == 2)
341352
dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [
342353
torch2cute_dtype_map[t.dtype] if t is not None else None
343354
for t in [x, out, weight, bias, residual, residual_out]
@@ -353,6 +364,7 @@ def _rmsnorm_fwd(
353364
rstd is not None,
354365
mean is not None,
355366
is_layernorm,
367+
per_head,
356368
)(x, weight, bias, residual, out, residual_out, rstd, mean, eps)
357369

358370

@@ -372,8 +384,11 @@ def _rmsnorm_fwd_fake(
372384
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
373385
from quack.cache_utils import COMPILE_ONLY
374386

375-
if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt):
376-
N = x.size(1)
387+
if COMPILE_ONLY and not isinstance(x.size(-1), torch.SymInt):
388+
N = x.size(-1)
389+
per_head = (weight is not None and weight.dim() == 2) or (
390+
bias is not None and bias.dim() == 2
391+
)
377392
dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [
378393
torch2cute_dtype_map[t.dtype] if t is not None else None
379394
for t in [x, out, weight, bias, residual, residual_out]
@@ -389,6 +404,7 @@ def _rmsnorm_fwd_fake(
389404
rstd is not None,
390405
mean is not None,
391406
is_layernorm,
407+
per_head,
392408
)
393409
_compile_rmsnorm_bwd(
394410
N,
@@ -400,6 +416,7 @@ def _rmsnorm_fwd_fake(
400416
res_dtype,
401417
res_out_dtype,
402418
weight is not None,
419+
per_head,
403420
)
404421

405422

@@ -415,16 +432,23 @@ def _compile_rmsnorm_fwd(
415432
has_rstd,
416433
has_mean,
417434
is_layernorm,
435+
per_head,
418436
):
419437
batch_sym = cute.sym_int()
438+
head_sym = cute.sym_int() if per_head else None
439+
batch_shape = (batch_sym, head_sym) if per_head else (batch_sym,)
420440
all_dtypes = [dtype, out_dtype, res_dtype, weight_dtype, bias_dtype, res_out_dtype]
421441
div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
422442
x_cute, out_cute, res_cute, res_out_cute = [
423-
fake_tensor(dt, (batch_sym, N), div) for dt in [dtype, out_dtype, res_dtype, res_out_dtype]
443+
fake_tensor(dt, (*batch_shape, N), div)
444+
for dt in [dtype, out_dtype, res_dtype, res_out_dtype]
445+
]
446+
weight_shape = (head_sym, N) if per_head else (N,)
447+
weight_cute, bias_cute = [
448+
fake_tensor(dt, weight_shape, div) for dt in [weight_dtype, bias_dtype]
424449
]
425-
weight_cute, bias_cute = [fake_tensor(dt, (N,), div) for dt in [weight_dtype, bias_dtype]]
426-
rstd_cute = fake_tensor(Float32, (batch_sym,)) if has_rstd else None
427-
mean_cute = fake_tensor(Float32, (batch_sym,)) if has_mean else None
450+
rstd_cute = fake_tensor(Float32, batch_shape) if has_rstd else None
451+
mean_cute = fake_tensor(Float32, batch_shape) if has_mean else None
428452
return cute.compile(
429453
RMSNorm(dtype, N, is_layernorm=is_layernorm),
430454
x_cute,
@@ -456,7 +480,7 @@ def rmsnorm_fwd(
456480
# so that _layer_norm_fwd_impl doesn't have to return them.
457481
out_dtype = x.dtype if out_dtype is None else out_dtype
458482
out = torch.empty_like(x, dtype=out_dtype)
459-
rstd = torch.empty(x.shape[0], device=x.device, dtype=torch.float32) if store_rstd else None
483+
rstd = torch.empty(*x.shape[:-1], device=x.device, dtype=torch.float32) if store_rstd else None
460484
if residual is not None:
461485
residual_dtype = residual.dtype
462486
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
@@ -476,7 +500,7 @@ def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6):
476500
x_f32 = x.float()
477501
if residual is not None:
478502
residual_f32 = residual.float()
479-
x_f32 += residual_f32
503+
x_f32 = x_f32 + residual_f32
480504
x_norm = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps))
481505
out = x_norm * w if w is not None else x_norm
482506
if bias is not None:
@@ -565,10 +589,11 @@ def __call__(
565589
layout_utils.expand(mW, dim=0, size=tiler_mn[0]) if const_expr(mW is not None) else None
566590
)
567591
num_blocks = sm_count
592+
num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1
568593
self.kernel(
569594
mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row
570595
).launch(
571-
grid=[num_blocks, self.cluster_n, 1],
596+
grid=[num_blocks, self.cluster_n, num_heads],
572597
block=[num_threads, 1, 1],
573598
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
574599
stream=stream,
@@ -591,11 +616,19 @@ def kernel(
591616
threads_per_row: cutlass.Constexpr[int],
592617
):
593618
tidx, _, _ = cute.arch.thread_idx()
594-
bidx_start, _, _ = cute.arch.block_idx()
619+
bidx_start, _, bidz = cute.arch.block_idx()
595620
gdim, _, _ = cute.arch.grid_dim()
596621
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
597622
tv_layout = tiled_copy.layout_tv_tiled
598623

624+
# Slice per head
625+
if const_expr(cute.rank(mX) == 3):
626+
mX, mW, mdO, mdResO, mdX, mdW, mdB, mdRes = [
627+
mT[None, bidz, None] if const_expr(mT is not None) else None
628+
for mT in (mX, mW, mdO, mdResO, mdX, mdW, mdB, mdRes)
629+
]
630+
mRstd = mRstd[None, bidz]
631+
599632
shape = mX.shape
600633
M, N = shape[0], shape[1]
601634
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
@@ -895,22 +928,21 @@ def _rmsnorm_bwd(
895928
) -> None:
896929
"""RMSNorm backward pass.
897930
Args:
898-
x: Input tensor of shape (M, N)
899-
weight: Optional weight tensor of shape (N,)
900-
dout: Upstream gradients tensor of shape (M, N)
901-
rstd: Reciprocal standard deviation tensor of shape (M,)
931+
x: Input tensor of shape (M, N) or (M, H, N) for per-head
932+
weight: Optional weight tensor of shape (N,) or (H, N) for per-head
933+
dout: Upstream gradients tensor of shape (M, N) or (M, H, N)
934+
rstd: Reciprocal standard deviation tensor of shape (M,) or (M, H)
902935
Returns:
903936
Tuple of (dx, dw) where:
904937
- dx: Input gradients tensor of same shape as x
905938
- dw: Weight gradients tensor of same shape as weight (or None if weight is None)
906939
"""
907-
assert x.dim() == 2, "Input must be 2D"
940+
assert x.dim() in (2, 3), "Input must be 2D or 3D"
908941
assert x.is_cuda, "Input tensor must be on CUDA device"
909942
supported_types = {torch.float16, torch.bfloat16, torch.float32}
910943
assert x.dtype in supported_types, "Unsupported dtype"
944+
per_head = x.dim() == 3
911945
if weight is not None:
912-
assert weight.dim() == 1, "Weight must be 1D"
913-
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
914946
assert weight.is_cuda, "Weight tensor must be on CUDA device"
915947
assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16"
916948
if dresidual_out is not None:
@@ -924,7 +956,7 @@ def _rmsnorm_bwd(
924956
assert dresidual.is_cuda
925957
assert dresidual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
926958

927-
N = x.size(1)
959+
N = x.size(-1)
928960
if dw_partial is None and db_partial is None:
929961
assert sm_count is not None
930962
else:
@@ -943,6 +975,7 @@ def _rmsnorm_bwd(
943975
dres_dtype,
944976
dres_out_dtype,
945977
dw_partial is not None,
978+
per_head,
946979
)(x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count)
947980

948981

@@ -962,8 +995,9 @@ def _rmsnorm_bwd_fake(
962995
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
963996
from quack.cache_utils import COMPILE_ONLY
964997

965-
if COMPILE_ONLY and not isinstance(x.size(1), torch.SymInt):
966-
N = x.size(1)
998+
if COMPILE_ONLY and not isinstance(x.size(-1), torch.SymInt):
999+
N = x.size(-1)
1000+
per_head = x.dim() == 3
9671001
if dw_partial is None and db_partial is None and sm_count is None:
9681002
return
9691003
dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [
@@ -980,6 +1014,7 @@ def _rmsnorm_bwd_fake(
9801014
dres_dtype,
9811015
dres_out_dtype,
9821016
dw_partial is not None,
1017+
per_head,
9831018
)
9841019

9851020

@@ -994,18 +1029,23 @@ def _compile_rmsnorm_bwd(
9941029
dres_dtype,
9951030
dres_out_dtype,
9961031
has_dw_partial,
1032+
per_head=False,
9971033
):
9981034
batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int()
1035+
head_sym = cute.sym_int() if per_head else None
1036+
batch_shape = (batch_sym, head_sym) if per_head else (batch_sym,)
9991037
all_dtypes = [dtype, dout_dtype, dx_dtype, dres_dtype, dres_out_dtype]
10001038
div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
10011039
x_cute, dout_cute, dx_cute, dres_out_cute, dres_cute = [
1002-
fake_tensor(dt, (batch_sym, N), div)
1040+
fake_tensor(dt, (*batch_shape, N), div)
10031041
for dt in [dtype, dout_dtype, dx_dtype, dres_out_dtype, dres_dtype]
10041042
]
1005-
weight_cute = fake_tensor(weight_dtype, (N,), div)
1006-
rstd_cute = fake_tensor(Float32, (batch_sym,))
1007-
dw_partial_cute = fake_tensor(Float32, (batch_partial_sym, N), div) if has_dw_partial else None
1008-
db_partial_cute = fake_tensor(Float32, (batch_partial_sym, N), div) if has_db_partial else None
1043+
weight_shape = (head_sym, N) if per_head else (N,)
1044+
weight_cute = fake_tensor(weight_dtype, weight_shape, div)
1045+
rstd_cute = fake_tensor(Float32, batch_shape)
1046+
dw_shape = (batch_partial_sym, head_sym, N) if per_head else (batch_partial_sym, N)
1047+
dw_partial_cute = fake_tensor(Float32, dw_shape, div) if has_dw_partial else None
1048+
db_partial_cute = fake_tensor(Float32, dw_shape, div) if has_db_partial else None
10091049
return cute.compile(
10101050
RMSNormBackward(dtype, N),
10111051
x_cute,
@@ -1033,19 +1073,27 @@ def rmsnorm_bwd(
10331073
has_residual: bool = False,
10341074
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
10351075
device = x.device
1036-
N = x.size(1)
1076+
N = x.size(-1)
1077+
per_head = x.dim() == 3
10371078
dx = torch.empty_like(x)
10381079
if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
10391080
dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
10401081
else:
10411082
dresidual = None
10421083
sm_count = _get_sm_count(N, device)
1084+
if per_head:
1085+
H = x.size(1)
1086+
sm_count = max(round(sm_count / H), 1)
1087+
else:
1088+
H = None
10431089
if weight is not None:
10441090
# Always store partial gradients in fp32 for numerical accuracy
1045-
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
1091+
dw_shape = (sm_count, H, N) if per_head else (sm_count, N)
1092+
dw_partial = torch.empty(dw_shape, device=device, dtype=torch.float32)
10461093
else:
10471094
dw_partial = None
1048-
db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
1095+
db_shape = (sm_count, H, N) if per_head else (sm_count, N)
1096+
db_partial = torch.empty(db_shape, device=device, dtype=torch.float32) if has_bias else None
10491097

10501098
_rmsnorm_bwd(
10511099
x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, sm_count
@@ -1074,10 +1122,14 @@ def forward(
10741122
prenorm=False,
10751123
):
10761124
x_shape_og = x.shape
1125+
per_head = (weight is not None and weight.dim() == 2) or (
1126+
bias is not None and bias.dim() == 2
1127+
)
1128+
last_shape = x_shape_og[-1:] if not per_head else x_shape_og[-2:]
10771129
# Flatten input, ensuring last dim is contiguous
1078-
x = _ensure_contiguous(x.reshape(-1, x.shape[-1]))
1130+
x = _ensure_contiguous(x.reshape(-1, *last_shape))
10791131
if residual is not None:
1080-
residual = _ensure_contiguous(residual.reshape(-1, residual.shape[-1]))
1132+
residual = _ensure_contiguous(residual.reshape(-1, *last_shape))
10811133
need_grad = any(ctx.needs_input_grad[:3])
10821134
out, residual_out, rstd = rmsnorm_fwd(
10831135
x,
@@ -1091,6 +1143,7 @@ def forward(
10911143
)
10921144
ctx.save_for_backward(x if residual is None else residual_out, weight, rstd)
10931145
ctx.has_bias = bias is not None
1146+
ctx.per_head = per_head
10941147
ctx.eps = eps
10951148
ctx.x_shape_og = x_shape_og
10961149
ctx.residual_dtype = residual.dtype if residual is not None else None
@@ -1104,14 +1157,16 @@ def forward(
11041157
def backward(ctx, dout, *args):
11051158
x, weight, rstd = ctx.saved_tensors
11061159
has_bias = ctx.has_bias
1160+
per_head = ctx.per_head
1161+
x_shape_og = ctx.x_shape_og
1162+
last_shape = x_shape_og[-2:] if per_head else x_shape_og[-1:]
11071163
if ctx.prenorm and ctx.residual_dtype is not None:
11081164
dresidual_out = args[0]
1109-
dresidual_out = _ensure_contiguous(dresidual_out.reshape(-1, dresidual_out.shape[-1]))
1165+
dresidual_out = _ensure_contiguous(dresidual_out.reshape(-1, *last_shape))
11101166
else:
11111167
dresidual_out = None
1112-
x_shape_og = ctx.x_shape_og
1113-
# Reshape dout to match the flattened shape used in forward
1114-
dout = _ensure_contiguous(dout.reshape(-1, dout.shape[-1]))
1168+
# Reshape dout to match the shape used in forward
1169+
dout = _ensure_contiguous(dout.reshape(-1, *last_shape))
11151170
dx, dw, db, dresidual = rmsnorm_bwd(
11161171
x,
11171172
weight,

0 commit comments

Comments
 (0)