Skip to content

Commit 2dfe8e4

Browse files
authored
Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape (#19376)
Differential Revision: D104258950 Pull Request resolved: #19376
1 parent ac9efa7 commit 2dfe8e4

2 files changed

Lines changed: 14 additions & 11 deletions

File tree

examples/models/llama/norm.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,20 @@ def forward(self, x):
4141
return output * self.weight.type_as(x)
4242

4343

44-
class ScalelessRMSNorm(torch.nn.Module):
44+
class ScalelessRMSNorm(torch.nn.RMSNorm):
45+
"""RMSNorm with weight hardcoded to ones and not trainable.
46+
47+
Equivalent to a scaleless RMSNorm (no learnable scaling) but implemented as a
48+
torch.nn.RMSNorm so the op composes/decomposes cleanly for backends like QNN
49+
instead of being expressed as a hand-rolled decomposition.
50+
"""
51+
4552
def __init__(self, dim: int, eps: float = 1e-6):
46-
super().__init__()
53+
super().__init__(dim, eps)
4754
self.dim = dim
48-
self.eps = eps
49-
50-
def forward(self, x):
51-
x_float = x.float()
52-
return (
53-
x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)
54-
).type_as(x)
55+
with torch.no_grad():
56+
self.weight.fill_(1.0)
57+
self.weight.requires_grad = False
5558

5659

5760
class RMSNormWithInputScale(torch.nn.Module):

examples/models/llama/source_transformation/sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(
6969
0, # dropout probability. Ignored by the code
7070
True, # is_causal
7171
)
72-
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
72+
return output.reshape(bsz, seqlen, self.dim).to(dtype=input_dtype)
7373

7474

7575
def _replace_sdpa_with_custom_op(
@@ -198,7 +198,7 @@ def forward(
198198
v_scale_fp32,
199199
)
200200

201-
return output.view(bsz, seqlen, self.dim)
201+
return output.reshape(bsz, seqlen, self.dim)
202202

203203

204204
def _update_attention_module_with_quantized_sdpa(

0 commit comments

Comments
 (0)