Skip to content

Commit e79a11d

Browse files
ssjiarascani
authored andcommitted
[ET-VK] Add fused HuggingFace RoPE operator (apply_rotary_emb_hf)
Pull Request resolved: pytorch#18592 Add a fused rotary positional embedding operator for the HuggingFace RoPE convention used by Qwen3, Phi-4-mini, and other HF-based models. The existing `et_vk.apply_rotary_emb` only matches the stock Meta/Llama RoPE pattern (interleaved pairs via reshape+unbind+stack+flatten). HF models use a different convention (split-half via slice+neg+cat), causing Qwen3's RoPE to decompose into ~560 GPU dispatches per decode step instead of 16 fused dispatches (~1,295 µs/decode, 7% of total). This commit adds `et_vk.apply_rotary_emb_hf` with: - Pattern matching: `HfRotaryEmbeddingPattern` in `patterns/rope_hf.py` using SubgraphMatcher to detect the HF RoPE graph and replace with fused op. Supports both full rotation (freqs_dim == head_dim) and partial rotation (freqs_dim < head_dim, e.g. Phi-4-mini with partial_rotary_factor=0.75) by registering two pattern variants in get_hf_rope_graphs(). - GLSL shader: `rotary_embedding_hf.glsl` which pairs elements at distance D/2 (half-apart) instead of adjacent pairs, computing half_dim from the metadata UBO for dynamic shape support - C++ dispatch: `add_rotary_embedding_hf_node` with corrected assertion (head_dim == freqs_dim, not freqs_dim*2) since HF freqs are full-dim - Custom op registration in both xplat and fbcode - Op tests covering multiple configurations and dynamic prefill→decode resize Also adds a convert_phi4_mini_weights binary target to the phi_4_mini TARGETS file to enable converting HF checkpoint weights to Meta format. Authored with Claude. ghstack-source-id: 359963407 @exported-using-ghexport Differential Revision: [D98741178](https://our.internmc.facebook.com/intern/diff/D98741178/)
1 parent 9a997e8 commit e79a11d

9 files changed

Lines changed: 1123 additions & 0 deletions

File tree

backends/vulkan/custom_ops_lib.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,32 @@ def apply_rotary_emb_impl(
802802
lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd")
803803
apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)
804804

805+
#########################
806+
## apply_rotary_emb_hf ##
807+
#########################
808+
809+
810+
def apply_rotary_emb_hf_impl(
811+
xq: torch.Tensor,
812+
xk: torch.Tensor,
813+
freqs_cos: torch.Tensor,
814+
freqs_sin: torch.Tensor,
815+
start_pos: int,
816+
):
817+
seq_len = xq.shape[1]
818+
freqs_cos = freqs_cos[start_pos : start_pos + seq_len]
819+
freqs_sin = freqs_sin[start_pos : start_pos + seq_len]
820+
pattern = vk_patterns.HfRotaryEmbeddingPattern()
821+
return pattern.forward(xq, xk, freqs_cos, freqs_sin)
822+
823+
824+
name = "apply_rotary_emb_hf"
825+
lib.define(
826+
f"{name}(Tensor xq, Tensor xk, Tensor freqs_cos, Tensor freqs_sin, SymInt start_pos) -> (Tensor, Tensor)"
827+
)
828+
lib.impl(name, apply_rotary_emb_hf_impl, "CompositeExplicitAutograd")
829+
apply_rotary_emb_hf_op = getattr(getattr(torch.ops, namespace), name)
830+
805831
########################
806832
## q8ta_add ##
807833
########################

backends/vulkan/op_registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,16 @@ def register_apply_rotary_emb():
10861086
)
10871087

10881088

1089+
@update_features(exir_ops.edge.et_vk.apply_rotary_emb_hf.default)
1090+
def register_apply_rotary_emb_hf():
1091+
return OpFeatures(
1092+
inputs_storage=utils.CONTIGUOUS_ANY,
1093+
inputs_dtypes=utils.FP_T,
1094+
supports_resize=True,
1095+
supports_highdim=True,
1096+
)
1097+
1098+
10891099
# =============================================================================
10901100
# Permute.cpp
10911101
# =============================================================================

backends/vulkan/patterns/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ fbcode_target(_kind = runtime.python_library,
1010
"__init__.py",
1111
"pattern_registry.py",
1212
"rope.py",
13+
"rope_hf.py",
1314
"quantized_embedding.py",
1415
"quantized_linear.py",
1516
"quantized_convolution.py",

backends/vulkan/patterns/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import executorch.backends.vulkan.patterns.rope # noqa
2020

21+
import executorch.backends.vulkan.patterns.rope_hf # noqa
22+
2123
import executorch.backends.vulkan.patterns.sdpa # noqa
2224

2325
import executorch.backends.vulkan.patterns.select_as_symint # noqa
@@ -37,6 +39,7 @@
3739
)
3840

3941
from executorch.backends.vulkan.patterns.rope import RotaryEmbeddingPattern
42+
from executorch.backends.vulkan.patterns.rope_hf import HfRotaryEmbeddingPattern
4043

4144
from executorch.exir import ExportedProgram
4245

@@ -49,6 +52,7 @@
4952
"DetectorFn",
5053
"CreateReplacementFn",
5154
"RotaryEmbeddingPattern",
55+
"HfRotaryEmbeddingPattern",
5256
"fusable_patterns",
5357
"register_pattern_graph",
5458
"register_pattern_detector",
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
9+
from functools import lru_cache
10+
from typing import List, Optional
11+
12+
import torch
13+
14+
from executorch.backends.vulkan.patterns.pattern_registry import (
15+
PatternMatch,
16+
register_pattern_graph,
17+
register_pattern_replacement,
18+
)
19+
20+
from executorch.exir import EdgeCompileConfig, ExportedProgram, to_edge
21+
from executorch.exir.dialects._ops import ops as exir_ops
22+
23+
from torch.export import export
24+
25+
26+
class HfRotaryEmbeddingPattern(torch.nn.Module):
27+
"""
28+
HuggingFace-style RoPE using rotate_half convention.
29+
Matches the hf_apply_rotary_emb function in examples/models/llama/rope.py.
30+
"""
31+
32+
def __init__(self):
33+
super().__init__()
34+
35+
def forward(
36+
self,
37+
xq: torch.Tensor,
38+
xk: torch.Tensor,
39+
freqs_cos: torch.Tensor,
40+
freqs_sin: torch.Tensor,
41+
):
42+
cos = freqs_cos.unsqueeze(1)
43+
sin = freqs_sin.unsqueeze(1)
44+
45+
rotary_dim = cos.shape[-1]
46+
q_rot, q_pass = xq[..., :rotary_dim], xq[..., rotary_dim:]
47+
k_rot, k_pass = xk[..., :rotary_dim], xk[..., rotary_dim:]
48+
49+
q_embed = torch.cat(
50+
[(q_rot.float() * cos) + (self._rotate_half(q_rot.float()) * sin), q_pass],
51+
dim=-1,
52+
)
53+
k_embed = torch.cat(
54+
[(k_rot.float() * cos) + (self._rotate_half(k_rot.float()) * sin), k_pass],
55+
dim=-1,
56+
)
57+
return q_embed.type_as(xq), k_embed.type_as(xk)
58+
59+
@staticmethod
60+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
61+
x1 = x[..., : x.shape[-1] // 2]
62+
x2 = x[..., x.shape[-1] // 2 :]
63+
return torch.cat((-x2, x1), dim=-1)
64+
65+
66+
@lru_cache(maxsize=2)
67+
@register_pattern_graph("hf_rope")
68+
def get_hf_rope_graphs() -> List[torch.fx.GraphModule]:
69+
batch_size = 1
70+
seq_len = 1
71+
n_heads = 4
72+
n_kv_heads = 2
73+
head_dim = 32
74+
75+
graphs = []
76+
dtype = torch.float32
77+
78+
# Full rotation pattern (partial_rotary_factor == 1.0): freqs_dim == head_dim
79+
xq = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype)
80+
xk = torch.randn(batch_size, seq_len, n_kv_heads, head_dim, dtype=dtype)
81+
freqs_cos = torch.randn(seq_len, head_dim, dtype=dtype)
82+
freqs_sin = torch.randn(seq_len, head_dim, dtype=dtype)
83+
84+
edge = to_edge(
85+
export(
86+
HfRotaryEmbeddingPattern(),
87+
(xq, xk, freqs_cos, freqs_sin),
88+
strict=True,
89+
),
90+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
91+
)
92+
gm = edge.exported_program().graph_module
93+
graphs.append(gm)
94+
95+
# Partial rotation pattern (partial_rotary_factor < 1.0): freqs_dim < head_dim
96+
# e.g. head_dim=32, rotary_dim=24 (0.75 factor), so q_pass is non-empty
97+
rotary_dim = 24
98+
xq_partial = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype)
99+
xk_partial = torch.randn(batch_size, seq_len, n_kv_heads, head_dim, dtype=dtype)
100+
freqs_cos_partial = torch.randn(seq_len, rotary_dim, dtype=dtype)
101+
freqs_sin_partial = torch.randn(seq_len, rotary_dim, dtype=dtype)
102+
103+
edge_partial = to_edge(
104+
export(
105+
HfRotaryEmbeddingPattern(),
106+
(xq_partial, xk_partial, freqs_cos_partial, freqs_sin_partial),
107+
strict=True,
108+
),
109+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
110+
)
111+
gm_partial = edge_partial.exported_program().graph_module
112+
graphs.append(gm_partial)
113+
114+
return graphs
115+
116+
117+
def identify_hf_rotary_emb_io_nodes(
118+
ep: ExportedProgram,
119+
graph_module: torch.fx.GraphModule,
120+
match: PatternMatch,
121+
) -> Optional[List[torch.fx.Node]]:
122+
input_nodes = match.input_nodes
123+
if len(input_nodes) != 4:
124+
return None
125+
126+
xq, xk, freqs_cos, freqs_sin = input_nodes
127+
128+
output_nodes = match.output_nodes
129+
if len(output_nodes) != 2:
130+
return None
131+
132+
xq_out, xk_out = output_nodes
133+
134+
return [xq, xk, freqs_cos, freqs_sin, xq_out, xk_out]
135+
136+
137+
@register_pattern_replacement("hf_rope")
138+
def create_hf_rotary_emb_custom_op(
139+
ep: ExportedProgram,
140+
graph_module: torch.fx.GraphModule,
141+
match: PatternMatch,
142+
):
143+
io_nodes = identify_hf_rotary_emb_io_nodes(ep, graph_module, match)
144+
if io_nodes is None:
145+
return
146+
147+
assert len(io_nodes) == 6
148+
xq, xk, freqs_cos, freqs_sin, xq_out, xk_out = io_nodes
149+
150+
# Check if freqs come from slice_copy and extract full table + start_pos
151+
if (
152+
freqs_cos.op == "call_function"
153+
and freqs_cos.target == exir_ops.edge.aten.slice_copy.Tensor
154+
):
155+
full_freqs_cos = freqs_cos.args[0]
156+
start_pos = freqs_cos.args[2]
157+
full_freqs_sin = freqs_sin.args[0]
158+
freqs_cos = full_freqs_cos
159+
freqs_sin = full_freqs_sin
160+
else:
161+
start_pos = 0
162+
163+
with graph_module.graph.inserting_before(xq_out):
164+
rotary_emb_node = graph_module.graph.create_node(
165+
"call_function",
166+
exir_ops.edge.et_vk.apply_rotary_emb_hf.default,
167+
args=(xq, xk, freqs_cos, freqs_sin, start_pos),
168+
)
169+
170+
with graph_module.graph.inserting_after(rotary_emb_node):
171+
getitem_0 = graph_module.graph.create_node(
172+
"call_function",
173+
operator.getitem,
174+
args=(rotary_emb_node, 0),
175+
)
176+
getitem_1 = graph_module.graph.create_node(
177+
"call_function",
178+
operator.getitem,
179+
args=(rotary_emb_node, 1),
180+
)
181+
182+
if hasattr(xq_out, "meta") and "val" in xq_out.meta:
183+
getitem_0.meta["val"] = xq_out.meta["val"]
184+
if hasattr(xk_out, "meta") and "val" in xk_out.meta:
185+
getitem_1.meta["val"] = xk_out.meta["val"]
186+
187+
xq_out.replace_all_uses_with(getitem_0)
188+
xk_out.replace_all_uses_with(getitem_1)

0 commit comments

Comments
 (0)