Commit 0900104
Make ScalelessRMSNorm a torch.nn.RMSNorm with frozen ones weight
Summary:
ScalelessRMSNorm was previously implemented as a hand-rolled RMS normalization
(decomposed into mean / rsqrt / mul). On the QNN export path, this decomposition
fails to lower. Using torch.nn.RMSNorm() directly works.
Re-implemented ScalelessRMSNorm as a torch.nn.RMSNorm subclass whose weight is
hardcoded to ones and frozen (requires_grad=False). This keeps the public
interface (ScalelessRMSNorm(dim, eps)) unchanged while letting backends see a
proper RMSNorm op so it lowers to QNN correctly.
Reviewed By: billmguo
Differential Revision: D1042589501 parent 1643611 commit 0900104
1 file changed
Lines changed: 12 additions & 9 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | | - | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
45 | 52 | | |
46 | | - | |
| 53 | + | |
47 | 54 | | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
55 | 58 | | |
56 | 59 | | |
57 | 60 | | |
| |||
0 commit comments