Skip to content

Commit dda3f6f

Browse files
author
ssjia
committed
[ET-VK] Fused RMSNorm operator to fix fp16 overflow
Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/) ghstack-source-id: 364237333 Pull Request resolved: #18772
1 parent c7b2efb commit dda3f6f

12 files changed

Lines changed: 990 additions & 0 deletions

File tree

backends/vulkan/custom_ops_lib.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,3 +959,24 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
959959
lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt")
960960
lib.impl(name, select_as_symint_impl, "Meta")
961961
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)
962+
963+
################
964+
## rms_norm ##
965+
################
966+
967+
968+
def rms_norm_impl(
969+
x: torch.Tensor,
970+
weight: torch.Tensor,
971+
eps: float,
972+
) -> torch.Tensor:
973+
input_dtype = x.dtype
974+
variance = x.float().pow(2).mean(-1, keepdim=True)
975+
x_normed = x.float() * torch.rsqrt(variance + eps)
976+
return (x_normed * weight.float()).to(input_dtype)
977+
978+
979+
name = "rms_norm"
980+
lib.define(f"{name}(Tensor x, Tensor weight, float eps) -> Tensor")
981+
lib.impl(name, rms_norm_impl, "CompositeExplicitAutograd")
982+
rms_norm_op = getattr(getattr(torch.ops, namespace), name)

backends/vulkan/op_registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,21 @@ def register_native_layer_norm():
16061606
)
16071607

16081608

1609+
# =============================================================================
1610+
# RmsNorm.cpp
1611+
# =============================================================================
1612+
1613+
1614+
@update_features(exir_ops.edge.et_vk.rms_norm.default)
1615+
def register_rms_norm():
1616+
return OpFeatures(
1617+
inputs_storage=utils.CONTIGUOUS_ANY,
1618+
inputs_dtypes=utils.FP_T,
1619+
supports_prepacking=True,
1620+
supports_resize=True,
1621+
)
1622+
1623+
16091624
#######################
16101625
## Utility functions ##
16111626
#######################

backends/vulkan/patterns/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ fbcode_target(_kind = runtime.python_library,
1616
"quantized_convolution.py",
1717
"quantized_binary.py",
1818
"quantized_unary.py",
19+
"rms_norm.py",
1920
"sdpa.py",
2021
"select_as_symint.py",
2122
],

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import executorch.backends.vulkan.patterns.quantized_unary # noqa
1818

19+
import executorch.backends.vulkan.patterns.rms_norm # noqa
20+
1921
import executorch.backends.vulkan.patterns.rope # noqa
2022

2123
import executorch.backends.vulkan.patterns.rope_hf # noqa
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
11+
from executorch.backends.vulkan.patterns.pattern_registry import (
12+
PatternMatch,
13+
register_pattern_detector,
14+
register_pattern_replacement,
15+
)
16+
17+
from executorch.exir import ExportedProgram
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
20+
21+
_CAST_OPS = {
22+
exir_ops.edge.aten._to_copy.default,
23+
exir_ops.edge.aten.to.dtype,
24+
}
25+
26+
27+
def _skip_casts(node: torch.fx.Node) -> torch.fx.Node:
28+
"""Unwrap chains of dtype-cast nodes to find the underlying value."""
29+
while isinstance(node, torch.fx.Node) and node.target in _CAST_OPS:
30+
if node.args and isinstance(node.args[0], torch.fx.Node):
31+
node = node.args[0]
32+
else:
33+
break
34+
return node
35+
36+
37+
class RmsNormMatch(PatternMatch):
38+
"""
39+
Detects the decomposed RMSNorm pattern, including variants where dtype
40+
casts (to_copy) are inserted around the computation.
41+
42+
The canonical pattern emitted by the Llama RMSNorm implementation is:
43+
44+
x_orig (any dtype)
45+
-> to_copy(fp32) -> x_f32
46+
-> mul(x_f32, x_f32) -> mean(dim=-1, keepdim=True)
47+
-> add(eps) -> rsqrt -> rstd_f32
48+
-> mul(x_f32, rstd_f32) -> norm_f32
49+
-> to_copy(orig dtype) -> norm_cast
50+
weight -> to_copy(orig dtype) -> weight_cast
51+
-> mul(norm_cast, weight_cast) ← anchor node
52+
53+
We look through to_copy nodes when comparing tensor identities so that
54+
the match also handles fp32-only models where no casts are present.
55+
56+
The anchor node is the final mul (scale by weight).
57+
"""
58+
59+
def __init__(self, final_mul_node: torch.fx.Node) -> None: # noqa: C901
60+
self.anchor_node = final_mul_node
61+
self.match_found = False
62+
self.all_nodes = [self.anchor_node]
63+
64+
# final_mul: mul(normalized_cast, weight_cast)
65+
# Unwrap casts to reach the underlying norm_mul and weight.
66+
norm_mul_node, self.weight_node = self._identify_norm_mul_and_weight(
67+
final_mul_node
68+
)
69+
if norm_mul_node is None:
70+
return
71+
72+
self.all_nodes.append(norm_mul_node)
73+
74+
# norm_mul: mul(x_f32, rstd_f32)
75+
rsqrt_node, x_for_norm = self._identify_rsqrt_and_input(norm_mul_node)
76+
if rsqrt_node is None:
77+
return
78+
79+
self.all_nodes.append(rsqrt_node)
80+
81+
# rsqrt -> add(mean_sq, eps) -> mean(x_sq, dim=-1, keepdim=True)
82+
add_node = self._get_single_arg_node(
83+
rsqrt_node, exir_ops.edge.aten.rsqrt.default
84+
)
85+
if add_node is None or add_node.target != exir_ops.edge.aten.add.Tensor:
86+
return
87+
88+
self.all_nodes.append(add_node)
89+
90+
self.eps_node = None
91+
mean_node = None
92+
for arg in add_node.args[:2]:
93+
if (
94+
isinstance(arg, torch.fx.Node)
95+
and arg.target == exir_ops.edge.aten.mean.dim
96+
):
97+
mean_node = arg
98+
else:
99+
self.eps_node = arg
100+
101+
if mean_node is None or self.eps_node is None:
102+
return
103+
104+
self.all_nodes.append(mean_node)
105+
106+
# Verify mean has keepdim=True and dim=[-1]
107+
if len(mean_node.args) < 3:
108+
return
109+
mean_dims = mean_node.args[1]
110+
if mean_dims != [-1]:
111+
return
112+
if not mean_node.args[2]:
113+
return
114+
115+
# mean's input should be x_sq = mul(x, x) or pow(x, 2)
116+
sq_node = mean_node.args[0]
117+
if not isinstance(sq_node, torch.fx.Node):
118+
return
119+
120+
self.all_nodes.append(sq_node)
121+
122+
# Use the fp32 x (x_for_norm) as the canonical fp32 input.
123+
# Both mul(x,x) and the norm mul should share the same fp32 source.
124+
x_f32 = (
125+
_skip_casts(x_for_norm)
126+
if isinstance(x_for_norm, torch.fx.Node)
127+
else x_for_norm
128+
)
129+
130+
if sq_node.target == exir_ops.edge.aten.mul.Tensor:
131+
if sq_node.args[0] != sq_node.args[1]:
132+
return
133+
sq_input = sq_node.args[0]
134+
if not isinstance(sq_input, torch.fx.Node):
135+
return
136+
if _skip_casts(sq_input) != x_f32 and sq_input != x_for_norm:
137+
return
138+
elif sq_node.target == exir_ops.edge.aten.pow.Tensor_Scalar:
139+
sq_input = sq_node.args[0]
140+
if not isinstance(sq_input, torch.fx.Node):
141+
return
142+
if _skip_casts(sq_input) != x_f32 and sq_input != x_for_norm:
143+
return
144+
if sq_node.args[1] != 2 and sq_node.args[1] != 2.0:
145+
return
146+
else:
147+
return
148+
149+
# The canonical input node to expose to the fused op is the original
150+
# tensor before any fp32 upcast (i.e. the input to the first to_copy).
151+
# If there's no cast, x_for_norm is already the original input.
152+
self.input_node = (
153+
_skip_casts(x_for_norm)
154+
if isinstance(x_for_norm, torch.fx.Node)
155+
else x_for_norm
156+
)
157+
# Also collect the intermediate cast nodes so they can be cleaned up
158+
cast_node = x_for_norm
159+
while (
160+
isinstance(cast_node, torch.fx.Node)
161+
and cast_node.target in _CAST_OPS
162+
and cast_node not in self.all_nodes
163+
):
164+
self.all_nodes.append(cast_node)
165+
cast_node = cast_node.args[0] if cast_node.args else cast_node
166+
167+
self.match_found = True
168+
169+
def _identify_norm_mul_and_weight(self, final_mul_node):
170+
"""From mul(norm_cast, weight_cast), unwrap casts and find the
171+
underlying norm-mul node and the weight source node."""
172+
if len(final_mul_node.args) < 2:
173+
return None, None
174+
175+
a, b = final_mul_node.args[0], final_mul_node.args[1]
176+
177+
for norm_candidate_raw, weight_candidate_raw in [(a, b), (b, a)]:
178+
if not isinstance(norm_candidate_raw, torch.fx.Node):
179+
continue
180+
norm_candidate = _skip_casts(norm_candidate_raw)
181+
if (
182+
isinstance(norm_candidate, torch.fx.Node)
183+
and norm_candidate.target == exir_ops.edge.aten.mul.Tensor
184+
and self._has_rsqrt_ancestor(norm_candidate)
185+
):
186+
return norm_candidate, weight_candidate_raw
187+
188+
return None, None
189+
190+
def _has_rsqrt_ancestor(self, mul_node):
191+
"""Check if one of mul_node's args is an rsqrt node (possibly through casts)."""
192+
for arg in mul_node.args[:2]:
193+
if not isinstance(arg, torch.fx.Node):
194+
continue
195+
if _skip_casts(arg).target == exir_ops.edge.aten.rsqrt.default:
196+
return True
197+
return False
198+
199+
def _identify_rsqrt_and_input(self, norm_mul_node):
200+
"""From mul(x, rstd), find the rsqrt node and the input x.
201+
The rsqrt may be wrapped in a cast node."""
202+
if len(norm_mul_node.args) < 2:
203+
return None, None
204+
205+
a, b = norm_mul_node.args[0], norm_mul_node.args[1]
206+
207+
for rsqrt_candidate_raw, input_candidate in [(a, b), (b, a)]:
208+
if not isinstance(rsqrt_candidate_raw, torch.fx.Node):
209+
continue
210+
rsqrt_candidate = _skip_casts(rsqrt_candidate_raw)
211+
if (
212+
isinstance(rsqrt_candidate, torch.fx.Node)
213+
and rsqrt_candidate.target == exir_ops.edge.aten.rsqrt.default
214+
):
215+
return rsqrt_candidate, input_candidate
216+
217+
return None, None
218+
219+
def _get_single_arg_node(self, node, expected_target):
220+
"""Get the single input arg of a unary op node."""
221+
if node.target != expected_target:
222+
return None
223+
if len(node.args) < 1 or not isinstance(node.args[0], torch.fx.Node):
224+
return None
225+
return node.args[0]
226+
227+
228+
@register_pattern_detector("rms_norm")
229+
def find_rms_norm_patterns(
230+
node: torch.fx.Node,
231+
) -> Optional[RmsNormMatch]:
232+
if node.target != exir_ops.edge.aten.mul.Tensor:
233+
return None
234+
235+
matched_pattern = RmsNormMatch(node)
236+
if matched_pattern.match_found:
237+
return matched_pattern
238+
239+
return None
240+
241+
242+
##
243+
## Pattern Replacement
244+
##
245+
246+
247+
def _extract_eps_value(eps_node) -> float:
248+
if isinstance(eps_node, (int, float)):
249+
return float(eps_node)
250+
if isinstance(eps_node, torch.fx.Node) and "val" in eps_node.meta:
251+
val = eps_node.meta["val"]
252+
if isinstance(val, torch.Tensor):
253+
return float(val.item())
254+
if isinstance(val, (int, float)):
255+
return float(val)
256+
raise ValueError(f"Cannot extract epsilon value from {eps_node}")
257+
258+
259+
@register_pattern_replacement("rms_norm")
260+
def replace_rms_norm_with_fused_op(
261+
ep: ExportedProgram,
262+
graph_module: torch.fx.GraphModule,
263+
match: RmsNormMatch,
264+
):
265+
eps_val = _extract_eps_value(match.eps_node)
266+
267+
with graph_module.graph.inserting_before(match.anchor_node):
268+
rms_norm_node = graph_module.graph.create_node(
269+
"call_function",
270+
exir_ops.edge.et_vk.rms_norm.default,
271+
args=(
272+
match.input_node,
273+
match.weight_node,
274+
eps_val,
275+
),
276+
)
277+
278+
rms_norm_node.meta["val"] = match.anchor_node.meta["val"]
279+
match.anchor_node.replace_all_uses_with(rms_norm_node)

0 commit comments

Comments
 (0)