diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 45560124f57..f1b9a3605fe 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -19,6 +19,7 @@ Conv1dUnsqueezePass, ) from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass +from executorch.backends.xnnpack._passes.convert_to_rope import ConvertToRopePass from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import ( ConvertToUpsampleBilinear2d, @@ -75,6 +76,7 @@ def __init__( ConvertToLinearPass, PropagateCustomMetaPass, ConvertToSDPAPass, + ConvertToRopePass, ConstPropPass, FuseBatchNormPass, DecomposeBatchNorm, diff --git a/backends/xnnpack/_passes/convert_to_rope.py b/backends/xnnpack/_passes/convert_to_rope.py new file mode 100644 index 00000000000..7eaf342c1fb --- /dev/null +++ b/backends/xnnpack/_passes/convert_to_rope.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import enum +import logging + +import torch +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.backends.xnnpack.partition.graphs import rope +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher + +logger = logging.getLogger(__name__) + + +class _Layout(enum.Enum): + BSHD = enum.auto() + BHSD = enum.auto() + + +class ConvertToRopePass(XNNPACKPass): + _BHSD_TO_BSHD_PERM = [0, 2, 1, 3] + + def _build_weights( + self, + graph_module: torch.fx.GraphModule, + cos_node: torch.fx.Node, + sin_node: torch.fx.Node, + output_node: torch.fx.Node, + ) -> torch.fx.Node: + """ + Construct the XNNPACK RoPE weights tensor from cos and sin inputs. + + The most common HF RoPE pattern doubles the frequencies: + cos/sin shape: [batch, seq, head_dim] where head_dim = 2 * (dim // 2) + The first half and second half are identical. + + XNNPACK expects weights: [tokens, channels] where: + weights[:, :C/2] = cos values (unique half) + weights[:, C/2:] = sin values (unique half) + + We insert graph nodes to slice the unique halves and concatenate them. + + Note that this assumes that cos and sin come from a cat([x, x]) node for + this to be sound. We check this in the pass. + """ + head_dim = cos_node.meta["val"].shape[-1] + half_dim = head_dim // 2 + + with graph_module.graph.inserting_before(output_node): + cos_half = graph_module.graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, + args=(cos_node, -1, 0, half_dim), + ) + sin_half = graph_module.graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, + args=(sin_node, -1, 0, half_dim), + ) + weights = graph_module.graph.call_function( + exir_ops.edge.aten.cat.default, + args=([cos_half, sin_half], -1), + ) + + return weights + + @staticmethod + def _trace_through_unsqueezes(node: torch.fx.Node) -> torch.fx.Node: + """Walk backwards through consecutive unsqueeze_copy ops to find the source.""" + current = node + while ( + current.op == "call_function" + and current.target == exir_ops.edge.aten.unsqueeze_copy.default + ): + current = current.args[0] + return current + + @staticmethod + def _find_trig_source(node: torch.fx.Node) -> torch.fx.Node | None: + """Walk backwards through unsqueeze_copy ops to find cos/sin op.""" + current = node + for _ in range(10): + if current.op != "call_function": + return None + if current.target in ( + exir_ops.edge.aten.cos.default, + exir_ops.edge.aten.sin.default, + ): + return current + if current.target == exir_ops.edge.aten.unsqueeze_copy.default: + current = current.args[0] + continue + return None + return None + + @classmethod + def _is_doubled_cat(cls, trig_node: torch.fx.Node) -> bool: + """Check that a cos/sin node's input is cat(x, x) with identical args.""" + cat_node = trig_node.args[0] + if ( + cat_node.op != "call_function" + or cat_node.target != exir_ops.edge.aten.cat.default + ): + return False + tensors = cat_node.args[0] + return len(tensors) == 2 and tensors[0] is tensors[1] + + @classmethod + def _has_doubled_freqs( + cls, + cos_unsqueezed: torch.fx.Node, + sin_unsqueezed: torch.fx.Node, + ) -> bool: + """Verify that cos/sin frequencies are doubled (first half == second half). + + Traces back through unsqueeze_copy ops to find the cos/sin producer, + then verifies its input is cat(x, x) where both args are the same + node — a structural proof that the first and second halves are identical. + """ + cos_trig = cls._find_trig_source(cos_unsqueezed) + sin_trig = cls._find_trig_source(sin_unsqueezed) + + if cos_trig is None or sin_trig is None: + return False + + return cls._is_doubled_cat(cos_trig) and cls._is_doubled_cat(sin_trig) + + @staticmethod + def _trace_through_permute(node: torch.fx.Node) -> torch.fx.Node | None: + """If node is a permute_copy that swaps dims 1 and 2, return its input.""" + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.permute_copy.default + and list(node.args[1]) == [0, 2, 1, 3] + ): + return node.args[0] + return None + + @staticmethod + def _get_layout(cos_unsqueezed: torch.fx.Node) -> _Layout | None: + """Determine the tensor layout from the cos unsqueeze dimension.""" + if not ( + cos_unsqueezed.op == "call_function" + and cos_unsqueezed.target == exir_ops.edge.aten.unsqueeze_copy.default + ): + return None + unsqueeze_dim = cos_unsqueezed.args[1] + ndim = len(cos_unsqueezed.meta["val"].shape) + normalized = unsqueeze_dim if unsqueeze_dim >= 0 else unsqueeze_dim + ndim + if normalized == 2: + return _Layout.BSHD + if normalized == 1: + return _Layout.BHSD + return None + + def create_rope( + self, + graph_module: torch.fx.GraphModule, + match: InternalMatch, + ): + logger.debug(f"Matched RoPE subgraph: {match}") + + # placeholder_nodes are in the order of the pattern's placeholder ops: + # [x, cos_unsqueezed, sin_unsqueezed] + x_node = match.placeholder_nodes[0] + cos_unsqueezed = match.placeholder_nodes[1] + sin_unsqueezed = match.placeholder_nodes[2] + output_node = match.returning_nodes[0] + + # xnn_define_rope expects NTHC (batch, tokens, heads, channels) input. + # BSHD (unsqueeze_dim=2) maps directly to NTHC. + # BHSD (unsqueeze_dim=1) requires tracing through the BSHD→BHSD permute + # to recover the BSHD input, then re-permuting the output back to BHSD. + layout = self._get_layout(cos_unsqueezed) + if layout == _Layout.BSHD: + rope_input = x_node + elif layout == _Layout.BHSD: + rope_input = self._trace_through_permute(x_node) + if rope_input is None: + logger.debug("Skipping RoPE fusion: BHSD but x is not a permute_copy") + return + else: + logger.debug("Skipping RoPE fusion: unrecognized layout") + return + + cos_node = self._trace_through_unsqueezes(cos_unsqueezed) + sin_node = self._trace_through_unsqueezes(sin_unsqueezed) + + if not self._has_doubled_freqs(cos_unsqueezed, sin_unsqueezed): + logger.debug("Skipping RoPE fusion: cannot verify doubled frequencies") + return + + weights = self._build_weights(graph_module, cos_node, sin_node, output_node) + + with graph_module.graph.inserting_before(output_node): + rope_node = graph_module.graph.create_node( + "call_function", + torch.ops.xnnpack.rope.default, + args=(rope_input, weights), + ) + + if layout == _Layout.BHSD: + permute_node = graph_module.graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(rope_node, self._BHSD_TO_BSHD_PERM), + ) + result_node = permute_node + else: + result_node = rope_node + + output_node.replace_all_uses_with(result_node) + graph_module.graph.eliminate_dead_code() + + # override + def call(self, graph_module: torch.fx.GraphModule): + total_matches = 0 + total_fused = 0 + for pattern in rope.get_graphs(): + sm = SubgraphMatcher(pattern.graph, ignore_literals=True) + matches = list(sm.match(graph_module.graph)) + total_matches += len(matches) + for match in matches: + try: + self.create_rope(graph_module, match) + total_fused += 1 + except Exception: + logger.warning("Failed to fuse RoPE pattern", exc_info=True) + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index e06f337f9ee..61da4082ea6 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -43,6 +43,7 @@ op_prelu, op_quant_dequant, op_relu, + op_rope, op_rsqrt, op_sigmoid, op_sin, diff --git a/backends/xnnpack/operators/op_rope.py b/backends/xnnpack/operators/op_rope.py new file mode 100644 index 00000000000..e5e2d58b1bf --- /dev/null +++ b/backends/xnnpack/operators/op_rope.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNRope, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + +# Register the custom op used by the fusion pass. The fused node targets +# this op after ConvertToRopePass replaces the decomposed HF RoPE subgraph. +lib = torch.library.Library("xnnpack", "FRAGMENT") +lib.define("rope(Tensor input, Tensor weights) -> Tensor") + + +@torch.library.impl(lib, "rope", "CompositeExplicitAutograd") +def rope_impl(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + channels = input.shape[-1] + half_c = channels // 2 + cos = weights[..., :half_c] + sin = weights[..., half_c:] + + x_real = input[..., :half_c] + x_imag = input[..., half_c:] + + out_real = x_real * cos - x_imag * sin + out_imag = x_real * sin + x_imag * cos + return torch.cat([out_real, out_imag], dim=-1) + + +@torch.library.impl(lib, "rope", "Meta") +def rope_meta(input: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + return torch.empty_like(input) + + +@register_node_visitor +class RopeVisitor(NodeVisitor): + target = "rope.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + input_id = vals_to_ids[get_input_node(node, 0)] + weights_id = vals_to_ids[get_input_node(node, 1)] + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNRope( + max_tokens=0, + input_id=input_id, + weights_id=weights_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/operators/op_squeeze.py b/backends/xnnpack/operators/op_squeeze.py index 3fd5a692e0c..55834aa91cf 100644 --- a/backends/xnnpack/operators/op_squeeze.py +++ b/backends/xnnpack/operators/op_squeeze.py @@ -13,6 +13,7 @@ ) from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( XNNGraph, + XNNStaticExpandDims, XNNStaticReshape, XNode, ) @@ -98,46 +99,21 @@ def define_node( vals_to_ids: Dict[torch.fx.Node, int], debug_handle: int, ) -> None: - - dim = cast(int, node.args[1]) - check_or_raise( - dim == -1 or dim == len(node.args[0].meta["val"].shape), - "XNNPACK currently only supports unsqueezing in last dimension", - ) - self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) input_node = get_input_node(node, 0) - # input input_id = vals_to_ids[input_node] - - # output output_id = vals_to_ids[node] - check_or_raise( - "val" in input_node.meta, - "Missing val in tensor metadata for input when serializing XNNStaticReshape node", - ) - dynamic_shape = node.meta["val"].shape - new_shape = [] - - num_dynamic_dims = 0 - for dim in dynamic_shape: - if free_symbols(dim): - num_dynamic_dims += 1 - new_shape.append(0) - else: - new_shape.append(dim) - - check_or_raise( - num_dynamic_dims <= 1, - "XNNPACK reshape only supports 1 dynamic dimension. This may occur when ", - ) + dim = cast(int, node.args[1]) + input_ndim = len(input_node.meta["val"].shape) + if dim < 0: + dim = input_ndim + 1 + dim ser_node = XNode( - xnode_union=XNNStaticReshape( - num_dims=len(new_shape), - new_shape=new_shape, + xnode_union=XNNStaticExpandDims( + num_new_axes=1, + new_axes=[dim], input_id=input_id, output_id=output_id, flags=0, diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 26ac6275ef1..d7acddf7993 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -55,6 +55,7 @@ SubConfig, TanhConfig, ToDimOrderCopyConfig, + UnsqueezeCopyConfig, UpsampleBilinear2dConfig, ) from executorch.backends.xnnpack.partition.config.node_configs import ( @@ -116,6 +117,7 @@ SoftmaxConfig, SquareRootConfig, SubConfig, + UnsqueezeCopyConfig, UpsampleBilinear2dConfig, # Quant/Dequant Op Configs QuantizedPerTensorConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 0e588af66cb..dfc78ee65d3 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -216,7 +216,7 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: if not self.check_common_constraints(node, ep): return False - num_tensors = len(node.all_input_nodes) + num_tensors = len(node.args[0]) if not (num_tensors >= 2): why( @@ -581,8 +581,15 @@ class SliceCopyConfig(GenericNodePartitionerConfig): def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: """ - Support slicing with stride = 1, no zero-dim tensors, Slice isn't supported - if the input or output is dynamic + Support slicing with stride = 1 and static begin/end offsets. + + XNNPACK's static_slice computes output shapes at runtime from the + actual input dimensions, so dynamic (symbolic) dims in non-sliced + dimensions are fine. We only reject when: + - stride != 1 + - the sliced dimension itself is symbolic in the input + - a begin or end arg is symbolic + - any output dim is zero """ if not self.check_common_constraints(node, ep): return False @@ -594,25 +601,51 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: if stride != 1: return False - input_node = get_input_node(node, 0) - output_node = node + # Slice dim and begin/end must be static + dim_of_slice = cast(int, node.args[1]) if len(node.args) > 1 else 0 + + begin = node.args[2] if len(node.args) > 2 else None + if begin is not None and not isinstance(begin, int): + why(node, reason=f"slice begin is not static: {begin}") + return False + + end = node.args[3] if len(node.args) > 3 else None + if end is not None and not isinstance(end, int): + why(node, reason=f"slice end is not static: {end}") + return False + input_node = get_input_node(node, 0) input_shape = list(input_node.meta["val"].shape) - output_shape = list(output_node.meta["val"].shape) + output_shape = list(node.meta["val"].shape) + + # The sliced dimension must be static in the input + ndim = len(input_shape) + normalized_dim = dim_of_slice if dim_of_slice >= 0 else dim_of_slice + ndim + if 0 <= normalized_dim < ndim: + sliced_dim = input_shape[normalized_dim] + if not isinstance(sliced_dim, int) or sliced_dim == 0: + why( + node, + reason=f"sliced dimension {normalized_dim} has invalid size: {sliced_dim}", + ) + return False - for dim in input_shape: - if not isinstance(dim, int) or dim == 0: + # The output of the sliced dimension must also be static + if 0 <= normalized_dim < len(output_shape): + out_sliced = output_shape[normalized_dim] + if not isinstance(out_sliced, int) or out_sliced == 0: why( node, - reason=f"input tensor has invalid shape, dim: {dim} of type {type(dim)}. Expecting non-zero, int values.", + reason=f"output sliced dimension {normalized_dim} is not static: {out_sliced}", ) return False - for dim in output_shape: - if not isinstance(dim, int) or dim == 0: + # No zero-dim outputs + for i, dim in enumerate(output_shape): + if isinstance(dim, int) and dim == 0: why( node, - reason=f"output tensor has invalid shape, dim: {dim} of type {type(dim)}. Expecting non-zero, int values.", + reason=f"output dim {i} is zero", ) return False @@ -717,6 +750,13 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return True +class UnsqueezeCopyConfig(GenericNodePartitionerConfig): + target_name = "unsqueeze_copy.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + class CosConfig(GenericNodePartitionerConfig): target_name = "cos.default" diff --git a/backends/xnnpack/partition/graphs/rope.py b/backends/xnnpack/partition/graphs/rope.py new file mode 100644 index 00000000000..5d0fcd3a1aa --- /dev/null +++ b/backends/xnnpack/partition/graphs/rope.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache +from typing import List + +import torch +from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config +from executorch.exir import to_edge +from torch.export import export + + +class HFRotaryEmbeddingPattern(torch.nn.Module): + """ + HuggingFace-style rotary embedding for a single tensor. + + The pattern excludes unsqueeze ops because cos/sin unsqueezes are typically + shared between q and k RoPE applications. SubgraphMatcher's containment + check rejects matches where intermediate nodes have external users, so the + unsqueezes must be outside the pattern boundary. + """ + + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + rot = torch.cat((-x2, x1), dim=-1) + return (x * cos) + (rot * sin) + + +@lru_cache(maxsize=None) +def get_graphs() -> List[torch.fx.GraphModule]: + """ + Returns decomposed edge-dialect graph(s) for the HF RoPE pattern. + """ + batch_size = 1 + seq_len = 8 + n_heads = 4 + head_dim = 32 + + x = torch.randn(batch_size, seq_len, n_heads, head_dim) + # cos/sin are post-unsqueeze: [batch, seq, 1, head_dim] + cos = torch.randn(batch_size, seq_len, 1, head_dim) + sin = torch.randn(batch_size, seq_len, 1, head_dim) + + edge = to_edge( + export(HFRotaryEmbeddingPattern(), (x, cos, sin), strict=True), + compile_config=get_xnnpack_edge_compile_config(), + ) + return [edge.exported_program().graph_module] diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 4881844ac6d..63c75eada2e 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1699,6 +1699,60 @@ Error defineGenericBinaryNode( node->debug_handle()); \ } +Error defineRopeNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNRope(); + + xnn_status status = xnn_define_rope( + subgraph_ptr, + graph_node->max_tokens(), + remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->weights_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create RoPE node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + +Error defineStaticExpandDimsNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNStaticExpandDims(); + + std::vector new_axes = flatbufferDimsToVector(graph_node->new_axes()); + xnn_status status = xnn_define_static_expand_dims( + subgraph_ptr, + graph_node->num_new_axes(), + new_axes.data(), + remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create static expand dims node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Returns the pointer to the defineNode function that handles the given XNode type @@ -1794,6 +1848,8 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(StaticSlice) _DEFINE(BatchMatrixMultiply) _DEFINE(Copy) + _DEFINE(Rope) + _DEFINE(StaticExpandDims) case fb_xnnpack::XNodeUnion::NONE: default: // Adding here as a catch all, just in case return &defineNotImplementedNode; diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index fb2c9b1598c..1ca6d020c1e 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -159,6 +159,24 @@ union XNodeUnion { XNNSin: _XNNNode1x1, XNNCopy: _XNNNode1x1, XNNCos: _XNNNode1x1, + XNNRope, + XNNStaticExpandDims, +} + +table XNNRope { + max_tokens: uint; + input_id: uint; + weights_id: uint; + output_id: uint; + flags: uint; +} + +table XNNStaticExpandDims { + num_new_axes: uint; + new_axes: [uint]; + input_id: uint; + output_id: uint; + flags: uint; } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 203469421d1..283fa6134cb 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -155,6 +155,24 @@ union XNodeUnion { XNNSin: _XNNNode1x1, XNNCopy: _XNNNode1x1, XNNCos: _XNNNode1x1, + XNNRope, + XNNStaticExpandDims, +} + +table XNNRope { + max_tokens: uint; + input_id: uint; + weights_id: uint; + output_id: uint; + flags: uint; +} + +table XNNStaticExpandDims { + num_new_axes: uint; + new_axes: [uint]; + input_id: uint; + output_id: uint; + flags: uint; } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index e95a55e1c01..bfd57df8533 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -373,6 +373,24 @@ class XNNScaledDotProductAttention: flags: int +@dataclass +class XNNRope: + max_tokens: int + input_id: int + weights_id: int + output_id: int + flags: int + + +@dataclass +class XNNStaticExpandDims: + num_new_axes: int + new_axes: List[int] + input_id: int + output_id: int + flags: int + + XNodeUnion = Union[ XNNAdd, XNNFullyConnected, @@ -421,6 +439,8 @@ class XNNScaledDotProductAttention: XNNSin, XNNCopy, XNNCos, + XNNRope, + XNNStaticExpandDims, ] diff --git a/backends/xnnpack/test/ops/test_rope.py b/backends/xnnpack/test/ops/test_rope.py new file mode 100644 index 00000000000..881841462c2 --- /dev/null +++ b/backends/xnnpack/test/ops/test_rope.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester +from torch.export import Dim + + +def _hf_freqs(seq_len: int, head_dim: int, doubled: bool = True) -> tuple: + """Generate cos/sin frequencies. If doubled, first and second halves are identical.""" + half = head_dim // 2 + freqs = torch.randn(seq_len, half) + if doubled: + emb = torch.cat((freqs, freqs), dim=-1) + else: + emb = torch.cat((freqs, torch.randn(seq_len, half)), dim=-1) + return torch.cos(emb).unsqueeze(0), torch.sin(emb).unsqueeze(0) + + +class TestRope(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class HFRope(torch.nn.Module): + """HuggingFace-style rotary position embedding (split-half layout).""" + + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + rot = torch.cat((-x2, x1), dim=-1) + return (x * cos) + (rot * sin) + + def _test_rope(self, inputs, dynamic_shapes=None): + ( + Tester(self.HFRope(), inputs, dynamic_shapes=dynamic_shapes) + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(inputs=inputs) + ) + + def test_fp32_rope(self): + batch, seq_len, n_heads, head_dim = 1, 8, 4, 32 + cos, sin = _hf_freqs(seq_len, head_dim) + inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin) + self._test_rope(inputs) + + def test_fp32_rope_large_head_dim(self): + batch, seq_len, n_heads, head_dim = 1, 16, 8, 128 + cos, sin = _hf_freqs(seq_len, head_dim) + inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin) + self._test_rope(inputs) + + def test_fp32_rope_dynamic_seq_len(self): + batch, seq_len, n_heads, head_dim = 1, 8, 4, 32 + cos, sin = _hf_freqs(seq_len, head_dim) + inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin) + seq = Dim("seq", min=1, max=128) + dynamic_shapes = ( + {0: None, 1: seq, 2: None, 3: None}, + {0: None, 1: seq, 2: None}, + {0: None, 1: seq, 2: None}, + ) + self._test_rope(inputs, dynamic_shapes=dynamic_shapes) + + class HFRopeBHSD(torch.nn.Module): + """HF-style RoPE with BHSD layout (transpose before/after RoPE).""" + + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + x = x.transpose(1, 2) + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + rot = torch.cat((-x2, x1), dim=-1) + out = (x * cos) + (rot * sin) + return out.transpose(1, 2) + + def _test_rope_bhsd(self, inputs, dynamic_shapes=None): + ( + Tester(self.HFRopeBHSD(), inputs, dynamic_shapes=dynamic_shapes) + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(inputs=inputs) + ) + + def test_fp32_rope_bhsd(self): + batch, seq_len, n_heads, head_dim = 1, 8, 4, 32 + cos, sin = _hf_freqs(seq_len, head_dim) + inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin) + self._test_rope_bhsd(inputs) + + def test_fp32_rope_bhsd_large_head_dim(self): + batch, seq_len, n_heads, head_dim = 1, 16, 8, 128 + cos, sin = _hf_freqs(seq_len, head_dim) + inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin) + self._test_rope_bhsd(inputs) + + def test_fp32_rope_bhsd_dynamic_seq_len(self): + batch, seq_len, n_heads, head_dim = 1, 8, 4, 32 + cos, sin = _hf_freqs(seq_len, head_dim) + inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin) + seq = Dim("seq", min=1, max=128) + dynamic_shapes = ( + {0: None, 1: seq, 2: None, 3: None}, + {0: None, 1: seq, 2: None}, + {0: None, 1: seq, 2: None}, + ) + self._test_rope_bhsd(inputs, dynamic_shapes=dynamic_shapes) + + def test_non_doubled_freqs_not_fused(self): + """Non-doubled cos/sin must not be fused into xnnpack.rope. + + The fused op only uses the first half of cos/sin as weights. If + fusion fires on non-doubled frequencies (where the two halves + differ), the second half is silently discarded and the output is + wrong. This test catches that: run_method_and_compare_outputs will + fail if fusion incorrectly applies. + """ + batch, seq_len, n_heads, head_dim = 1, 8, 4, 32 + cos, sin = _hf_freqs(seq_len, head_dim, doubled=False) + inputs = (torch.randn(batch, seq_len, n_heads, head_dim), cos, sin) + ( + Tester(self.HFRope(), inputs) + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(inputs=inputs) + ) diff --git a/backends/xnnpack/test/ops/test_slice_copy.py b/backends/xnnpack/test/ops/test_slice_copy.py index f8189ab9862..d38540e2112 100644 --- a/backends/xnnpack/test/ops/test_slice_copy.py +++ b/backends/xnnpack/test/ops/test_slice_copy.py @@ -121,7 +121,7 @@ def forward(self, x): def test_fp32_static_slice_with_dynamic_dim(self): """ - XNNPACK does not support dynamic dims with static slice + Slices on static dims are delegated; slices on the dynamic dim are not. """ class SliceCopy(torch.nn.Module): @@ -137,7 +137,8 @@ def forward(self, x): ) .export() .to_edge_transform_and_lower() - .check_not(["torch.ops.higher_order.executorch_call_delegate"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check(["executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"]) ) # Note: Slice ends up as slice_copy later in the process, but during quantization,