Skip to content

Commit ae17fd1

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Fused RMSNorm operator to fix fp16 overflow
Pull Request resolved: #18772 Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. ghstack-source-id: 364514329 @exported-using-ghexport Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/)
1 parent a231fbc commit ae17fd1

12 files changed

Lines changed: 991 additions & 0 deletions

File tree

backends/vulkan/custom_ops_lib.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,3 +959,24 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
959959
lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt")
960960
lib.impl(name, select_as_symint_impl, "Meta")
961961
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)
962+
963+
################
964+
## rms_norm ##
965+
################
966+
967+
968+
def rms_norm_impl(
969+
x: torch.Tensor,
970+
weight: torch.Tensor,
971+
eps: float,
972+
) -> torch.Tensor:
973+
input_dtype = x.dtype
974+
variance = x.float().pow(2).mean(-1, keepdim=True)
975+
x_normed = x.float() * torch.rsqrt(variance + eps)
976+
return (x_normed * weight.float()).to(input_dtype)
977+
978+
979+
name = "rms_norm"
980+
lib.define(f"{name}(Tensor x, Tensor weight, float eps) -> Tensor")
981+
lib.impl(name, rms_norm_impl, "CompositeExplicitAutograd")
982+
rms_norm_op = getattr(getattr(torch.ops, namespace), name)

backends/vulkan/op_registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,21 @@ def register_native_layer_norm():
16061606
)
16071607

16081608

1609+
# =============================================================================
1610+
# RmsNorm.cpp
1611+
# =============================================================================
1612+
1613+
1614+
@update_features(exir_ops.edge.et_vk.rms_norm.default)
1615+
def register_rms_norm():
1616+
return OpFeatures(
1617+
inputs_storage=utils.CONTIGUOUS_ANY,
1618+
inputs_dtypes=utils.FP_T,
1619+
supports_prepacking=True,
1620+
supports_resize=True,
1621+
)
1622+
1623+
16091624
#######################
16101625
## Utility functions ##
16111626
#######################

backends/vulkan/patterns/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ fbcode_target(_kind = runtime.python_library,
1616
"quantized_convolution.py",
1717
"quantized_binary.py",
1818
"quantized_unary.py",
19+
"rms_norm.py",
1920
"sdpa.py",
2021
"select_as_symint.py",
2122
],

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import executorch.backends.vulkan.patterns.quantized_unary # noqa
1818

19+
import executorch.backends.vulkan.patterns.rms_norm # noqa
20+
1921
import executorch.backends.vulkan.patterns.rope # noqa
2022

2123
import executorch.backends.vulkan.patterns.rope_hf # noqa
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
11+
from executorch.backends.vulkan.patterns.pattern_registry import (
12+
PatternMatch,
13+
register_pattern_detector,
14+
register_pattern_replacement,
15+
)
16+
17+
from executorch.exir import ExportedProgram
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
20+
21+
_CAST_OPS = {
22+
exir_ops.edge.aten._to_copy.default,
23+
exir_ops.edge.aten.to.dtype,
24+
}
25+
26+
27+
def _skip_casts(node: torch.fx.Node) -> torch.fx.Node:
28+
"""Unwrap chains of dtype-cast nodes to find the underlying value."""
29+
while node.target in _CAST_OPS:
30+
arg0 = node.args[0] if node.args else None
31+
if not isinstance(arg0, torch.fx.Node):
32+
break
33+
node = arg0
34+
# pyre-ignore[7]: node is always a Node; Pyre cannot narrow through loops
35+
return node
36+
37+
38+
class RmsNormMatch(PatternMatch):
39+
"""
40+
Detects the decomposed RMSNorm pattern, including variants where dtype
41+
casts (to_copy) are inserted around the computation.
42+
43+
The canonical pattern emitted by the Llama RMSNorm implementation is:
44+
45+
x_orig (any dtype)
46+
-> to_copy(fp32) -> x_f32
47+
-> mul(x_f32, x_f32) -> mean(dim=-1, keepdim=True)
48+
-> add(eps) -> rsqrt -> rstd_f32
49+
-> mul(x_f32, rstd_f32) -> norm_f32
50+
-> to_copy(orig dtype) -> norm_cast
51+
weight -> to_copy(orig dtype) -> weight_cast
52+
-> mul(norm_cast, weight_cast) ← anchor node
53+
54+
We look through to_copy nodes when comparing tensor identities so that
55+
the match also handles fp32-only models where no casts are present.
56+
57+
The anchor node is the final mul (scale by weight).
58+
"""
59+
60+
def __init__(self, final_mul_node: torch.fx.Node) -> None: # noqa: C901
61+
self.anchor_node = final_mul_node
62+
self.match_found = False
63+
self.all_nodes = [self.anchor_node]
64+
65+
# final_mul: mul(normalized_cast, weight_cast)
66+
# Unwrap casts to reach the underlying norm_mul and weight.
67+
norm_mul_node, self.weight_node = self._identify_norm_mul_and_weight(
68+
final_mul_node
69+
)
70+
if norm_mul_node is None:
71+
return
72+
73+
self.all_nodes.append(norm_mul_node)
74+
75+
# norm_mul: mul(x_f32, rstd_f32)
76+
rsqrt_node, x_for_norm = self._identify_rsqrt_and_input(norm_mul_node)
77+
if rsqrt_node is None:
78+
return
79+
80+
self.all_nodes.append(rsqrt_node)
81+
82+
# rsqrt -> add(mean_sq, eps) -> mean(x_sq, dim=-1, keepdim=True)
83+
add_node = self._get_single_arg_node(
84+
rsqrt_node, exir_ops.edge.aten.rsqrt.default
85+
)
86+
if add_node is None or add_node.target != exir_ops.edge.aten.add.Tensor:
87+
return
88+
89+
self.all_nodes.append(add_node)
90+
91+
self.eps_node = None
92+
mean_node = None
93+
for arg in add_node.args[:2]:
94+
if (
95+
isinstance(arg, torch.fx.Node)
96+
and arg.target == exir_ops.edge.aten.mean.dim
97+
):
98+
mean_node = arg
99+
else:
100+
self.eps_node = arg
101+
102+
if mean_node is None or self.eps_node is None:
103+
return
104+
105+
self.all_nodes.append(mean_node)
106+
107+
# Verify mean has keepdim=True and dim=[-1]
108+
if len(mean_node.args) < 3:
109+
return
110+
mean_dims = mean_node.args[1]
111+
if mean_dims != [-1]:
112+
return
113+
if not mean_node.args[2]:
114+
return
115+
116+
# mean's input should be x_sq = mul(x, x) or pow(x, 2)
117+
sq_node = mean_node.args[0]
118+
if not isinstance(sq_node, torch.fx.Node):
119+
return
120+
121+
self.all_nodes.append(sq_node)
122+
123+
# Use the fp32 x (x_for_norm) as the canonical fp32 input.
124+
# Both mul(x,x) and the norm mul should share the same fp32 source.
125+
x_f32 = (
126+
_skip_casts(x_for_norm)
127+
if isinstance(x_for_norm, torch.fx.Node)
128+
else x_for_norm
129+
)
130+
131+
if sq_node.target == exir_ops.edge.aten.mul.Tensor:
132+
if sq_node.args[0] != sq_node.args[1]:
133+
return
134+
sq_input = sq_node.args[0]
135+
if not isinstance(sq_input, torch.fx.Node):
136+
return
137+
if _skip_casts(sq_input) != x_f32 and sq_input != x_for_norm:
138+
return
139+
elif sq_node.target == exir_ops.edge.aten.pow.Tensor_Scalar:
140+
sq_input = sq_node.args[0]
141+
if not isinstance(sq_input, torch.fx.Node):
142+
return
143+
if _skip_casts(sq_input) != x_f32 and sq_input != x_for_norm:
144+
return
145+
if sq_node.args[1] != 2 and sq_node.args[1] != 2.0:
146+
return
147+
else:
148+
return
149+
150+
# The canonical input node to expose to the fused op is the original
151+
# tensor before any fp32 upcast (i.e. the input to the first to_copy).
152+
# If there's no cast, x_for_norm is already the original input.
153+
self.input_node = (
154+
_skip_casts(x_for_norm)
155+
if isinstance(x_for_norm, torch.fx.Node)
156+
else x_for_norm
157+
)
158+
# Also collect the intermediate cast nodes so they can be cleaned up
159+
cast_node = x_for_norm
160+
while (
161+
isinstance(cast_node, torch.fx.Node)
162+
and cast_node.target in _CAST_OPS
163+
and cast_node not in self.all_nodes
164+
):
165+
self.all_nodes.append(cast_node)
166+
cast_node = cast_node.args[0] if cast_node.args else cast_node
167+
168+
self.match_found = True
169+
170+
def _identify_norm_mul_and_weight(self, final_mul_node):
171+
"""From mul(norm_cast, weight_cast), unwrap casts and find the
172+
underlying norm-mul node and the weight source node."""
173+
if len(final_mul_node.args) < 2:
174+
return None, None
175+
176+
a, b = final_mul_node.args[0], final_mul_node.args[1]
177+
178+
for norm_candidate_raw, weight_candidate_raw in [(a, b), (b, a)]:
179+
if not isinstance(norm_candidate_raw, torch.fx.Node):
180+
continue
181+
norm_candidate = _skip_casts(norm_candidate_raw)
182+
if (
183+
isinstance(norm_candidate, torch.fx.Node)
184+
and norm_candidate.target == exir_ops.edge.aten.mul.Tensor
185+
and self._has_rsqrt_ancestor(norm_candidate)
186+
):
187+
return norm_candidate, weight_candidate_raw
188+
189+
return None, None
190+
191+
def _has_rsqrt_ancestor(self, mul_node):
192+
"""Check if one of mul_node's args is an rsqrt node (possibly through casts)."""
193+
for arg in mul_node.args[:2]:
194+
if not isinstance(arg, torch.fx.Node):
195+
continue
196+
if _skip_casts(arg).target == exir_ops.edge.aten.rsqrt.default:
197+
return True
198+
return False
199+
200+
def _identify_rsqrt_and_input(self, norm_mul_node):
201+
"""From mul(x, rstd), find the rsqrt node and the input x.
202+
The rsqrt may be wrapped in a cast node."""
203+
if len(norm_mul_node.args) < 2:
204+
return None, None
205+
206+
a, b = norm_mul_node.args[0], norm_mul_node.args[1]
207+
208+
for rsqrt_candidate_raw, input_candidate in [(a, b), (b, a)]:
209+
if not isinstance(rsqrt_candidate_raw, torch.fx.Node):
210+
continue
211+
rsqrt_candidate = _skip_casts(rsqrt_candidate_raw)
212+
if (
213+
isinstance(rsqrt_candidate, torch.fx.Node)
214+
and rsqrt_candidate.target == exir_ops.edge.aten.rsqrt.default
215+
):
216+
return rsqrt_candidate, input_candidate
217+
218+
return None, None
219+
220+
def _get_single_arg_node(self, node, expected_target):
221+
"""Get the single input arg of a unary op node."""
222+
if node.target != expected_target:
223+
return None
224+
if len(node.args) < 1 or not isinstance(node.args[0], torch.fx.Node):
225+
return None
226+
return node.args[0]
227+
228+
229+
@register_pattern_detector("rms_norm")
230+
def find_rms_norm_patterns(
231+
node: torch.fx.Node,
232+
) -> Optional[RmsNormMatch]:
233+
if node.target != exir_ops.edge.aten.mul.Tensor:
234+
return None
235+
236+
matched_pattern = RmsNormMatch(node)
237+
if matched_pattern.match_found:
238+
return matched_pattern
239+
240+
return None
241+
242+
243+
##
244+
## Pattern Replacement
245+
##
246+
247+
248+
def _extract_eps_value(eps_node) -> float:
249+
if isinstance(eps_node, (int, float)):
250+
return float(eps_node)
251+
if isinstance(eps_node, torch.fx.Node) and "val" in eps_node.meta:
252+
val = eps_node.meta["val"]
253+
if isinstance(val, torch.Tensor):
254+
return float(val.item())
255+
if isinstance(val, (int, float)):
256+
return float(val)
257+
raise ValueError(f"Cannot extract epsilon value from {eps_node}")
258+
259+
260+
@register_pattern_replacement("rms_norm")
261+
def replace_rms_norm_with_fused_op(
262+
ep: ExportedProgram,
263+
graph_module: torch.fx.GraphModule,
264+
match: RmsNormMatch,
265+
):
266+
eps_val = _extract_eps_value(match.eps_node)
267+
268+
with graph_module.graph.inserting_before(match.anchor_node):
269+
rms_norm_node = graph_module.graph.create_node(
270+
"call_function",
271+
exir_ops.edge.et_vk.rms_norm.default,
272+
args=(
273+
match.input_node,
274+
match.weight_node,
275+
eps_val,
276+
),
277+
)
278+
279+
rms_norm_node.meta["val"] = match.anchor_node.meta["val"]
280+
match.anchor_node.replace_all_uses_with(rms_norm_node)

0 commit comments

Comments
 (0)