Skip to content

Commit 0900104

Browse files
navsudfacebook-github-bot
authored andcommitted
Make ScalelessRMSNorm a torch.nn.RMSNorm with frozen ones weight
Summary: ScalelessRMSNorm was previously implemented as a hand-rolled RMS normalization (decomposed into mean / rsqrt / mul). On the QNN export path, this decomposition fails to lower. Using torch.nn.RMSNorm() directly works. Re-implemented ScalelessRMSNorm as a torch.nn.RMSNorm subclass whose weight is hardcoded to ones and frozen (requires_grad=False). This keeps the public interface (ScalelessRMSNorm(dim, eps)) unchanged while letting backends see a proper RMSNorm op so it lowers to QNN correctly. Reviewed By: billmguo Differential Revision: D104258950
1 parent 1643611 commit 0900104

1 file changed

Lines changed: 12 additions & 9 deletions

File tree

examples/models/llama/norm.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,20 @@ def forward(self, x):
4141
return output * self.weight.type_as(x)
4242

4343

44-
class ScalelessRMSNorm(torch.nn.Module):
44+
class ScalelessRMSNorm(torch.nn.RMSNorm):
45+
"""RMSNorm with weight hardcoded to ones and not trainable.
46+
47+
Equivalent to a scaleless RMSNorm (no learnable scaling) but implemented as a
48+
torch.nn.RMSNorm so the op composes/decomposes cleanly for backends like QNN
49+
instead of being expressed as a hand-rolled decomposition.
50+
"""
51+
4552
def __init__(self, dim: int, eps: float = 1e-6):
46-
super().__init__()
53+
super().__init__(dim, eps)
4754
self.dim = dim
48-
self.eps = eps
49-
50-
def forward(self, x):
51-
x_float = x.float()
52-
return (
53-
x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)
54-
).type_as(x)
55+
with torch.no_grad():
56+
self.weight.fill_(1.0)
57+
self.weight.requires_grad = False
5558

5659

5760
class RMSNormWithInputScale(torch.nn.Module):

0 commit comments

Comments
 (0)