Skip to content

Commit 0a82163

Browse files
authored
Back out "Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape"
Differential Revision: D105623266 Pull Request resolved: #19655
1 parent 869af13 commit 0a82163

2 files changed

Lines changed: 11 additions & 14 deletions

File tree

examples/models/llama/norm.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,17 @@ def forward(self, x):
4141
return output * self.weight.type_as(x)
4242

4343

44-
class ScalelessRMSNorm(torch.nn.RMSNorm):
45-
"""RMSNorm with weight hardcoded to ones and not trainable.
46-
47-
Equivalent to a scaleless RMSNorm (no learnable scaling) but implemented as a
48-
torch.nn.RMSNorm so the op composes/decomposes cleanly for backends like QNN
49-
instead of being expressed as a hand-rolled decomposition.
50-
"""
51-
44+
class ScalelessRMSNorm(torch.nn.Module):
5245
def __init__(self, dim: int, eps: float = 1e-6):
53-
super().__init__(dim, eps)
46+
super().__init__()
5447
self.dim = dim
55-
with torch.no_grad():
56-
self.weight.fill_(1.0)
57-
self.weight.requires_grad = False
48+
self.eps = eps
49+
50+
def forward(self, x):
51+
x_float = x.float()
52+
return (
53+
x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)
54+
).type_as(x)
5855

5956

6057
class RMSNormCoreML(torch.nn.Module):

examples/models/llama/source_transformation/sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(
6969
0, # dropout probability. Ignored by the code
7070
True, # is_causal
7171
)
72-
return output.reshape(bsz, seqlen, self.dim).to(dtype=input_dtype)
72+
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
7373

7474

7575
def _replace_sdpa_with_custom_op(
@@ -198,7 +198,7 @@ def forward(
198198
v_scale_fp32,
199199
)
200200

201-
return output.reshape(bsz, seqlen, self.dim)
201+
return output.view(bsz, seqlen, self.dim)
202202

203203

204204
def _update_attention_module_with_quantized_sdpa(

0 commit comments

Comments
 (0)