Skip to content

Commit 07aa7e5

Browse files
navsudmeta-codesync[bot]
authored andcommitted
Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape (#19376)
Summary: Pull Request resolved: #19376 Two related changes that together unblock the QNN export path for VLM/STITO: (1) ScalelessRMSNorm: re-implement as torch.nn.RMSNorm subclass ScalelessRMSNorm was previously implemented as a hand-rolled RMS normalization (decomposed into mean / rsqrt / mul). On the QNN export path, this decomposition fails to lower for an LLM. Using torch.nn.RMSNorm() directly works. Re-implement ScalelessRMSNorm as a torch.nn.RMSNorm subclass whose weight is hardcoded to ones and frozen (requires_grad=False). This keeps the public interface (ScalelessRMSNorm(dim, eps)) unchanged while letting backends see a proper RMSNorm op so it composes/decomposes cleanly for QNN. (2) SDPACustom / QuantizedSDPA: replace .view() with .reshape() Switching to torch.nn.RMSNorm changes how strides propagate through the export graph compared to the hand-rolled decomposition, exposing a latent bug in source_transformation/sdpa.py. The output of torch.ops.llama.custom_sdpa retains the non-contiguous (transposed) strides of its inputs, so output.view(bsz, seqlen, self.dim) — which merges the last two dims (n_heads, head_dim) — fails during torch.export with: Cannot view a tensor with shape (1, s0, 32, 64) and strides (2048*s0, 64, 64*s0, 1) as a tensor with shape (1, s0, 2048) Switching to .reshape() inserts .contiguous() only when needed and matches the pattern already used elsewhere in this file (SDPASimple, SDPAFlex, SDPACoreML, and attention.py). Reviewed By: billmguo, telgamal-1 Differential Revision: D104258950
1 parent 91aef57 commit 07aa7e5

2 files changed

Lines changed: 14 additions & 11 deletions

File tree

examples/models/llama/norm.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,20 @@ def forward(self, x):
4141
return output * self.weight.type_as(x)
4242

4343

44-
class ScalelessRMSNorm(torch.nn.Module):
44+
class ScalelessRMSNorm(torch.nn.RMSNorm):
45+
"""RMSNorm with weight hardcoded to ones and not trainable.
46+
47+
Equivalent to a scaleless RMSNorm (no learnable scaling) but implemented as a
48+
torch.nn.RMSNorm so the op composes/decomposes cleanly for backends like QNN
49+
instead of being expressed as a hand-rolled decomposition.
50+
"""
51+
4552
def __init__(self, dim: int, eps: float = 1e-6):
46-
super().__init__()
53+
super().__init__(dim, eps)
4754
self.dim = dim
48-
self.eps = eps
49-
50-
def forward(self, x):
51-
x_float = x.float()
52-
return (
53-
x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)
54-
).type_as(x)
55+
with torch.no_grad():
56+
self.weight.fill_(1.0)
57+
self.weight.requires_grad = False
5558

5659

5760
class RMSNormWithInputScale(torch.nn.Module):

examples/models/llama/source_transformation/sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(
6969
0, # dropout probability. Ignored by the code
7070
True, # is_causal
7171
)
72-
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
72+
return output.reshape(bsz, seqlen, self.dim).to(dtype=input_dtype)
7373

7474

7575
def _replace_sdpa_with_custom_op(
@@ -198,7 +198,7 @@ def forward(
198198
v_scale_fp32,
199199
)
200200

201-
return output.view(bsz, seqlen, self.dim)
201+
return output.reshape(bsz, seqlen, self.dim)
202202

203203

204204
def _update_attention_module_with_quantized_sdpa(

0 commit comments

Comments
 (0)