|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | + |
| 9 | +import torch |
| 10 | +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass |
| 11 | +from executorch.backends.xnnpack.partition.graphs import rope |
| 12 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 13 | + |
| 14 | +from torch.fx.passes.infra.pass_base import PassResult |
| 15 | +from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class ConvertToRopePass(XNNPACKPass): |
| 21 | + def _build_weights( |
| 22 | + self, |
| 23 | + graph_module: torch.fx.GraphModule, |
| 24 | + cos_node: torch.fx.Node, |
| 25 | + sin_node: torch.fx.Node, |
| 26 | + output_node: torch.fx.Node, |
| 27 | + ) -> torch.fx.Node: |
| 28 | + """ |
| 29 | + Construct the XNNPACK RoPE weights tensor from cos and sin inputs. |
| 30 | +
|
| 31 | + HF precompute_freqs_cis doubles the frequencies: |
| 32 | + cos/sin shape: [batch, seq, head_dim] where head_dim = 2 * (dim // 2) |
| 33 | + The first half and second half are identical. |
| 34 | +
|
| 35 | + XNNPACK expects weights: [tokens, channels] where: |
| 36 | + weights[:, :C/2] = cos values (unique half) |
| 37 | + weights[:, C/2:] = sin values (unique half) |
| 38 | +
|
| 39 | + We insert graph nodes to slice the unique halves and concatenate them. |
| 40 | + """ |
| 41 | + cos_val = cos_node.meta.get("val") |
| 42 | + head_dim = cos_val.shape[-1] |
| 43 | + half_dim = head_dim // 2 |
| 44 | + |
| 45 | + with graph_module.graph.inserting_before(output_node): |
| 46 | + cos_half = graph_module.graph.call_function( |
| 47 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 48 | + args=(cos_node, -1, 0, half_dim), |
| 49 | + ) |
| 50 | + cos_half.meta["val"] = cos_val.narrow(-1, 0, half_dim) |
| 51 | + |
| 52 | + sin_val = sin_node.meta.get("val") |
| 53 | + sin_half = graph_module.graph.call_function( |
| 54 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 55 | + args=(sin_node, -1, 0, half_dim), |
| 56 | + ) |
| 57 | + sin_half.meta["val"] = sin_val.narrow(-1, 0, half_dim) |
| 58 | + |
| 59 | + weights = graph_module.graph.call_function( |
| 60 | + exir_ops.edge.aten.cat.default, |
| 61 | + args=([cos_half, sin_half], -1), |
| 62 | + ) |
| 63 | + weights.meta["val"] = torch.cat( |
| 64 | + [cos_half.meta["val"], sin_half.meta["val"]], -1 |
| 65 | + ) |
| 66 | + |
| 67 | + return weights |
| 68 | + |
| 69 | + @staticmethod |
| 70 | + def _trace_through_unsqueeze(node: torch.fx.Node) -> torch.fx.Node: |
| 71 | + """If node is an unsqueeze_copy, return its input. Otherwise return node as-is.""" |
| 72 | + if ( |
| 73 | + node.op == "call_function" |
| 74 | + and node.target == exir_ops.edge.aten.unsqueeze_copy.default |
| 75 | + ): |
| 76 | + return node.args[0] |
| 77 | + return node |
| 78 | + |
| 79 | + def create_rope( |
| 80 | + self, |
| 81 | + graph_module: torch.fx.GraphModule, |
| 82 | + match: InternalMatch, |
| 83 | + ): |
| 84 | + logger.debug(f"Matched RoPE subgraph: {match}") |
| 85 | + |
| 86 | + # placeholder_nodes are in the order of the pattern's placeholder ops: |
| 87 | + # [x, cos_unsqueezed, sin_unsqueezed] |
| 88 | + x_node = match.placeholder_nodes[0] |
| 89 | + cos_unsqueezed = match.placeholder_nodes[1] |
| 90 | + sin_unsqueezed = match.placeholder_nodes[2] |
| 91 | + output_node = match.returning_nodes[0] |
| 92 | + |
| 93 | + # Trace back through unsqueeze to get raw cos/sin for weight construction. |
| 94 | + # The pattern excludes unsqueeze ops (they're shared between q/k RoPE), |
| 95 | + # so the matched placeholders are the unsqueeze outputs. |
| 96 | + cos_node = self._trace_through_unsqueeze(cos_unsqueezed) |
| 97 | + sin_node = self._trace_through_unsqueeze(sin_unsqueezed) |
| 98 | + |
| 99 | + weights = self._build_weights(graph_module, cos_node, sin_node, output_node) |
| 100 | + |
| 101 | + with graph_module.graph.inserting_before(output_node): |
| 102 | + rope_node = graph_module.graph.create_node( |
| 103 | + "call_function", |
| 104 | + torch.ops.xnnpack.rope.default, |
| 105 | + args=(x_node, weights), |
| 106 | + ) |
| 107 | + |
| 108 | + rope_node.meta["val"] = torch.empty_like(x_node.meta["val"]) |
| 109 | + |
| 110 | + output_node.replace_all_uses_with(rope_node) |
| 111 | + graph_module.graph.eliminate_dead_code() |
| 112 | + |
| 113 | + # override |
| 114 | + def call(self, graph_module: torch.fx.GraphModule): |
| 115 | + total_matches = 0 |
| 116 | + total_fused = 0 |
| 117 | + for pattern in rope.get_graphs(): |
| 118 | + sm = SubgraphMatcher(pattern.graph, ignore_literals=True) |
| 119 | + matches = list(sm.match(graph_module.graph)) |
| 120 | + total_matches += len(matches) |
| 121 | + for match in matches: |
| 122 | + try: |
| 123 | + self.create_rope(graph_module, match) |
| 124 | + total_fused += 1 |
| 125 | + except Exception: |
| 126 | + logger.warning("Failed to fuse RoPE pattern", exc_info=True) |
| 127 | + graph_module.recompile() |
| 128 | + graph_module = super().call(graph_module).graph_module |
| 129 | + |
| 130 | + return PassResult(graph_module, True) |
0 commit comments