Skip to content

Commit aa7a4c4

Browse files
committed
Use XNN fused RoPE for HF-style RoPE patterns
1 parent 28f3cf3 commit aa7a4c4

14 files changed

Lines changed: 508 additions & 45 deletions

File tree

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Conv1dUnsqueezePass,
2020
)
2121
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
22+
from executorch.backends.xnnpack._passes.convert_to_rope import ConvertToRopePass
2223
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
2324
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
2425
ConvertToUpsampleBilinear2d,
@@ -75,6 +76,7 @@ def __init__(
7576
ConvertToLinearPass,
7677
PropagateCustomMetaPass,
7778
ConvertToSDPAPass,
79+
ConvertToRopePass,
7880
ConstPropPass,
7981
FuseBatchNormPass,
8082
DecomposeBatchNorm,
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11+
from executorch.backends.xnnpack.partition.graphs import rope
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
from torch.fx.passes.infra.pass_base import PassResult
15+
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class ConvertToRopePass(XNNPACKPass):
21+
def _build_weights(
22+
self,
23+
graph_module: torch.fx.GraphModule,
24+
cos_node: torch.fx.Node,
25+
sin_node: torch.fx.Node,
26+
output_node: torch.fx.Node,
27+
) -> torch.fx.Node:
28+
"""
29+
Construct the XNNPACK RoPE weights tensor from cos and sin inputs.
30+
31+
HF precompute_freqs_cis doubles the frequencies:
32+
cos/sin shape: [batch, seq, head_dim] where head_dim = 2 * (dim // 2)
33+
The first half and second half are identical.
34+
35+
XNNPACK expects weights: [tokens, channels] where:
36+
weights[:, :C/2] = cos values (unique half)
37+
weights[:, C/2:] = sin values (unique half)
38+
39+
We insert graph nodes to slice the unique halves and concatenate them.
40+
"""
41+
cos_val = cos_node.meta.get("val")
42+
head_dim = cos_val.shape[-1]
43+
half_dim = head_dim // 2
44+
45+
with graph_module.graph.inserting_before(output_node):
46+
cos_half = graph_module.graph.call_function(
47+
exir_ops.edge.aten.slice_copy.Tensor,
48+
args=(cos_node, -1, 0, half_dim),
49+
)
50+
cos_half.meta["val"] = cos_val.narrow(-1, 0, half_dim)
51+
52+
sin_val = sin_node.meta.get("val")
53+
sin_half = graph_module.graph.call_function(
54+
exir_ops.edge.aten.slice_copy.Tensor,
55+
args=(sin_node, -1, 0, half_dim),
56+
)
57+
sin_half.meta["val"] = sin_val.narrow(-1, 0, half_dim)
58+
59+
weights = graph_module.graph.call_function(
60+
exir_ops.edge.aten.cat.default,
61+
args=([cos_half, sin_half], -1),
62+
)
63+
weights.meta["val"] = torch.cat(
64+
[cos_half.meta["val"], sin_half.meta["val"]], -1
65+
)
66+
67+
return weights
68+
69+
@staticmethod
70+
def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node:
71+
"""If node is an unsqueeze_copy, return its input. Otherwise return node as-is."""
72+
if (
73+
node.op == "call_function"
74+
and node.target == exir_ops.edge.aten.unsqueeze_copy.default
75+
):
76+
return node.args[0]
77+
return node
78+
79+
def create_rope(
80+
self,
81+
graph_module: torch.fx.GraphModule,
82+
match: InternalMatch,
83+
):
84+
logger.debug(f"Matched RoPE subgraph: {match}")
85+
86+
# placeholder_nodes are in the order of the pattern's placeholder ops:
87+
# [x, cos_unsqueezed, sin_unsqueezed]
88+
x_node = match.placeholder_nodes[0]
89+
cos_unsqueezed = match.placeholder_nodes[1]
90+
sin_unsqueezed = match.placeholder_nodes[2]
91+
output_node = match.returning_nodes[0]
92+
93+
# Trace back through unsqueeze to get raw cos/sin for weight construction.
94+
# The pattern excludes unsqueeze ops (they're shared between q/k RoPE),
95+
# so the matched placeholders are the unsqueeze outputs.
96+
cos_node = self._trace_through_unsqueeze(cos_unsqueezed)
97+
sin_node = self._trace_through_unsqueeze(sin_unsqueezed)
98+
99+
weights = self._build_weights(graph_module, cos_node, sin_node, output_node)
100+
101+
with graph_module.graph.inserting_before(output_node):
102+
rope_node = graph_module.graph.create_node(
103+
"call_function",
104+
torch.ops.xnnpack.rope.default,
105+
args=(x_node, weights),
106+
)
107+
108+
rope_node.meta["val"] = torch.empty_like(x_node.meta["val"])
109+
110+
output_node.replace_all_uses_with(rope_node)
111+
graph_module.graph.eliminate_dead_code()
112+
113+
# override
114+
def call(self, graph_module: torch.fx.GraphModule):
115+
total_matches = 0
116+
total_fused = 0
117+
for pattern in rope.get_graphs():
118+
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
119+
matches = list(sm.match(graph_module.graph))
120+
total_matches += len(matches)
121+
for match in matches:
122+
try:
123+
self.create_rope(graph_module, match)
124+
total_fused += 1
125+
except Exception:
126+
logger.warning("Failed to fuse RoPE pattern", exc_info=True)
127+
graph_module.recompile()
128+
graph_module = super().call(graph_module).graph_module
129+
130+
return PassResult(graph_module, True)

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
op_prelu,
4444
op_quant_dequant,
4545
op_relu,
46+
op_rope,
4647
op_rsqrt,
4748
op_sigmoid,
4849
op_sin,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNGraph,
16+
XNNRope,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
# Register the custom op used by the fusion pass. The fused node targets
22+
# this op after ConvertToRopePass replaces the decomposed HF RoPE subgraph.
23+
lib = torch.library.Library("xnnpack", "FRAGMENT")
24+
lib.define("rope(Tensor input, Tensor weights) -> Tensor")
25+
26+
27+
@torch.library.impl(lib, "rope", "CompositeExplicitAutograd")
28+
def rope_impl(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
29+
channels = input.shape[-1]
30+
half_c = channels // 2
31+
cos = weights[..., :half_c]
32+
sin = weights[..., half_c:]
33+
34+
x_real = input[..., :half_c]
35+
x_imag = input[..., half_c:]
36+
37+
out_real = x_real * cos - x_imag * sin
38+
out_imag = x_real * sin + x_imag * cos
39+
return torch.cat([out_real, out_imag], dim=-1)
40+
41+
42+
@torch.library.impl(lib, "rope", "Meta")
43+
def rope_meta(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
44+
return torch.empty_like(input)
45+
46+
47+
@register_node_visitor
48+
class RopeVisitor(NodeVisitor):
49+
target = "rope.default"
50+
51+
def __init__(self, *args) -> None:
52+
super().__init__(*args)
53+
54+
def define_node(
55+
self,
56+
node: torch.fx.Node,
57+
xnn_graph: XNNGraph,
58+
vals_to_ids: Dict[torch.fx.Node, int],
59+
debug_handle: int,
60+
) -> None:
61+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
62+
63+
input_id = vals_to_ids[get_input_node(node, 0)]
64+
weights_id = vals_to_ids[get_input_node(node, 1)]
65+
output_id = vals_to_ids[node]
66+
67+
ser_node = XNode(
68+
xnode_union=XNNRope(
69+
max_tokens=0,
70+
input_id=input_id,
71+
weights_id=weights_id,
72+
output_id=output_id,
73+
flags=0,
74+
),
75+
debug_handle=debug_handle,
76+
)
77+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/operators/op_squeeze.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1515
XNNGraph,
16+
XNNStaticExpandDims,
1617
XNNStaticReshape,
1718
XNode,
1819
)
@@ -98,46 +99,21 @@ def define_node(
9899
vals_to_ids: Dict[torch.fx.Node, int],
99100
debug_handle: int,
100101
) -> None:
101-
102-
dim = cast(int, node.args[1])
103-
check_or_raise(
104-
dim == -1 or dim == len(node.args[0].meta["val"].shape),
105-
"XNNPACK currently only supports unsqueezing in last dimension",
106-
)
107-
108102
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
109103
input_node = get_input_node(node, 0)
110104

111-
# input
112105
input_id = vals_to_ids[input_node]
113-
114-
# output
115106
output_id = vals_to_ids[node]
116107

117-
check_or_raise(
118-
"val" in input_node.meta,
119-
"Missing val in tensor metadata for input when serializing XNNStaticReshape node",
120-
)
121-
dynamic_shape = node.meta["val"].shape
122-
new_shape = []
123-
124-
num_dynamic_dims = 0
125-
for dim in dynamic_shape:
126-
if free_symbols(dim):
127-
num_dynamic_dims += 1
128-
new_shape.append(0)
129-
else:
130-
new_shape.append(dim)
131-
132-
check_or_raise(
133-
num_dynamic_dims <= 1,
134-
"XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
135-
)
108+
dim = cast(int, node.args[1])
109+
input_ndim = len(input_node.meta["val"].shape)
110+
if dim < 0:
111+
dim = input_ndim + 1 + dim
136112

137113
ser_node = XNode(
138-
xnode_union=XNNStaticReshape(
139-
num_dims=len(new_shape),
140-
new_shape=new_shape,
114+
xnode_union=XNNStaticExpandDims(
115+
num_new_axes=1,
116+
new_axes=[dim],
141117
input_id=input_id,
142118
output_id=output_id,
143119
flags=0,

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
SubConfig,
5656
TanhConfig,
5757
ToDimOrderCopyConfig,
58+
UnsqueezeCopyConfig,
5859
UpsampleBilinear2dConfig,
5960
)
6061
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -116,6 +117,7 @@
116117
SoftmaxConfig,
117118
SquareRootConfig,
118119
SubConfig,
120+
UnsqueezeCopyConfig,
119121
UpsampleBilinear2dConfig,
120122
# Quant/Dequant Op Configs
121123
QuantizedPerTensorConfig,

0 commit comments

Comments
 (0)