Skip to content

Commit 154321b

Browse files
committed
Use XNN fused RoPE for HF-style RoPE patterns
1 parent 19bbeac commit 154321b

14 files changed

Lines changed: 620 additions & 46 deletions

File tree

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Conv1dUnsqueezePass,
2020
)
2121
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
22+
from executorch.backends.xnnpack._passes.convert_to_rope import ConvertToRopePass
2223
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
2324
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
2425
ConvertToUpsampleBilinear2d,
@@ -75,6 +76,7 @@ def __init__(
7576
ConvertToLinearPass,
7677
PropagateCustomMetaPass,
7778
ConvertToSDPAPass,
79+
ConvertToRopePass,
7880
ConstPropPass,
7981
FuseBatchNormPass,
8082
DecomposeBatchNorm,
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11+
from executorch.backends.xnnpack.partition.graphs import rope
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
from torch.fx.passes.infra.pass_base import PassResult
15+
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class ConvertToRopePass(XNNPACKPass):
21+
def _build_weights(
22+
self,
23+
graph_module: torch.fx.GraphModule,
24+
cos_node: torch.fx.Node,
25+
sin_node: torch.fx.Node,
26+
output_node: torch.fx.Node,
27+
) -> torch.fx.Node:
28+
"""
29+
Construct the XNNPACK RoPE weights tensor from cos and sin inputs.
30+
31+
HF precompute_freqs_cis doubles the frequencies:
32+
cos/sin shape: [batch, seq, head_dim] where head_dim = 2 * (dim // 2)
33+
The first half and second half are identical.
34+
35+
XNNPACK expects weights: [tokens, channels] where:
36+
weights[:, :C/2] = cos values (unique half)
37+
weights[:, C/2:] = sin values (unique half)
38+
39+
We insert graph nodes to slice the unique halves and concatenate them.
40+
"""
41+
cos_val = cos_node.meta.get("val")
42+
head_dim = cos_val.shape[-1]
43+
half_dim = head_dim // 2
44+
45+
with graph_module.graph.inserting_before(output_node):
46+
cos_half = graph_module.graph.call_function(
47+
exir_ops.edge.aten.slice_copy.Tensor,
48+
args=(cos_node, -1, 0, half_dim),
49+
)
50+
cos_half.meta["val"] = cos_val.narrow(-1, 0, half_dim)
51+
52+
sin_val = sin_node.meta.get("val")
53+
sin_half = graph_module.graph.call_function(
54+
exir_ops.edge.aten.slice_copy.Tensor,
55+
args=(sin_node, -1, 0, half_dim),
56+
)
57+
sin_half.meta["val"] = sin_val.narrow(-1, 0, half_dim)
58+
59+
weights = graph_module.graph.call_function(
60+
exir_ops.edge.aten.cat.default,
61+
args=([cos_half, sin_half], -1),
62+
)
63+
weights.meta["val"] = torch.cat(
64+
[cos_half.meta["val"], sin_half.meta["val"]], -1
65+
)
66+
67+
return weights
68+
69+
@staticmethod
70+
def _get_unsqueeze_dim(node: torch.fx.Node) -> int:
71+
"""Return the unsqueeze dim if node is an unsqueeze_copy, else -1."""
72+
if (
73+
node.op == "call_function"
74+
and node.target == exir_ops.edge.aten.unsqueeze_copy.default
75+
):
76+
return node.args[1]
77+
return -1
78+
79+
_BHSD_TO_BSHD_PERM = [0, 2, 1, 3]
80+
81+
@staticmethod
82+
def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node:
83+
"""If node is an unsqueeze_copy, return its input. Otherwise return node as-is."""
84+
if (
85+
node.op == "call_function"
86+
and node.target == exir_ops.edge.aten.unsqueeze_copy.default
87+
):
88+
return node.args[0]
89+
return node
90+
91+
@staticmethod
92+
def _trace_through_permute(node: torch.fx.Node) -> torch.fx.Node | None:
93+
"""If node is a permute_copy that swaps dims 1 and 2, return its input."""
94+
if (
95+
node.op == "call_function"
96+
and node.target == exir_ops.edge.aten.permute_copy.default
97+
and list(node.args[1]) == [0, 2, 1, 3]
98+
):
99+
return node.args[0]
100+
return None
101+
102+
def _get_layout(self, cos_unsqueezed: torch.fx.Node) -> str | None:
103+
"""Determine the tensor layout from the cos unsqueeze dimension.
104+
105+
Returns "BSHD", "BHSD", or None if the layout cannot be determined.
106+
"""
107+
unsqueeze_dim = self._get_unsqueeze_dim(cos_unsqueezed)
108+
if unsqueeze_dim == -1:
109+
return None
110+
ndim = len(cos_unsqueezed.meta["val"].shape)
111+
normalized = unsqueeze_dim if unsqueeze_dim >= 0 else unsqueeze_dim + ndim
112+
if normalized == 2:
113+
return "BSHD"
114+
if normalized == 1:
115+
return "BHSD"
116+
return None
117+
118+
def create_rope(
119+
self,
120+
graph_module: torch.fx.GraphModule,
121+
match: InternalMatch,
122+
):
123+
logger.debug(f"Matched RoPE subgraph: {match}")
124+
125+
# placeholder_nodes are in the order of the pattern's placeholder ops:
126+
# [x, cos_unsqueezed, sin_unsqueezed]
127+
x_node = match.placeholder_nodes[0]
128+
cos_unsqueezed = match.placeholder_nodes[1]
129+
sin_unsqueezed = match.placeholder_nodes[2]
130+
output_node = match.returning_nodes[0]
131+
132+
# xnn_define_rope expects NTHC (batch, tokens, heads, channels) input.
133+
# BSHD (unsqueeze_dim=2) maps directly to NTHC.
134+
# BHSD (unsqueeze_dim=1) requires tracing through the BSHD→BHSD permute
135+
# to recover the BSHD input, then re-permuting the output back to BHSD.
136+
layout = self._get_layout(cos_unsqueezed)
137+
if layout == "BSHD":
138+
rope_input = x_node
139+
elif layout == "BHSD":
140+
rope_input = self._trace_through_permute(x_node)
141+
if rope_input is None:
142+
logger.debug("Skipping RoPE fusion: BHSD but x is not a permute_copy")
143+
return
144+
else:
145+
logger.debug("Skipping RoPE fusion: unrecognized layout")
146+
return
147+
148+
cos_node = self._trace_through_unsqueeze(cos_unsqueezed)
149+
sin_node = self._trace_through_unsqueeze(sin_unsqueezed)
150+
151+
weights = self._build_weights(graph_module, cos_node, sin_node, output_node)
152+
153+
with graph_module.graph.inserting_before(output_node):
154+
rope_node = graph_module.graph.create_node(
155+
"call_function",
156+
torch.ops.xnnpack.rope.default,
157+
args=(rope_input, weights),
158+
)
159+
rope_node.meta["val"] = torch.empty_like(rope_input.meta["val"])
160+
161+
if layout == "BHSD":
162+
permute_node = graph_module.graph.call_function(
163+
exir_ops.edge.aten.permute_copy.default,
164+
args=(rope_node, self._BHSD_TO_BSHD_PERM),
165+
)
166+
permute_node.meta["val"] = rope_node.meta["val"].permute(
167+
self._BHSD_TO_BSHD_PERM
168+
)
169+
result_node = permute_node
170+
else:
171+
result_node = rope_node
172+
173+
output_node.replace_all_uses_with(result_node)
174+
graph_module.graph.eliminate_dead_code()
175+
176+
# override
177+
def call(self, graph_module: torch.fx.GraphModule):
178+
total_matches = 0
179+
total_fused = 0
180+
for pattern in rope.get_graphs():
181+
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
182+
matches = list(sm.match(graph_module.graph))
183+
total_matches += len(matches)
184+
for match in matches:
185+
try:
186+
self.create_rope(graph_module, match)
187+
total_fused += 1
188+
except Exception:
189+
logger.warning("Failed to fuse RoPE pattern", exc_info=True)
190+
graph_module.recompile()
191+
graph_module = super().call(graph_module).graph_module
192+
193+
return PassResult(graph_module, True)

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
op_prelu,
4444
op_quant_dequant,
4545
op_relu,
46+
op_rope,
4647
op_rsqrt,
4748
op_sigmoid,
4849
op_sin,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNRope,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
# Register the custom op used by the fusion pass. The fused node targets
22+
# this op after ConvertToRopePass replaces the decomposed HF RoPE subgraph.
23+
lib = torch.library.Library("xnnpack", "FRAGMENT")
24+
lib.define("rope(Tensor input, Tensor weights) -> Tensor")
25+
26+
27+
@torch.library.impl(lib, "rope", "CompositeExplicitAutograd")
28+
def rope_impl(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
29+
channels = input.shape[-1]
30+
half_c = channels // 2
31+
cos = weights[..., :half_c]
32+
sin = weights[..., half_c:]
33+
34+
x_real = input[..., :half_c]
35+
x_imag = input[..., half_c:]
36+
37+
out_real = x_real * cos - x_imag * sin
38+
out_imag = x_real * sin + x_imag * cos
39+
return torch.cat([out_real, out_imag], dim=-1)
40+
41+
42+
@torch.library.impl(lib, "rope", "Meta")
43+
def rope_meta(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
44+
return torch.empty_like(input)
45+
46+
47+
@register_node_visitor
48+
class RopeVisitor(NodeVisitor):
49+
target = "rope.default"
50+
51+
def __init__(self, *args) -> None:
52+
super().__init__(*args)
53+
54+
def define_node(
55+
self,
56+
node: torch.fx.Node,
57+
xnn_graph: XNNGraph,
58+
vals_to_ids: Dict[torch.fx.Node, int],
59+
debug_handle: int,
60+
) -> None:
61+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
62+
63+
input_id = vals_to_ids[get_input_node(node, 0)]
64+
weights_id = vals_to_ids[get_input_node(node, 1)]
65+
output_id = vals_to_ids[node]
66+
67+
ser_node = XNode(
68+
xnode_union=XNNRope(
69+
max_tokens=0,
70+
input_id=input_id,
71+
weights_id=weights_id,
72+
output_id=output_id,
73+
flags=0,
74+
),
75+
debug_handle=debug_handle,
76+
)
77+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/operators/op_squeeze.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1515
XNNGraph,
16+
XNNStaticExpandDims,
1617
XNNStaticReshape,
1718
XNode,
1819
)
@@ -98,46 +99,21 @@ def define_node(
9899
vals_to_ids: Dict[torch.fx.Node, int],
99100
debug_handle: int,
100101
) -> None:
101-
102-
dim = cast(int, node.args[1])
103-
check_or_raise(
104-
dim == -1 or dim == len(node.args[0].meta["val"].shape),
105-
"XNNPACK currently only supports unsqueezing in last dimension",
106-
)
107-
108102
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
109103
input_node = get_input_node(node, 0)
110104

111-
# input
112105
input_id = vals_to_ids[input_node]
113-
114-
# output
115106
output_id = vals_to_ids[node]
116107

117-
check_or_raise(
118-
"val" in input_node.meta,
119-
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
120-
)
121-
dynamic_shape = node.meta["val"].shape
122-
new_shape = []
123-
124-
num_dynamic_dims = 0
125-
for dim in dynamic_shape:
126-
if free_symbols(dim):
127-
num_dynamic_dims += 1
128-
new_shape.append(0)
129-
else:
130-
new_shape.append(dim)
131-
132-
check_or_raise(
133-
num_dynamic_dims <= 1,
134-
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
135-
)
108+
dim = cast(int, node.args[1])
109+
input_ndim = len(input_node.meta["val"].shape)
110+
if dim < 0:
111+
dim = input_ndim + 1 + dim
136112

137113
ser_node = XNode(
138-
xnode_union=XNNStaticReshape(
139-
num_dims=len(new_shape),
140-
new_shape=new_shape,
114+
xnode_union=XNNStaticExpandDims(
115+
num_new_axes=1,
116+
new_axes=[dim],
141117
input_id=input_id,
142118
output_id=output_id,
143119
flags=0,

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
SubConfig,
5656
TanhConfig,
5757
ToDimOrderCopyConfig,
58+
UnsqueezeCopyConfig,
5859
UpsampleBilinear2dConfig,
5960
)
6061
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -116,6 +117,7 @@
116117
SoftmaxConfig,
117118
SquareRootConfig,
118119
SubConfig,
120+
UnsqueezeCopyConfig,
119121
UpsampleBilinear2dConfig,
120122
# Quant/Dequant Op Configs
121123
QuantizedPerTensorConfig,

0 commit comments

Comments
 (0)