You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update on "[ET-VK] Add fused HuggingFace RoPE operator (apply_rotary_emb_hf)"
Add a fused rotary positional embedding operator for the HuggingFace RoPE
convention used by Qwen3, Phi-4-mini, and other HF-based models.
The existing `et_vk.apply_rotary_emb` only matches the stock Meta/Llama RoPE
pattern (interleaved pairs via reshape+unbind+stack+flatten). HF models use a
different convention (split-half via slice+neg+cat), causing Qwen3's RoPE to
decompose into ~560 GPU dispatches per decode step instead of 16 fused
dispatches (~1,295 µs/decode, 7% of total).
This commit adds `et_vk.apply_rotary_emb_hf` with:
- Pattern matching: `HfRotaryEmbeddingPattern` in `patterns/rope_hf.py` using
SubgraphMatcher to detect the HF RoPE graph and replace with fused op.
Supports both full rotation (freqs_dim == head_dim) and partial rotation
(freqs_dim < head_dim, e.g. Phi-4-mini with partial_rotary_factor=0.75)
by registering two pattern variants in get_hf_rope_graphs().
- GLSL shader: `rotary_embedding_hf.glsl` which pairs elements at distance D/2
(half-apart) instead of adjacent pairs, computing half_dim from the metadata
UBO for dynamic shape support
- C++ dispatch: `add_rotary_embedding_hf_node` with corrected assertion
(head_dim == freqs_dim, not freqs_dim*2) since HF freqs are full-dim
- Custom op registration in both xplat and fbcode
- Op tests covering multiple configurations and dynamic prefill→decode resize
Also adds a convert_phi4_mini_weights binary target to the phi_4_mini TARGETS
file to enable converting HF checkpoint weights to Meta format.
Authored with Claude.
Differential Revision: [D98741178](https://our.internmc.facebook.com/intern/diff/D98741178/)
[ghstack-poisoned]
0 commit comments