Commit 8dd8aeb
ssjia
Update on "[ET-VK] Fused RMSNorm operator to fix fp16 overflow"
Fused RMSNorm operator that performs squaring, mean, rsqrt, and
weight scaling in a single shader dispatch. All accumulation is done
in fp32 regardless of input dtype, preventing fp16 overflow when
residual stream values exceed sqrt(65504) ≈ 256.
The Python reference impl (`rms_norm_impl`) must preserve the input
dtype — PyTorch type promotion would otherwise produce fp32 output
from fp16 inputs, and the FusePatternsPass re-trace would propagate
that incorrect dtype through the graph.
Authored by Claude.
Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/)
[ghstack-poisoned]1 file changed
Lines changed: 5 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
33 | 32 | | |
| 33 | + | |
| 34 | + | |
34 | 35 | | |
35 | 36 | | |
36 | 37 | | |
| |||
0 commit comments