Skip to content

Commit 14f27e1

Browse files
author
ssjia
committed
Update on "[ET-VK] Add fused HuggingFace RoPE operator (apply_rotary_emb_hf)"
Add a fused rotary positional embedding operator for the HuggingFace RoPE convention used by Qwen3, Phi-4-mini, and other HF-based models. The existing `et_vk.apply_rotary_emb` only matches the stock Meta/Llama RoPE pattern (interleaved pairs via reshape+unbind+stack+flatten). HF models use a different convention (split-half via slice+neg+cat), causing Qwen3's RoPE to decompose into ~560 GPU dispatches per decode step instead of 16 fused dispatches (~1,295 µs/decode, 7% of total). This commit adds `et_vk.apply_rotary_emb_hf` with: - Pattern matching: `HfRotaryEmbeddingPattern` in `patterns/rope_hf.py` using SubgraphMatcher to detect the HF RoPE graph and replace with fused op. Supports both full rotation (freqs_dim == head_dim) and partial rotation (freqs_dim < head_dim, e.g. Phi-4-mini with partial_rotary_factor=0.75) by registering two pattern variants in get_hf_rope_graphs(). - GLSL shader: `rotary_embedding_hf.glsl` which pairs elements at distance D/2 (half-apart) instead of adjacent pairs, computing half_dim from the metadata UBO for dynamic shape support - C++ dispatch: `add_rotary_embedding_hf_node` with corrected assertion (head_dim == freqs_dim, not freqs_dim*2) since HF freqs are full-dim - Custom op registration in both xplat and fbcode - Op tests covering multiple configurations and dynamic prefill→decode resize Also adds a convert_phi4_mini_weights binary target to the phi_4_mini TARGETS file to enable converting HF checkpoint weights to Meta format. Authored with Claude. Differential Revision: [D98741178](https://our.internmc.facebook.com/intern/diff/D98741178/) [ghstack-poisoned]
2 parents 1bdfb33 + 8092bfe commit 14f27e1

37 files changed

Lines changed: 1453 additions & 610 deletions

.github/workflows/android-release-artifacts.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ jobs:
145145
export BUILD_AAR_DIR=aar-out
146146
bash scripts/build_android_library.sh
147147
mkdir -p "${ARTIFACTS_DIR_NAME}"
148-
cp aar-out/executorch.aar "${ARTIFACTS_DIR_NAME}/executorch.aar"
148+
cp aar-out/executorch.aar "${ARTIFACTS_DIR_NAME}/executorch-${FLAVOR}.aar"
149149
150-
shasum -a 256 "${ARTIFACTS_DIR_NAME}/executorch.aar"
150+
shasum -a 256 "${ARTIFACTS_DIR_NAME}/executorch-${FLAVOR}.aar"
151151
152152
# Publish to maven staging
153153
UPLOAD_TO_MAVEN="${{ inputs.upload_to_maven }}"
@@ -172,11 +172,6 @@ jobs:
172172
- name: Upload AAR RC to AWS S3
173173
shell: bash
174174
run: |
175-
wget https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/executorch.aar
176-
shasum -a 256 executorch.aar > executorch.aar.sha256sums
177-
178-
pip install awscli==1.32.18
179-
AWS_CMD="aws s3 cp"
180175
VERSION="${{ inputs.version }}"
181176
FLAVOR="${{ inputs.flavor }}"
182177
if [ -z "$VERSION" ]; then
@@ -185,5 +180,11 @@ jobs:
185180
if [ -z "$FLAVOR" ]; then
186181
FLAVOR="xnnpack"
187182
fi
183+
wget https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/executorch-${FLAVOR}.aar
184+
mv executorch-${FLAVOR}.aar executorch.aar
185+
shasum -a 256 executorch.aar > executorch.aar.sha256sums
186+
187+
pip install awscli==1.32.18
188+
AWS_CMD="aws s3 cp"
188189
${AWS_CMD} executorch.aar s3://ossci-android/executorch/release/${VERSION}-${FLAVOR}/executorch.aar --acl public-read
189190
${AWS_CMD} executorch.aar.sha256sums s3://ossci-android/executorch/release/${VERSION}-${FLAVOR}/executorch.aar.sha256sums --acl public-read

.github/workflows/android-release-on-tag.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ on:
66
- 'v*.*.*-rc*'
77
- 'v*.*.*'
88

9+
permissions:
10+
id-token: write
11+
contents: read
12+
913
jobs:
1014
prepare:
1115
runs-on: ubuntu-latest

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
4444
from .decompose_div_pass import DecomposeDivPass # noqa
4545
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
46+
from .decompose_einsum_pass import DecomposeEinsumPass # noqa
4647
from .decompose_elu_pass import DecomposeEluPass # noqa
4748
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
4849
from .decompose_erfinv_pass import DecomposeErfinvPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
DecomposeCumsumPass,
5252
DecomposeDivPass,
5353
DecomposeDivTensorModePass,
54+
DecomposeEinsumPass,
5455
DecomposeEluPass,
5556
DecomposeEmbeddingPass,
5657
DecomposeErfinvPass,
@@ -560,6 +561,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
560561
DecomposeFloorDividePass(tfa_pass=True),
561562
DecomposeDivTensorModePass(tfa_pass=True),
562563
DecomposeWhereScalarOtherPass(tfa_pass=True),
564+
DecomposeEinsumPass(tfa_pass=True),
563565
RewriteInplaceArithmeticPass(tfa_pass=True),
564566
DecomposeAddSubAlphaPass(tfa_pass=True),
565567
DecomposeLeakyReLUPass(tfa_pass=True),
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.get_decomposition_pass import GetDecompositionPass
8+
9+
10+
class DecomposeEinsumPass(GetDecompositionPass):
11+
"""Decomposes aten.einsum.default into more primitive ops.
12+
13+
This pass is intended to be called in transform_for_annotation to prepare
14+
the graph for quantization. Einsum is not annotated directly by the Arm
15+
quantizer, but the decomposed ops are.
16+
17+
"""
18+
19+
targeted_ops = [torch.ops.aten.einsum.default]
20+
21+
def _get_input_tensors(self, node: torch.fx.Node) -> list:
22+
"""Override the base hook because aten.einsum.default takes (equation,
23+
[operands]), which cannot be handled by the generic one-arg-per-input
24+
logic.
25+
"""
26+
equation, operands = node.args # type: ignore[union-attr]
27+
fake_operands = [operand.meta["val"] for operand in operands] # type: ignore[union-attr]
28+
return [equation, fake_operands]
29+
30+
def _get_placeholder_map(
31+
self,
32+
node: torch.fx.Node,
33+
decomposed_module: torch.fx.GraphModule,
34+
) -> dict[str, torch.fx.Node]:
35+
"""Override the base hook because einsum does not trace placeholders
36+
one-to-one with node.args.
37+
38+
The traced graph includes arg0_1 for the equation string and arg1_i for
39+
each tensor inside the operand list, so we must skip the equation
40+
placeholder, which is not an original FX tensor node, and map each
41+
operand placeholder back to the corresponding original FX node.
42+
43+
"""
44+
_, operands = node.args
45+
name_to_input_tensor_map = {}
46+
47+
for decomposed_node in decomposed_module.graph.nodes:
48+
if decomposed_node.op != "placeholder":
49+
continue
50+
if decomposed_node.name == "arg0_1":
51+
continue
52+
if not decomposed_node.name.startswith("arg1_"):
53+
raise RuntimeError(
54+
f"Unexpected einsum placeholder name {decomposed_node.name!r}."
55+
)
56+
57+
operand_idx = int(decomposed_node.name.split("_")[1]) - 1
58+
name_to_input_tensor_map[decomposed_node.name] = operands[operand_idx] # type: ignore[index]
59+
60+
return name_to_input_tensor_map # type: ignore[return-value]
61+
62+
def _get_output_node(self, output_node: torch.fx.Node) -> torch.fx.Node:
63+
"""Return the traced value node for einsum graphs that emit
64+
output([node]).
65+
"""
66+
return output_node.args[0][0] # type: ignore[index, return-value]

backends/arm/_passes/get_decomposition_pass.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,31 @@ def __init__(self, tfa_pass=False, *args, **kwargs):
3434
def _skip_pass(self, input_tensors: list) -> bool:
3535
return False
3636

37+
def _get_input_tensors(self, node: torch.fx.Node) -> list:
38+
input_tensors = []
39+
for arg in node.args:
40+
if hasattr(arg, "meta"):
41+
input_tensors.append(arg.meta["val"]) # type: ignore[union-attr]
42+
elif isinstance(arg, int):
43+
input_tensors.append(arg)
44+
return input_tensors
45+
46+
def _get_placeholder_map(
47+
self,
48+
node: torch.fx.Node,
49+
decomposed_module: torch.fx.GraphModule,
50+
) -> dict[str, torch.fx.Node]:
51+
# Keep decomposed_module in the hook signature so subclasses can inspect
52+
# traced placeholder structure when the mapping is not one-to-one.
53+
name_to_input_tensor_map = {}
54+
for i, arg in enumerate(node.args):
55+
name_to_input_tensor_map[f"arg{i}_1"] = arg
56+
return name_to_input_tensor_map # type: ignore[return-value]
57+
58+
def _get_output_node(self, output_node: torch.fx.Node) -> torch.fx.Node:
59+
"""Return the traced value node for graphs that emit output(node)."""
60+
return output_node.args[0] # type: ignore[return-value]
61+
3762
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
3863
modified = False
3964
for node in graph_module.graph.nodes:
@@ -44,13 +69,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
4469
):
4570
continue
4671

47-
input_tensors = []
48-
for arg in node.args:
49-
if hasattr(arg, "meta"):
50-
input_tensors.append(arg.meta["val"])
51-
52-
elif isinstance(arg, int):
53-
input_tensors.append(arg)
72+
input_tensors = self._get_input_tensors(node)
5473

5574
if self._skip_pass(input_tensors):
5675
continue
@@ -70,22 +89,26 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
7089
)(*input_tensors)
7190

7291
with graph_module.graph.inserting_before(node):
73-
name_to_input_tensor_map = {}
74-
for i, arg in enumerate(node.args):
75-
name_to_input_tensor_map[f"arg{i}_1"] = arg
92+
name_to_input_tensor_map = self._get_placeholder_map(
93+
node, decomposed_module
94+
)
7695

7796
decomposed_node_to_subgraph_node = {}
7897
last_decomposed_node = None
7998
# Create a mapping from input nodes in decomposed module to original nodes.
8099
# In decomposed module, there are only input tensors for placeholder op.
81100
for decomposed_node in decomposed_module.graph.nodes:
82101
if decomposed_node.op == "placeholder":
102+
# Some ops, such as einsum, trace extra placeholders that do
103+
# not map back to original graph tensor inputs.
104+
if decomposed_node.name not in name_to_input_tensor_map:
105+
continue
83106
decomposed_node_to_subgraph_node[decomposed_node] = (
84107
name_to_input_tensor_map[decomposed_node.name]
85108
)
86109

87110
if decomposed_node.op == "output":
88-
last_decomposed_node = decomposed_node.args[0]
111+
last_decomposed_node = self._get_output_node(decomposed_node)
89112

90113
# Copy node from decompose graph module
91114
for decomposed_node in decomposed_module.graph.nodes:

0 commit comments

Comments
 (0)