|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | + |
| 9 | +import torch |
| 10 | +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass |
| 11 | +from executorch.backends.xnnpack.partition.graphs import rope |
| 12 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 13 | + |
| 14 | +from torch.fx.passes.infra.pass_base import PassResult |
| 15 | +from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class ConvertToRopePass(XNNPACKPass): |
| 21 | + def _build_weights( |
| 22 | + self, |
| 23 | + graph_module: torch.fx.GraphModule, |
| 24 | + cos_node: torch.fx.Node, |
| 25 | + sin_node: torch.fx.Node, |
| 26 | + output_node: torch.fx.Node, |
| 27 | + ) -> torch.fx.Node: |
| 28 | + """ |
| 29 | + Construct the XNNPACK RoPE weights tensor from cos and sin inputs. |
| 30 | +
|
| 31 | + HF precompute_freqs_cis doubles the frequencies: |
| 32 | + cos/sin shape: [batch, seq, head_dim] where head_dim = 2 * (dim // 2) |
| 33 | + The first half and second half are identical. |
| 34 | +
|
| 35 | + XNNPACK expects weights: [tokens, channels] where: |
| 36 | + weights[:, :C/2] = cos values (unique half) |
| 37 | + weights[:, C/2:] = sin values (unique half) |
| 38 | +
|
| 39 | + We insert graph nodes to slice the unique halves and concatenate them. |
| 40 | + """ |
| 41 | + cos_val = cos_node.meta.get("val") |
| 42 | + head_dim = cos_val.shape[-1] |
| 43 | + half_dim = head_dim // 2 |
| 44 | + |
| 45 | + with graph_module.graph.inserting_before(output_node): |
| 46 | + cos_half = graph_module.graph.call_function( |
| 47 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 48 | + args=(cos_node, -1, 0, half_dim), |
| 49 | + ) |
| 50 | + cos_half.meta["val"] = cos_val.narrow(-1, 0, half_dim) |
| 51 | + |
| 52 | + sin_val = sin_node.meta.get("val") |
| 53 | + sin_half = graph_module.graph.call_function( |
| 54 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 55 | + args=(sin_node, -1, 0, half_dim), |
| 56 | + ) |
| 57 | + sin_half.meta["val"] = sin_val.narrow(-1, 0, half_dim) |
| 58 | + |
| 59 | + weights = graph_module.graph.call_function( |
| 60 | + exir_ops.edge.aten.cat.default, |
| 61 | + args=([cos_half, sin_half], -1), |
| 62 | + ) |
| 63 | + weights.meta["val"] = torch.cat( |
| 64 | + [cos_half.meta["val"], sin_half.meta["val"]], -1 |
| 65 | + ) |
| 66 | + |
| 67 | + return weights |
| 68 | + |
| 69 | + @staticmethod |
| 70 | + def _get_unsqueeze_dim(node: torch.fx.Node) -> int: |
| 71 | + """Return the unsqueeze dim if node is an unsqueeze_copy, else -1.""" |
| 72 | + if ( |
| 73 | + node.op == "call_function" |
| 74 | + and node.target == exir_ops.edge.aten.unsqueeze_copy.default |
| 75 | + ): |
| 76 | + return node.args[1] |
| 77 | + return -1 |
| 78 | + |
| 79 | + _BHSD_TO_BSHD_PERM = [0, 2, 1, 3] |
| 80 | + |
| 81 | + @staticmethod |
| 82 | + def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node: |
| 83 | + """If node is an unsqueeze_copy, return its input. Otherwise return node as-is.""" |
| 84 | + if ( |
| 85 | + node.op == "call_function" |
| 86 | + and node.target == exir_ops.edge.aten.unsqueeze_copy.default |
| 87 | + ): |
| 88 | + return node.args[0] |
| 89 | + return node |
| 90 | + |
| 91 | + @staticmethod |
| 92 | + def _trace_through_permute(node: torch.fx.Node) -> torch.fx.Node | None: |
| 93 | + """If node is a permute_copy that swaps dims 1 and 2, return its input.""" |
| 94 | + if ( |
| 95 | + node.op == "call_function" |
| 96 | + and node.target == exir_ops.edge.aten.permute_copy.default |
| 97 | + and list(node.args[1]) == [0, 2, 1, 3] |
| 98 | + ): |
| 99 | + return node.args[0] |
| 100 | + return None |
| 101 | + |
| 102 | + def _get_layout(self, cos_unsqueezed: torch.fx.Node) -> str | None: |
| 103 | + """Determine the tensor layout from the cos unsqueeze dimension. |
| 104 | +
|
| 105 | + Returns "BSHD", "BHSD", or None if the layout cannot be determined. |
| 106 | + """ |
| 107 | + unsqueeze_dim = self._get_unsqueeze_dim(cos_unsqueezed) |
| 108 | + if unsqueeze_dim == -1: |
| 109 | + return None |
| 110 | + ndim = len(cos_unsqueezed.meta["val"].shape) |
| 111 | + normalized = unsqueeze_dim if unsqueeze_dim >= 0 else unsqueeze_dim + ndim |
| 112 | + if normalized == 2: |
| 113 | + return "BSHD" |
| 114 | + if normalized == 1: |
| 115 | + return "BHSD" |
| 116 | + return None |
| 117 | + |
| 118 | + def create_rope( |
| 119 | + self, |
| 120 | + graph_module: torch.fx.GraphModule, |
| 121 | + match: InternalMatch, |
| 122 | + ): |
| 123 | + logger.debug(f"Matched RoPE subgraph: {match}") |
| 124 | + |
| 125 | + # placeholder_nodes are in the order of the pattern's placeholder ops: |
| 126 | + # [x, cos_unsqueezed, sin_unsqueezed] |
| 127 | + x_node = match.placeholder_nodes[0] |
| 128 | + cos_unsqueezed = match.placeholder_nodes[1] |
| 129 | + sin_unsqueezed = match.placeholder_nodes[2] |
| 130 | + output_node = match.returning_nodes[0] |
| 131 | + |
| 132 | + # xnn_define_rope expects NTHC (batch, tokens, heads, channels) input. |
| 133 | + # BSHD (unsqueeze_dim=2) maps directly to NTHC. |
| 134 | + # BHSD (unsqueeze_dim=1) requires tracing through the BSHD→BHSD permute |
| 135 | + # to recover the BSHD input, then re-permuting the output back to BHSD. |
| 136 | + layout = self._get_layout(cos_unsqueezed) |
| 137 | + if layout == "BSHD": |
| 138 | + rope_input = x_node |
| 139 | + elif layout == "BHSD": |
| 140 | + rope_input = self._trace_through_permute(x_node) |
| 141 | + if rope_input is None: |
| 142 | + logger.debug("Skipping RoPE fusion: BHSD but x is not a permute_copy") |
| 143 | + return |
| 144 | + else: |
| 145 | + logger.debug("Skipping RoPE fusion: unrecognized layout") |
| 146 | + return |
| 147 | + |
| 148 | + cos_node = self._trace_through_unsqueeze(cos_unsqueezed) |
| 149 | + sin_node = self._trace_through_unsqueeze(sin_unsqueezed) |
| 150 | + |
| 151 | + weights = self._build_weights(graph_module, cos_node, sin_node, output_node) |
| 152 | + |
| 153 | + with graph_module.graph.inserting_before(output_node): |
| 154 | + rope_node = graph_module.graph.create_node( |
| 155 | + "call_function", |
| 156 | + torch.ops.xnnpack.rope.default, |
| 157 | + args=(rope_input, weights), |
| 158 | + ) |
| 159 | + rope_node.meta["val"] = torch.empty_like(rope_input.meta["val"]) |
| 160 | + |
| 161 | + if layout == "BHSD": |
| 162 | + permute_node = graph_module.graph.call_function( |
| 163 | + exir_ops.edge.aten.permute_copy.default, |
| 164 | + args=(rope_node, self._BHSD_TO_BSHD_PERM), |
| 165 | + ) |
| 166 | + permute_node.meta["val"] = rope_node.meta["val"].permute( |
| 167 | + self._BHSD_TO_BSHD_PERM |
| 168 | + ) |
| 169 | + result_node = permute_node |
| 170 | + else: |
| 171 | + result_node = rope_node |
| 172 | + |
| 173 | + output_node.replace_all_uses_with(result_node) |
| 174 | + graph_module.graph.eliminate_dead_code() |
| 175 | + |
| 176 | + # override |
| 177 | + def call(self, graph_module: torch.fx.GraphModule): |
| 178 | + total_matches = 0 |
| 179 | + total_fused = 0 |
| 180 | + for pattern in rope.get_graphs(): |
| 181 | + sm = SubgraphMatcher(pattern.graph, ignore_literals=True) |
| 182 | + matches = list(sm.match(graph_module.graph)) |
| 183 | + total_matches += len(matches) |
| 184 | + for match in matches: |
| 185 | + try: |
| 186 | + self.create_rope(graph_module, match) |
| 187 | + total_fused += 1 |
| 188 | + except Exception: |
| 189 | + logger.warning("Failed to fuse RoPE pattern", exc_info=True) |
| 190 | + graph_module.recompile() |
| 191 | + graph_module = super().call(graph_module).graph_module |
| 192 | + |
| 193 | + return PassResult(graph_module, True) |
0 commit comments