Skip to content

Commit cc6fedc

Browse files
navsudfacebook-github-bot
authored andcommitted
Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape (#19376)
Summary: Two related changes that together unblock the QNN export path for VLM/STITO: (1) ScalelessRMSNorm: re-implement as torch.nn.RMSNorm subclass ScalelessRMSNorm was previously implemented as a hand-rolled RMS normalization (decomposed into mean / rsqrt / mul). On the QNN export path, this decomposition fails to lower for an LLM. Using torch.nn.RMSNorm() directly works. Re-implement ScalelessRMSNorm as a torch.nn.RMSNorm subclass whose weight is hardcoded to ones and frozen (requires_grad=False). This keeps the public interface (ScalelessRMSNorm(dim, eps)) unchanged while letting backends see a proper RMSNorm op so it composes/decomposes cleanly for QNN. (2) SDPACustom / QuantizedSDPA: replace .view() with .reshape() Switching to torch.nn.RMSNorm changes how strides propagate through the export graph compared to the hand-rolled decomposition, exposing a latent bug in source_transformation/sdpa.py. The output of torch.ops.llama.custom_sdpa retains the non-contiguous (transposed) strides of its inputs, so output.view(bsz, seqlen, self.dim) — which merges the last two dims (n_heads, head_dim) — fails during torch.export with: Cannot view a tensor with shape (1, s0, 32, 64) and strides (2048*s0, 64, 64*s0, 1) as a tensor with shape (1, s0, 2048) Switching to .reshape() inserts .contiguous() only when needed and matches the pattern already used elsewhere in this file (SDPASimple, SDPAFlex, SDPACoreML, and attention.py). Reviewed By: billmguo, telgamal-1 Differential Revision: D104258950
1 parent a49171d commit cc6fedc

2 files changed

Lines changed: 13 additions & 6 deletions

File tree

examples/models/llama/norm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,23 @@ def forward(self, x):
4242

4343

4444
class ScalelessRMSNorm(torch.nn.Module):
45+
"""RMSNorm without learnable scaling.
46+
47+
Calls F.rms_norm with weight=None so the op composes/decomposes cleanly for
48+
backends like QNN instead of being expressed as a hand-rolled decomposition
49+
of mean / rsqrt / mul. Semantically equivalent to
50+
torch.nn.RMSNorm(elementwise_affine=False), but implemented as a plain
51+
Module to preserve the previous parameterless state_dict signature (no
52+
`weight` attribute / parameter).
53+
"""
54+
4555
def __init__(self, dim: int, eps: float = 1e-6):
4656
super().__init__()
4757
self.dim = dim
4858
self.eps = eps
4959

5060
def forward(self, x):
51-
x_float = x.float()
52-
return (
53-
x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)
54-
).type_as(x)
61+
return F.rms_norm(x, (self.dim,), None, self.eps)
5562

5663

5764
class RMSNormWithInputScale(torch.nn.Module):

examples/models/llama/source_transformation/sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(
6969
0, # dropout probability. Ignored by the code
7070
True, # is_causal
7171
)
72-
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
72+
return output.reshape(bsz, seqlen, self.dim).to(dtype=input_dtype)
7373

7474

7575
def _replace_sdpa_with_custom_op(
@@ -198,7 +198,7 @@ def forward(
198198
v_scale_fp32,
199199
)
200200

201-
return output.view(bsz, seqlen, self.dim)
201+
return output.reshape(bsz, seqlen, self.dim)
202202

203203

204204
def _update_attention_module_with_quantized_sdpa(

0 commit comments

Comments
 (0)