Skip to content

Commit f4fbeb5

Browse files
committed
Use XNN fused RoPE for HF-style RoPE patterns
1 parent 19bbeac commit f4fbeb5

14 files changed

Lines changed: 699 additions & 46 deletions

File tree

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Conv1dUnsqueezePass,
2020
)
2121
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
22+
from executorch.backends.xnnpack._passes.convert_to_rope import ConvertToRopePass
2223
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
2324
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
2425
ConvertToUpsampleBilinear2d,
@@ -75,6 +76,7 @@ def __init__(
7576
ConvertToLinearPass,
7677
PropagateCustomMetaPass,
7778
ConvertToSDPAPass,
79+
ConvertToRopePass,
7880
ConstPropPass,
7981
FuseBatchNormPass,
8082
DecomposeBatchNorm,
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11+
from executorch.backends.xnnpack.partition.graphs import rope
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
from torch.fx.passes.infra.pass_base import PassResult
15+
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class ConvertToRopePass(XNNPACKPass):
21+
def _build_weights(
22+
self,
23+
graph_module: torch.fx.GraphModule,
24+
cos_node: torch.fx.Node,
25+
sin_node: torch.fx.Node,
26+
output_node: torch.fx.Node,
27+
) -> torch.fx.Node:
28+
"""
29+
Construct the XNNPACK RoPE weights tensor from cos and sin inputs.
30+
31+
HF precompute_freqs_cis doubles the frequencies:
32+
cos/sin shape: [batch, seq, head_dim] where head_dim = 2 * (dim // 2)
33+
The first half and second half are identical.
34+
35+
XNNPACK expects weights: [tokens, channels] where:
36+
weights[:, :C/2] = cos values (unique half)
37+
weights[:, C/2:] = sin values (unique half)
38+
39+
We insert graph nodes to slice the unique halves and concatenate them.
40+
"""
41+
cos_val = cos_node.meta.get("val")
42+
head_dim = cos_val.shape[-1]
43+
half_dim = head_dim // 2
44+
45+
with graph_module.graph.inserting_before(output_node):
46+
cos_half = graph_module.graph.call_function(
47+
exir_ops.edge.aten.slice_copy.Tensor,
48+
args=(cos_node, -1, 0, half_dim),
49+
)
50+
cos_half.meta["val"] = cos_val.narrow(-1, 0, half_dim)
51+
52+
sin_val = sin_node.meta.get("val")
53+
sin_half = graph_module.graph.call_function(
54+
exir_ops.edge.aten.slice_copy.Tensor,
55+
args=(sin_node, -1, 0, half_dim),
56+
)
57+
sin_half.meta["val"] = sin_val.narrow(-1, 0, half_dim)
58+
59+
weights = graph_module.graph.call_function(
60+
exir_ops.edge.aten.cat.default,
61+
args=([cos_half, sin_half], -1),
62+
)
63+
weights.meta["val"] = torch.cat(
64+
[cos_half.meta["val"], sin_half.meta["val"]], -1
65+
)
66+
67+
return weights
68+
69+
@staticmethod
70+
def _get_unsqueeze_dim(node: torch.fx.Node) -> int:
71+
"""Return the unsqueeze dim if node is an unsqueeze_copy, else -1."""
72+
if (
73+
node.op == "call_function"
74+
and node.target == exir_ops.edge.aten.unsqueeze_copy.default
75+
):
76+
return node.args[1]
77+
return -1
78+
79+
_BHSD_TO_BSHD_PERM = [0, 2, 1, 3]
80+
81+
@staticmethod
82+
def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node:
83+
"""If node is an unsqueeze_copy, return its input. Otherwise return node as-is."""
84+
if (
85+
node.op == "call_function"
86+
and node.target == exir_ops.edge.aten.unsqueeze_copy.default
87+
):
88+
return node.args[0]
89+
return node
90+
91+
@staticmethod
92+
def _find_trig_source(node: torch.fx.Node) -> torch.fx.Node | None:
93+
"""Walk backwards through unsqueeze_copy ops to find cos/sin op."""
94+
current = node
95+
for _ in range(10):
96+
if current.op != "call_function":
97+
return None
98+
if current.target in (
99+
exir_ops.edge.aten.cos.default,
100+
exir_ops.edge.aten.sin.default,
101+
):
102+
return current
103+
if current.target == exir_ops.edge.aten.unsqueeze_copy.default:
104+
current = current.args[0]
105+
continue
106+
return None
107+
return None
108+
109+
@classmethod
110+
def _is_doubled_cat(cls, trig_node: torch.fx.Node) -> bool:
111+
"""Check that a cos/sin node's input is cat(x, x) with identical args."""
112+
cat_node = trig_node.args[0]
113+
if (
114+
cat_node.op != "call_function"
115+
or cat_node.target != exir_ops.edge.aten.cat.default
116+
):
117+
return False
118+
tensors = cat_node.args[0]
119+
return len(tensors) == 2 and tensors[0] is tensors[1]
120+
121+
@classmethod
122+
def _has_doubled_freqs(
123+
cls,
124+
cos_unsqueezed: torch.fx.Node,
125+
sin_unsqueezed: torch.fx.Node,
126+
) -> bool:
127+
"""Verify that cos/sin frequencies are doubled (first half == second half).
128+
129+
Traces back through unsqueeze_copy ops to find the cos/sin producer,
130+
then verifies its input is cat(x, x) where both args are the same
131+
node — a structural proof that the first and second halves are identical.
132+
"""
133+
cos_trig = cls._find_trig_source(cos_unsqueezed)
134+
sin_trig = cls._find_trig_source(sin_unsqueezed)
135+
136+
if cos_trig is None or sin_trig is None:
137+
return False
138+
139+
return cls._is_doubled_cat(cos_trig) and cls._is_doubled_cat(sin_trig)
140+
141+
@staticmethod
142+
def _trace_through_permute(node: torch.fx.Node) -> torch.fx.Node | None:
143+
"""If node is a permute_copy that swaps dims 1 and 2, return its input."""
144+
if (
145+
node.op == "call_function"
146+
and node.target == exir_ops.edge.aten.permute_copy.default
147+
and list(node.args[1]) == [0, 2, 1, 3]
148+
):
149+
return node.args[0]
150+
return None
151+
152+
def _get_layout(self, cos_unsqueezed: torch.fx.Node) -> str | None:
153+
"""Determine the tensor layout from the cos unsqueeze dimension.
154+
155+
Returns "BSHD", "BHSD", or None if the layout cannot be determined.
156+
"""
157+
unsqueeze_dim = self._get_unsqueeze_dim(cos_unsqueezed)
158+
if unsqueeze_dim == -1:
159+
return None
160+
ndim = len(cos_unsqueezed.meta["val"].shape)
161+
normalized = unsqueeze_dim if unsqueeze_dim >= 0 else unsqueeze_dim + ndim
162+
if normalized == 2:
163+
return "BSHD"
164+
if normalized == 1:
165+
return "BHSD"
166+
return None
167+
168+
def create_rope(
169+
self,
170+
graph_module: torch.fx.GraphModule,
171+
match: InternalMatch,
172+
):
173+
logger.debug(f"Matched RoPE subgraph: {match}")
174+
175+
# placeholder_nodes are in the order of the pattern's placeholder ops:
176+
# [x, cos_unsqueezed, sin_unsqueezed]
177+
x_node = match.placeholder_nodes[0]
178+
cos_unsqueezed = match.placeholder_nodes[1]
179+
sin_unsqueezed = match.placeholder_nodes[2]
180+
output_node = match.returning_nodes[0]
181+
182+
# xnn_define_rope expects NTHC (batch, tokens, heads, channels) input.
183+
# BSHD (unsqueeze_dim=2) maps directly to NTHC.
184+
# BHSD (unsqueeze_dim=1) requires tracing through the BSHD→BHSD permute
185+
# to recover the BSHD input, then re-permuting the output back to BHSD.
186+
layout = self._get_layout(cos_unsqueezed)
187+
if layout == "BSHD":
188+
rope_input = x_node
189+
elif layout == "BHSD":
190+
rope_input = self._trace_through_permute(x_node)
191+
if rope_input is None:
192+
logger.debug("Skipping RoPE fusion: BHSD but x is not a permute_copy")
193+
return
194+
else:
195+
logger.debug("Skipping RoPE fusion: unrecognized layout")
196+
return
197+
198+
cos_node = self._trace_through_unsqueeze(cos_unsqueezed)
199+
sin_node = self._trace_through_unsqueeze(sin_unsqueezed)
200+
201+
if not self._has_doubled_freqs(cos_unsqueezed, sin_unsqueezed):
202+
logger.debug("Skipping RoPE fusion: cannot verify doubled frequencies")
203+
return
204+
205+
weights = self._build_weights(graph_module, cos_node, sin_node, output_node)
206+
207+
with graph_module.graph.inserting_before(output_node):
208+
rope_node = graph_module.graph.create_node(
209+
"call_function",
210+
torch.ops.xnnpack.rope.default,
211+
args=(rope_input, weights),
212+
)
213+
rope_node.meta["val"] = torch.empty_like(rope_input.meta["val"])
214+
215+
if layout == "BHSD":
216+
permute_node = graph_module.graph.call_function(
217+
exir_ops.edge.aten.permute_copy.default,
218+
args=(rope_node, self._BHSD_TO_BSHD_PERM),
219+
)
220+
permute_node.meta["val"] = rope_node.meta["val"].permute(
221+
self._BHSD_TO_BSHD_PERM
222+
)
223+
result_node = permute_node
224+
else:
225+
result_node = rope_node
226+
227+
output_node.replace_all_uses_with(result_node)
228+
graph_module.graph.eliminate_dead_code()
229+
230+
# override
231+
def call(self, graph_module: torch.fx.GraphModule):
232+
total_matches = 0
233+
total_fused = 0
234+
for pattern in rope.get_graphs():
235+
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
236+
matches = list(sm.match(graph_module.graph))
237+
total_matches += len(matches)
238+
for match in matches:
239+
try:
240+
self.create_rope(graph_module, match)
241+
total_fused += 1
242+
except Exception:
243+
logger.warning("Failed to fuse RoPE pattern", exc_info=True)
244+
graph_module.recompile()
245+
graph_module = super().call(graph_module).graph_module
246+
247+
return PassResult(graph_module, True)

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
op_prelu,
4444
op_quant_dequant,
4545
op_relu,
46+
op_rope,
4647
op_rsqrt,
4748
op_sigmoid,
4849
op_sin,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNRope,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
# Register the custom op used by the fusion pass. The fused node targets
22+
# this op after ConvertToRopePass replaces the decomposed HF RoPE subgraph.
23+
lib = torch.library.Library("xnnpack", "FRAGMENT")
24+
lib.define("rope(Tensor input, Tensor weights) -> Tensor")
25+
26+
27+
@torch.library.impl(lib, "rope", "CompositeExplicitAutograd")
28+
def rope_impl(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
29+
channels = input.shape[-1]
30+
half_c = channels // 2
31+
cos = weights[..., :half_c]
32+
sin = weights[..., half_c:]
33+
34+
x_real = input[..., :half_c]
35+
x_imag = input[..., half_c:]
36+
37+
out_real = x_real * cos - x_imag * sin
38+
out_imag = x_real * sin + x_imag * cos
39+
return torch.cat([out_real, out_imag], dim=-1)
40+
41+
42+
@torch.library.impl(lib, "rope", "Meta")
43+
def rope_meta(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
44+
return torch.empty_like(input)
45+
46+
47+
@register_node_visitor
48+
class RopeVisitor(NodeVisitor):
49+
target = "rope.default"
50+
51+
def __init__(self, *args) -> None:
52+
super().__init__(*args)
53+
54+
def define_node(
55+
self,
56+
node: torch.fx.Node,
57+
xnn_graph: XNNGraph,
58+
vals_to_ids: Dict[torch.fx.Node, int],
59+
debug_handle: int,
60+
) -> None:
61+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
62+
63+
input_id = vals_to_ids[get_input_node(node, 0)]
64+
weights_id = vals_to_ids[get_input_node(node, 1)]
65+
output_id = vals_to_ids[node]
66+
67+
ser_node = XNode(
68+
xnode_union=XNNRope(
69+
max_tokens=0,
70+
input_id=input_id,
71+
weights_id=weights_id,
72+
output_id=output_id,
73+
flags=0,
74+
),
75+
debug_handle=debug_handle,
76+
)
77+
xnn_graph.xnodes.append(ser_node)

0 commit comments

Comments
 (0)