|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import enum |
| 8 | +import logging |
| 9 | + |
| 10 | +import torch |
| 11 | +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass |
| 12 | +from executorch.backends.xnnpack.partition.graphs import rope |
| 13 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 14 | + |
| 15 | +from torch.fx.passes.infra.pass_base import PassResult |
| 16 | +from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher |
| 17 | + |
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | + |
| 21 | +class _Layout(enum.Enum): |
| 22 | + BSHD = enum.auto() |
| 23 | + BHSD = enum.auto() |
| 24 | + |
| 25 | + |
| 26 | +class ConvertToRopePass(XNNPACKPass): |
| 27 | + _BHSD_TO_BSHD_PERM = [0, 2, 1, 3] |
| 28 | + |
| 29 | + def _build_weights( |
| 30 | + self, |
| 31 | + graph_module: torch.fx.GraphModule, |
| 32 | + cos_node: torch.fx.Node, |
| 33 | + sin_node: torch.fx.Node, |
| 34 | + output_node: torch.fx.Node, |
| 35 | + ) -> torch.fx.Node: |
| 36 | + """ |
| 37 | + Construct the XNNPACK RoPE weights tensor from cos and sin inputs. |
| 38 | +
|
| 39 | + The most common HF RoPE pattern doubles the frequencies: |
| 40 | + cos/sin shape: [batch, seq, head_dim] where head_dim = 2 * (dim // 2) |
| 41 | + The first half and second half are identical. |
| 42 | +
|
| 43 | + XNNPACK expects weights: [tokens, channels] where: |
| 44 | + weights[:, :C/2] = cos values (unique half) |
| 45 | + weights[:, C/2:] = sin values (unique half) |
| 46 | +
|
| 47 | + We insert graph nodes to slice the unique halves and concatenate them. |
| 48 | +
|
| 49 | + Note that this assumes that cos and sin come from a cat([x, x]) node for |
| 50 | + this to be sound. We check this in the pass. |
| 51 | + """ |
| 52 | + head_dim = cos_node.meta["val"].shape[-1] |
| 53 | + half_dim = head_dim // 2 |
| 54 | + |
| 55 | + with graph_module.graph.inserting_before(output_node): |
| 56 | + cos_half = graph_module.graph.call_function( |
| 57 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 58 | + args=(cos_node, -1, 0, half_dim), |
| 59 | + ) |
| 60 | + sin_half = graph_module.graph.call_function( |
| 61 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 62 | + args=(sin_node, -1, 0, half_dim), |
| 63 | + ) |
| 64 | + weights = graph_module.graph.call_function( |
| 65 | + exir_ops.edge.aten.cat.default, |
| 66 | + args=([cos_half, sin_half], -1), |
| 67 | + ) |
| 68 | + |
| 69 | + return weights |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def _trace_through_unsqueezes(node: torch.fx.Node) -> torch.fx.Node: |
| 73 | + """Walk backwards through consecutive unsqueeze_copy ops to find the source.""" |
| 74 | + current = node |
| 75 | + while ( |
| 76 | + current.op == "call_function" |
| 77 | + and current.target == exir_ops.edge.aten.unsqueeze_copy.default |
| 78 | + ): |
| 79 | + current = current.args[0] |
| 80 | + return current |
| 81 | + |
| 82 | + @staticmethod |
| 83 | + def _find_trig_source(node: torch.fx.Node) -> torch.fx.Node | None: |
| 84 | + """Walk backwards through unsqueeze_copy ops to find cos/sin op.""" |
| 85 | + current = node |
| 86 | + for _ in range(10): |
| 87 | + if current.op != "call_function": |
| 88 | + return None |
| 89 | + if current.target in ( |
| 90 | + exir_ops.edge.aten.cos.default, |
| 91 | + exir_ops.edge.aten.sin.default, |
| 92 | + ): |
| 93 | + return current |
| 94 | + if current.target == exir_ops.edge.aten.unsqueeze_copy.default: |
| 95 | + current = current.args[0] |
| 96 | + continue |
| 97 | + return None |
| 98 | + return None |
| 99 | + |
| 100 | + @classmethod |
| 101 | + def _is_doubled_cat(cls, trig_node: torch.fx.Node) -> bool: |
| 102 | + """Check that a cos/sin node's input is cat(x, x) with identical args.""" |
| 103 | + cat_node = trig_node.args[0] |
| 104 | + if ( |
| 105 | + cat_node.op != "call_function" |
| 106 | + or cat_node.target != exir_ops.edge.aten.cat.default |
| 107 | + ): |
| 108 | + return False |
| 109 | + tensors = cat_node.args[0] |
| 110 | + return len(tensors) == 2 and tensors[0] is tensors[1] |
| 111 | + |
| 112 | + @classmethod |
| 113 | + def _has_doubled_freqs( |
| 114 | + cls, |
| 115 | + cos_unsqueezed: torch.fx.Node, |
| 116 | + sin_unsqueezed: torch.fx.Node, |
| 117 | + ) -> bool: |
| 118 | + """Verify that cos/sin frequencies are doubled (first half == second half). |
| 119 | +
|
| 120 | + Traces back through unsqueeze_copy ops to find the cos/sin producer, |
| 121 | + then verifies its input is cat(x, x) where both args are the same |
| 122 | + node — a structural proof that the first and second halves are identical. |
| 123 | + """ |
| 124 | + cos_trig = cls._find_trig_source(cos_unsqueezed) |
| 125 | + sin_trig = cls._find_trig_source(sin_unsqueezed) |
| 126 | + |
| 127 | + if cos_trig is None or sin_trig is None: |
| 128 | + return False |
| 129 | + |
| 130 | + return cls._is_doubled_cat(cos_trig) and cls._is_doubled_cat(sin_trig) |
| 131 | + |
| 132 | + @staticmethod |
| 133 | + def _trace_through_permute(node: torch.fx.Node) -> torch.fx.Node | None: |
| 134 | + """If node is a permute_copy that swaps dims 1 and 2, return its input.""" |
| 135 | + if ( |
| 136 | + node.op == "call_function" |
| 137 | + and node.target == exir_ops.edge.aten.permute_copy.default |
| 138 | + and list(node.args[1]) == [0, 2, 1, 3] |
| 139 | + ): |
| 140 | + return node.args[0] |
| 141 | + return None |
| 142 | + |
| 143 | + @staticmethod |
| 144 | + def _get_layout(cos_unsqueezed: torch.fx.Node) -> _Layout | None: |
| 145 | + """Determine the tensor layout from the cos unsqueeze dimension.""" |
| 146 | + if not ( |
| 147 | + cos_unsqueezed.op == "call_function" |
| 148 | + and cos_unsqueezed.target == exir_ops.edge.aten.unsqueeze_copy.default |
| 149 | + ): |
| 150 | + return None |
| 151 | + unsqueeze_dim = cos_unsqueezed.args[1] |
| 152 | + ndim = len(cos_unsqueezed.meta["val"].shape) |
| 153 | + normalized = unsqueeze_dim if unsqueeze_dim >= 0 else unsqueeze_dim + ndim |
| 154 | + if normalized == 2: |
| 155 | + return _Layout.BSHD |
| 156 | + if normalized == 1: |
| 157 | + return _Layout.BHSD |
| 158 | + return None |
| 159 | + |
| 160 | + def create_rope( |
| 161 | + self, |
| 162 | + graph_module: torch.fx.GraphModule, |
| 163 | + match: InternalMatch, |
| 164 | + ): |
| 165 | + logger.debug(f"Matched RoPE subgraph: {match}") |
| 166 | + |
| 167 | + # placeholder_nodes are in the order of the pattern's placeholder ops: |
| 168 | + # [x, cos_unsqueezed, sin_unsqueezed] |
| 169 | + x_node = match.placeholder_nodes[0] |
| 170 | + cos_unsqueezed = match.placeholder_nodes[1] |
| 171 | + sin_unsqueezed = match.placeholder_nodes[2] |
| 172 | + output_node = match.returning_nodes[0] |
| 173 | + |
| 174 | + # xnn_define_rope expects NTHC (batch, tokens, heads, channels) input. |
| 175 | + # BSHD (unsqueeze_dim=2) maps directly to NTHC. |
| 176 | + # BHSD (unsqueeze_dim=1) requires tracing through the BSHD→BHSD permute |
| 177 | + # to recover the BSHD input, then re-permuting the output back to BHSD. |
| 178 | + layout = self._get_layout(cos_unsqueezed) |
| 179 | + if layout == _Layout.BSHD: |
| 180 | + rope_input = x_node |
| 181 | + elif layout == _Layout.BHSD: |
| 182 | + rope_input = self._trace_through_permute(x_node) |
| 183 | + if rope_input is None: |
| 184 | + logger.debug("Skipping RoPE fusion: BHSD but x is not a permute_copy") |
| 185 | + return |
| 186 | + else: |
| 187 | + logger.debug("Skipping RoPE fusion: unrecognized layout") |
| 188 | + return |
| 189 | + |
| 190 | + cos_node = self._trace_through_unsqueezes(cos_unsqueezed) |
| 191 | + sin_node = self._trace_through_unsqueezes(sin_unsqueezed) |
| 192 | + |
| 193 | + if not self._has_doubled_freqs(cos_unsqueezed, sin_unsqueezed): |
| 194 | + logger.debug("Skipping RoPE fusion: cannot verify doubled frequencies") |
| 195 | + return |
| 196 | + |
| 197 | + weights = self._build_weights(graph_module, cos_node, sin_node, output_node) |
| 198 | + |
| 199 | + with graph_module.graph.inserting_before(output_node): |
| 200 | + rope_node = graph_module.graph.create_node( |
| 201 | + "call_function", |
| 202 | + torch.ops.xnnpack.rope.default, |
| 203 | + args=(rope_input, weights), |
| 204 | + ) |
| 205 | + |
| 206 | + if layout == _Layout.BHSD: |
| 207 | + permute_node = graph_module.graph.call_function( |
| 208 | + exir_ops.edge.aten.permute_copy.default, |
| 209 | + args=(rope_node, self._BHSD_TO_BSHD_PERM), |
| 210 | + ) |
| 211 | + result_node = permute_node |
| 212 | + else: |
| 213 | + result_node = rope_node |
| 214 | + |
| 215 | + output_node.replace_all_uses_with(result_node) |
| 216 | + graph_module.graph.eliminate_dead_code() |
| 217 | + |
| 218 | + # override |
| 219 | + def call(self, graph_module: torch.fx.GraphModule): |
| 220 | + total_matches = 0 |
| 221 | + total_fused = 0 |
| 222 | + for pattern in rope.get_graphs(): |
| 223 | + sm = SubgraphMatcher(pattern.graph, ignore_literals=True) |
| 224 | + matches = list(sm.match(graph_module.graph)) |
| 225 | + total_matches += len(matches) |
| 226 | + for match in matches: |
| 227 | + try: |
| 228 | + self.create_rope(graph_module, match) |
| 229 | + total_fused += 1 |
| 230 | + except Exception: |
| 231 | + logger.warning("Failed to fuse RoPE pattern", exc_info=True) |
| 232 | + graph_module.recompile() |
| 233 | + graph_module = super().call(graph_module).graph_module |
| 234 | + |
| 235 | + return PassResult(graph_module, True) |
0 commit comments