Skip to content

Commit 6478082

Browse files
committed
Use XNN fused RoPE for HF-style RoPE patterns
1 parent 19bbeac commit 6478082

14 files changed

Lines changed: 687 additions & 46 deletions

File tree

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Conv1dUnsqueezePass,
2020
)
2121
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
22+
from executorch.backends.xnnpack._passes.convert_to_rope import ConvertToRopePass
2223
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
2324
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
2425
ConvertToUpsampleBilinear2d,
@@ -75,6 +76,7 @@ def __init__(
7576
ConvertToLinearPass,
7677
PropagateCustomMetaPass,
7778
ConvertToSDPAPass,
79+
ConvertToRopePass,
7880
ConstPropPass,
7981
FuseBatchNormPass,
8082
DecomposeBatchNorm,
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import enum
8+
import logging
9+
10+
import torch
11+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.backends.xnnpack.partition.graphs import rope
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
from torch.fx.passes.infra.pass_base import PassResult
16+
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class _Layout(enum.Enum):
22+
BSHD = enum.auto()
23+
BHSD = enum.auto()
24+
25+
26+
class ConvertToRopePass(XNNPACKPass):
27+
_BHSD_TO_BSHD_PERM = [0, 2, 1, 3]
28+
29+
def _build_weights(
30+
self,
31+
graph_module: torch.fx.GraphModule,
32+
cos_node: torch.fx.Node,
33+
sin_node: torch.fx.Node,
34+
output_node: torch.fx.Node,
35+
) -> torch.fx.Node:
36+
"""
37+
Construct the XNNPACK RoPE weights tensor from cos and sin inputs.
38+
39+
The most common HF RoPE pattern doubles the frequencies:
40+
cos/sin shape: [batch, seq, head_dim] where head_dim = 2 * (dim // 2)
41+
The first half and second half are identical.
42+
43+
XNNPACK expects weights: [tokens, channels] where:
44+
weights[:, :C/2] = cos values (unique half)
45+
weights[:, C/2:] = sin values (unique half)
46+
47+
We insert graph nodes to slice the unique halves and concatenate them.
48+
49+
Note that this assumes that cos and sin come from a cat([x, x]) node for
50+
this to be sound. We check this in the pass.
51+
"""
52+
head_dim = cos_node.meta["val"].shape[-1]
53+
half_dim = head_dim // 2
54+
55+
with graph_module.graph.inserting_before(output_node):
56+
cos_half = graph_module.graph.call_function(
57+
exir_ops.edge.aten.slice_copy.Tensor,
58+
args=(cos_node, -1, 0, half_dim),
59+
)
60+
sin_half = graph_module.graph.call_function(
61+
exir_ops.edge.aten.slice_copy.Tensor,
62+
args=(sin_node, -1, 0, half_dim),
63+
)
64+
weights = graph_module.graph.call_function(
65+
exir_ops.edge.aten.cat.default,
66+
args=([cos_half, sin_half], -1),
67+
)
68+
69+
return weights
70+
71+
@staticmethod
72+
def _trace_through_unsqueezes(node: torch.fx.Node) -> torch.fx.Node:
73+
"""Walk backwards through consecutive unsqueeze_copy ops to find the source."""
74+
current = node
75+
while (
76+
current.op == "call_function"
77+
and current.target == exir_ops.edge.aten.unsqueeze_copy.default
78+
):
79+
current = current.args[0]
80+
return current
81+
82+
@staticmethod
83+
def _find_trig_source(node: torch.fx.Node) -> torch.fx.Node | None:
84+
"""Walk backwards through unsqueeze_copy ops to find cos/sin op."""
85+
current = node
86+
for _ in range(10):
87+
if current.op != "call_function":
88+
return None
89+
if current.target in (
90+
exir_ops.edge.aten.cos.default,
91+
exir_ops.edge.aten.sin.default,
92+
):
93+
return current
94+
if current.target == exir_ops.edge.aten.unsqueeze_copy.default:
95+
current = current.args[0]
96+
continue
97+
return None
98+
return None
99+
100+
@classmethod
101+
def _is_doubled_cat(cls, trig_node: torch.fx.Node) -> bool:
102+
"""Check that a cos/sin node's input is cat(x, x) with identical args."""
103+
cat_node = trig_node.args[0]
104+
if (
105+
cat_node.op != "call_function"
106+
or cat_node.target != exir_ops.edge.aten.cat.default
107+
):
108+
return False
109+
tensors = cat_node.args[0]
110+
return len(tensors) == 2 and tensors[0] is tensors[1]
111+
112+
@classmethod
113+
def _has_doubled_freqs(
114+
cls,
115+
cos_unsqueezed: torch.fx.Node,
116+
sin_unsqueezed: torch.fx.Node,
117+
) -> bool:
118+
"""Verify that cos/sin frequencies are doubled (first half == second half).
119+
120+
Traces back through unsqueeze_copy ops to find the cos/sin producer,
121+
then verifies its input is cat(x, x) where both args are the same
122+
node — a structural proof that the first and second halves are identical.
123+
"""
124+
cos_trig = cls._find_trig_source(cos_unsqueezed)
125+
sin_trig = cls._find_trig_source(sin_unsqueezed)
126+
127+
if cos_trig is None or sin_trig is None:
128+
return False
129+
130+
return cls._is_doubled_cat(cos_trig) and cls._is_doubled_cat(sin_trig)
131+
132+
@staticmethod
133+
def _trace_through_permute(node: torch.fx.Node) -> torch.fx.Node | None:
134+
"""If node is a permute_copy that swaps dims 1 and 2, return its input."""
135+
if (
136+
node.op == "call_function"
137+
and node.target == exir_ops.edge.aten.permute_copy.default
138+
and list(node.args[1]) == [0, 2, 1, 3]
139+
):
140+
return node.args[0]
141+
return None
142+
143+
@staticmethod
144+
def _get_layout(cos_unsqueezed: torch.fx.Node) -> _Layout | None:
145+
"""Determine the tensor layout from the cos unsqueeze dimension."""
146+
if not (
147+
cos_unsqueezed.op == "call_function"
148+
and cos_unsqueezed.target == exir_ops.edge.aten.unsqueeze_copy.default
149+
):
150+
return None
151+
unsqueeze_dim = cos_unsqueezed.args[1]
152+
ndim = len(cos_unsqueezed.meta["val"].shape)
153+
normalized = unsqueeze_dim if unsqueeze_dim >= 0 else unsqueeze_dim + ndim
154+
if normalized == 2:
155+
return _Layout.BSHD
156+
if normalized == 1:
157+
return _Layout.BHSD
158+
return None
159+
160+
def create_rope(
161+
self,
162+
graph_module: torch.fx.GraphModule,
163+
match: InternalMatch,
164+
):
165+
logger.debug(f"Matched RoPE subgraph: {match}")
166+
167+
# placeholder_nodes are in the order of the pattern's placeholder ops:
168+
# [x, cos_unsqueezed, sin_unsqueezed]
169+
x_node = match.placeholder_nodes[0]
170+
cos_unsqueezed = match.placeholder_nodes[1]
171+
sin_unsqueezed = match.placeholder_nodes[2]
172+
output_node = match.returning_nodes[0]
173+
174+
# xnn_define_rope expects NTHC (batch, tokens, heads, channels) input.
175+
# BSHD (unsqueeze_dim=2) maps directly to NTHC.
176+
# BHSD (unsqueeze_dim=1) requires tracing through the BSHD→BHSD permute
177+
# to recover the BSHD input, then re-permuting the output back to BHSD.
178+
layout = self._get_layout(cos_unsqueezed)
179+
if layout == _Layout.BSHD:
180+
rope_input = x_node
181+
elif layout == _Layout.BHSD:
182+
rope_input = self._trace_through_permute(x_node)
183+
if rope_input is None:
184+
logger.debug("Skipping RoPE fusion: BHSD but x is not a permute_copy")
185+
return
186+
else:
187+
logger.debug("Skipping RoPE fusion: unrecognized layout")
188+
return
189+
190+
cos_node = self._trace_through_unsqueezes(cos_unsqueezed)
191+
sin_node = self._trace_through_unsqueezes(sin_unsqueezed)
192+
193+
if not self._has_doubled_freqs(cos_unsqueezed, sin_unsqueezed):
194+
logger.debug("Skipping RoPE fusion: cannot verify doubled frequencies")
195+
return
196+
197+
weights = self._build_weights(graph_module, cos_node, sin_node, output_node)
198+
199+
with graph_module.graph.inserting_before(output_node):
200+
rope_node = graph_module.graph.create_node(
201+
"call_function",
202+
torch.ops.xnnpack.rope.default,
203+
args=(rope_input, weights),
204+
)
205+
206+
if layout == _Layout.BHSD:
207+
permute_node = graph_module.graph.call_function(
208+
exir_ops.edge.aten.permute_copy.default,
209+
args=(rope_node, self._BHSD_TO_BSHD_PERM),
210+
)
211+
result_node = permute_node
212+
else:
213+
result_node = rope_node
214+
215+
output_node.replace_all_uses_with(result_node)
216+
graph_module.graph.eliminate_dead_code()
217+
218+
# override
219+
def call(self, graph_module: torch.fx.GraphModule):
220+
total_matches = 0
221+
total_fused = 0
222+
for pattern in rope.get_graphs():
223+
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
224+
matches = list(sm.match(graph_module.graph))
225+
total_matches += len(matches)
226+
for match in matches:
227+
try:
228+
self.create_rope(graph_module, match)
229+
total_fused += 1
230+
except Exception:
231+
logger.warning("Failed to fuse RoPE pattern", exc_info=True)
232+
graph_module.recompile()
233+
graph_module = super().call(graph_module).graph_module
234+
235+
return PassResult(graph_module, True)

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
op_prelu,
4444
op_quant_dequant,
4545
op_relu,
46+
op_rope,
4647
op_rsqrt,
4748
op_sigmoid,
4849
op_sin,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNRope,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
# Register the custom op used by the fusion pass. The fused node targets
22+
# this op after ConvertToRopePass replaces the decomposed HF RoPE subgraph.
23+
lib = torch.library.Library("xnnpack", "FRAGMENT")
24+
lib.define("rope(Tensor input, Tensor weights) -> Tensor")
25+
26+
27+
@torch.library.impl(lib, "rope", "CompositeExplicitAutograd")
28+
def rope_impl(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
29+
channels = input.shape[-1]
30+
half_c = channels // 2
31+
cos = weights[..., :half_c]
32+
sin = weights[..., half_c:]
33+
34+
x_real = input[..., :half_c]
35+
x_imag = input[..., half_c:]
36+
37+
out_real = x_real * cos - x_imag * sin
38+
out_imag = x_real * sin + x_imag * cos
39+
return torch.cat([out_real, out_imag], dim=-1)
40+
41+
42+
@torch.library.impl(lib, "rope", "Meta")
43+
def rope_meta(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
44+
return torch.empty_like(input)
45+
46+
47+
@register_node_visitor
48+
class RopeVisitor(NodeVisitor):
49+
target = "rope.default"
50+
51+
def __init__(self, *args) -> None:
52+
super().__init__(*args)
53+
54+
def define_node(
55+
self,
56+
node: torch.fx.Node,
57+
xnn_graph: XNNGraph,
58+
vals_to_ids: Dict[torch.fx.Node, int],
59+
debug_handle: int,
60+
) -> None:
61+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
62+
63+
input_id = vals_to_ids[get_input_node(node, 0)]
64+
weights_id = vals_to_ids[get_input_node(node, 1)]
65+
output_id = vals_to_ids[node]
66+
67+
ser_node = XNode(
68+
xnode_union=XNNRope(
69+
max_tokens=0,
70+
input_id=input_id,
71+
weights_id=weights_id,
72+
output_id=output_id,
73+
flags=0,
74+
),
75+
debug_handle=debug_handle,
76+
)
77+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/operators/op_squeeze.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1515
XNNGraph,
16+
XNNStaticExpandDims,
1617
XNNStaticReshape,
1718
XNode,
1819
)
@@ -98,46 +99,21 @@ def define_node(
9899
vals_to_ids: Dict[torch.fx.Node, int],
99100
debug_handle: int,
100101
) -> None:
101-
102-
dim = cast(int, node.args[1])
103-
check_or_raise(
104-
dim == -1 or dim == len(node.args[0].meta["val"].shape),
105-
"XNNPACK currently only supports unsqueezing in last dimension",
106-
)
107-
108102
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
109103
input_node = get_input_node(node, 0)
110104

111-
# input
112105
input_id = vals_to_ids[input_node]
113-
114-
# output
115106
output_id = vals_to_ids[node]
116107

117-
check_or_raise(
118-
"val" in input_node.meta,
119-
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
120-
)
121-
dynamic_shape = node.meta["val"].shape
122-
new_shape = []
123-
124-
num_dynamic_dims = 0
125-
for dim in dynamic_shape:
126-
if free_symbols(dim):
127-
num_dynamic_dims += 1
128-
new_shape.append(0)
129-
else:
130-
new_shape.append(dim)
131-
132-
check_or_raise(
133-
num_dynamic_dims <= 1,
134-
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
135-
)
108+
dim = cast(int, node.args[1])
109+
input_ndim = len(input_node.meta["val"].shape)
110+
if dim < 0:
111+
dim = input_ndim + 1 + dim
136112

137113
ser_node = XNode(
138-
xnode_union=XNNStaticReshape(
139-
num_dims=len(new_shape),
140-
new_shape=new_shape,
114+
xnode_union=XNNStaticExpandDims(
115+
num_new_axes=1,
116+
new_axes=[dim],
141117
input_id=input_id,
142118
output_id=output_id,
143119
flags=0,

0 commit comments

Comments
 (0)