Commit a44ced6
Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape (#19376)
Summary:
Pull Request resolved: #19376
Two related changes that together unblock the QNN export path for VLM/STITO:
(1) ScalelessRMSNorm: re-implement as torch.nn.RMSNorm subclass
ScalelessRMSNorm was previously implemented as a hand-rolled RMS normalization
(decomposed into mean / rsqrt / mul). On the QNN export path, this decomposition
fails to lower for an LLM. Using torch.nn.RMSNorm() directly works.
Re-implement ScalelessRMSNorm as a torch.nn.RMSNorm subclass whose weight is
hardcoded to ones and frozen (requires_grad=False). This keeps the public
interface (ScalelessRMSNorm(dim, eps)) unchanged while letting backends see a
proper RMSNorm op so it composes/decomposes cleanly for QNN.
(2) SDPACustom / QuantizedSDPA: replace .view() with .reshape()
Switching to torch.nn.RMSNorm changes how strides propagate through the export
graph compared to the hand-rolled decomposition, exposing a latent bug in
source_transformation/sdpa.py. The output of torch.ops.llama.custom_sdpa retains
the non-contiguous (transposed) strides of its inputs, so
output.view(bsz, seqlen, self.dim) — which merges the last two dims
(n_heads, head_dim) — fails during torch.export with:
Cannot view a tensor with shape (1, s0, 32, 64) and strides
(2048*s0, 64, 64*s0, 1) as a tensor with shape (1, s0, 2048)
Switching to .reshape() inserts .contiguous() only when needed and matches the pattern already used elsewhere in this file (SDPASimple, SDPAFlex, SDPACoreML, and attention.py).
Reviewed By: billmguo, telgamal-1
Differential Revision: D1042589501 parent ada8e35 commit a44ced6
2 files changed
Lines changed: 14 additions & 11 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | | - | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
45 | 52 | | |
46 | | - | |
| 53 | + | |
47 | 54 | | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
55 | 58 | | |
56 | 59 | | |
57 | 60 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
69 | 69 | | |
70 | 70 | | |
71 | 71 | | |
72 | | - | |
| 72 | + | |
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
| |||
198 | 198 | | |
199 | 199 | | |
200 | 200 | | |
201 | | - | |
| 201 | + | |
202 | 202 | | |
203 | 203 | | |
204 | 204 | | |
| |||
0 commit comments