Skip to content

Commit 8dd8aeb

Browse files
author
ssjia
committed
Update on "[ET-VK] Fused RMSNorm operator to fix fp16 overflow"
Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/) [ghstack-poisoned]
2 parents 0889ac4 + 0aa9212 commit 8dd8aeb

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

backends/vulkan/patterns/rms_norm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@
2626

2727
def _skip_casts(node: torch.fx.Node) -> torch.fx.Node:
2828
"""Unwrap chains of dtype-cast nodes to find the underlying value."""
29-
while isinstance(node, torch.fx.Node) and node.target in _CAST_OPS:
30-
if node.args and isinstance(node.args[0], torch.fx.Node):
31-
node = node.args[0]
32-
else:
29+
while node.target in _CAST_OPS:
30+
arg0 = node.args[0] if node.args else None
31+
if not isinstance(arg0, torch.fx.Node):
3332
break
33+
node = arg0
34+
# pyre-ignore[7]: node is always a Node; Pyre cannot narrow through loops
3435
return node
3536

3637

0 commit comments

Comments
 (0)