Skip to content

Commit f61350b

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Add apply_rotary_emb_interleaved fused operator
Pull Request resolved: #19115 Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch. **Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`: out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] **Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads. **Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches. **Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML. **Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`. ghstack-source-id: 373258231 @exported-using-ghexport Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/)
1 parent 0930b63 commit f61350b

6 files changed

Lines changed: 524 additions & 0 deletions

File tree

backends/vulkan/custom_ops_lib.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,55 @@ def apply_rotary_emb_hf_impl(
828828
lib.impl(name, apply_rotary_emb_hf_impl, "CompositeExplicitAutograd")
829829
apply_rotary_emb_hf_op = getattr(getattr(torch.ops, namespace), name)
830830

831+
##################################
832+
## apply_rotary_emb_interleaved ##
833+
##################################
834+
835+
836+
def apply_rotary_emb_interleaved_impl(
837+
x: torch.Tensor, freqs_cis: torch.Tensor
838+
) -> torch.Tensor:
839+
# EdgeTAM's pair-interleaved complex-number RoPE.
840+
# x: [B, N, C] with (real, imag) pairs interleaved along C
841+
# freqs_cis: any rank whose flattened layout is [N, C]. Commonly 2D
842+
# [N, C] or 4D [1, N, C/2, 2] from
843+
# `torch.view_as_real(...).unsqueeze(0)`. The (cos, sin)
844+
# pairs are interleaved along the innermost axis in the
845+
# flattened view.
846+
# Semantically equivalent to:
847+
# freqs_cis.reshape(N, C // 2, 2) -> (cos, sin)
848+
# out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k]
849+
# out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]
850+
B, N, C = x.shape
851+
a_real, a_imag = x.view(B, N, C // 2, 2).unbind(-1)
852+
# Use reshape so callers may pass freqs_cis at any rank.
853+
cs = freqs_cis.reshape(N, C // 2, 2)
854+
b_real, b_imag = cs[..., 0], cs[..., 1]
855+
out = torch.stack(
856+
(a_real * b_real - a_imag * b_imag, a_real * b_imag + a_imag * b_real),
857+
dim=-1,
858+
)
859+
return out.view(B, N, C)
860+
861+
862+
def apply_rotary_emb_interleaved_meta(
863+
x: torch.Tensor, freqs_cis: torch.Tensor
864+
) -> torch.Tensor:
865+
# Meta kernel: shape-only. Keeps the op opaque during torch.export (no
866+
# inlining of view/reshape calls into the exported graph) and does not
867+
# constrain the rank of freqs_cis — any shape with N * C elements is
868+
# accepted by the Vulkan dispatcher.
869+
return torch.empty_like(x)
870+
871+
872+
name = "apply_rotary_emb_interleaved"
873+
lib.define(f"{name}(Tensor x, Tensor freqs_cis) -> Tensor")
874+
# CPU kernel preserves eager-mode reference semantics.
875+
lib.impl(name, apply_rotary_emb_interleaved_impl, "CPU")
876+
# Meta kernel keeps the op opaque in the exported graph.
877+
lib.impl(name, apply_rotary_emb_interleaved_meta, "Meta")
878+
apply_rotary_emb_interleaved_op = getattr(getattr(torch.ops, namespace), name)
879+
831880
########################
832881
## q8ta_add ##
833882
########################

backends/vulkan/op_registry.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,22 @@ def register_apply_rotary_emb_hf():
11101110
)
11111111

11121112

1113+
@update_features(exir_ops.edge.et_vk.apply_rotary_emb_interleaved.default)
1114+
def register_apply_rotary_emb_interleaved():
1115+
return OpFeatures(
1116+
# freqs_cis is pinned to buffer storage so the shader can compute a
1117+
# flat [N, C] linear address regardless of the tensor's declared rank
1118+
# (callers commonly pass 4D [1, N, C/2, 2] without a preceding view).
1119+
inputs_storage=[
1120+
utils.CONTIGUOUS_ANY, # x
1121+
utils.CONTIGUOUS_BUFFER, # freqs_cis
1122+
],
1123+
inputs_dtypes=utils.FP_T,
1124+
supports_resize=True,
1125+
supports_highdim=True,
1126+
)
1127+
1128+
11131129
# =============================================================================
11141130
# Permute.cpp
11151131
# =============================================================================
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
${define_required_extensions(STORAGE, DTYPE)}
12+
${define_required_extensions("buffer", DTYPE)}
13+
14+
#define PRECISION ${PRECISION}
15+
16+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
17+
#define BUFFER_VEC4_T ${texel_load_type(DTYPE, "buffer")}
18+
19+
${define_active_storage_type(STORAGE)}
20+
21+
layout(std430) buffer;
22+
23+
#include "indexing.glslh"
24+
25+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
26+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
27+
// freqs_cis is always bound as a buffer so the shader can flat-index it
28+
// regardless of the caller's declared rank (2D [N, C] or 4D [1, N, C/2, 2]).
29+
${layout_declare_tensor(B, "r", "t_freqs", DTYPE, "buffer", is_scalar_array=False)}
30+
31+
$if STORAGE == "buffer":
32+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
33+
$else:
34+
${layout_declare_ubo(B, "TextureMetadata", "outp")}
35+
36+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
37+
38+
${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")}
39+
40+
/*
41+
* Applies rotary positional embeddings to a tensor whose last dimension
42+
* contains pair-interleaved (real, imag) components. This matches EdgeTAM's
43+
* `apply_rotary_enc_without_complex` semantics, where the fused cos/sin
44+
* freqs tensor has a flattened [N, C] layout (cos, sin pairs interleaved).
45+
*
46+
* Inputs:
47+
* t_in: [B, N, C] (last dim packed as [r0, i0, r1, i1, ...])
48+
* t_freqs: contiguous memory with N * C elements. May arrive at any rank
49+
* (e.g. 2D [N, C] or 4D [1, N, C/2, 2]). Physically the values
50+
* are [cos0, sin0, cos1, sin1, ...] along the final axis.
51+
*
52+
* Output:
53+
* t_out: same shape as t_in
54+
*
55+
* Math per k in [0, C/2):
56+
* out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k]
57+
* out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]
58+
*
59+
* Each thread processes one width-packed texel (4 elements = 2 (r, i) pairs).
60+
* All participating tensors are assumed to be width-packed with standard axis
61+
* maps.
62+
*
63+
* The freqs tensor is indexed using a flat (n_idx * C + c_offset) address to
64+
* remain correct regardless of input rank — the shape of t_freqs does not
65+
* need to match the logical [N, C] layout, only the underlying memory does.
66+
*/
67+
void main() {
68+
// Each thread computes one output texel of 4 elements along the last dim.
69+
TensorIndex4D out_tidx = zero_tensor4d_idx();
70+
out_tidx.data.x = int(gl_GlobalInvocationID.x) * 4;
71+
out_tidx.data.y = int(gl_GlobalInvocationID.y);
72+
out_tidx.data.z = int(gl_GlobalInvocationID.z);
73+
74+
if (out_of_bounds(out_tidx, outp)) {
75+
return;
76+
}
77+
78+
// Freqs tensor is always a contiguous buffer of N * C elements. Compute
79+
// a flat texel index directly from logical (n_idx, c_elem_idx / 4). The
80+
// logical width C comes from the output tensor metadata — both buffer
81+
// and texture metadata store this at index 0 (sizes[0][0] / sizes.x).
82+
#ifdef USING_BUFFER
83+
const uint C_width = outp.sizes[0][0];
84+
#else
85+
const uint C_width = uint(outp.sizes.x);
86+
#endif
87+
const uint freqs_texel_bufi =
88+
uint(out_tidx.data.y) * div_4(C_width)
89+
+ uint(gl_GlobalInvocationID.x);
90+
BUFFER_VEC4_T f_tex = t_freqs[freqs_texel_bufi];
91+
92+
#ifdef USING_BUFFER
93+
const uint out_texel_bufi =
94+
div_4(tensor4d_idx_to_linear_idx(outp, out_tidx));
95+
VEC4_T x_tex = t_in[out_texel_bufi];
96+
#else // USING_TEXTURE
97+
const ivec3 out_pos =
98+
tensor4d_idx_to_texel_pos_simple(outp, out_tidx, outp_layout);
99+
VEC4_T x_tex = texelFetch(t_in, out_pos, 0);
100+
#endif
101+
102+
// x_tex = (r0, i0, r1, i1), f_tex = (c0, s0, c1, s1)
103+
VEC4_T out_tex;
104+
out_tex.x = x_tex.x * VEC4_T(f_tex).x - x_tex.y * VEC4_T(f_tex).y;
105+
out_tex.y = x_tex.x * VEC4_T(f_tex).y + x_tex.y * VEC4_T(f_tex).x;
106+
out_tex.z = x_tex.z * VEC4_T(f_tex).z - x_tex.w * VEC4_T(f_tex).w;
107+
out_tex.w = x_tex.z * VEC4_T(f_tex).w + x_tex.w * VEC4_T(f_tex).z;
108+
109+
#ifdef USING_BUFFER
110+
t_out[out_texel_bufi] = out_tex;
111+
#else
112+
imageStore(t_out, out_pos, out_tex);
113+
#endif
114+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
apply_rotary_emb_interleaved:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: texture3d
5+
generate_variant_forall:
6+
STORAGE:
7+
- VALUE: texture3d
8+
- VALUE: buffer
9+
DTYPE:
10+
- VALUE: half
11+
- VALUE: float
12+
shader_variants:
13+
- NAME: apply_rotary_emb_interleaved

backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,117 @@ void apply_rotary_emb_hf(
219219
graph, args[0], args[1], args[2], args[3], args[4], xq_out, xk_out);
220220
}
221221

222+
//
223+
// EdgeTAM-style RoPE variant with fused [cos, sin] freqs tensor
224+
//
225+
// Operates on a single tensor (Q or K) of shape [B, N, C] with pair-interleaved
226+
// (real, imag) components along the last dim, and a freqs tensor with a total
227+
// element count of N * C that packs (cos, sin) pairs in the same interleaved
228+
// order as the x tensor. The freqs tensor may be passed in at any rank whose
229+
// flattened layout is [N, C] — e.g. 2D `[N, C]` or 4D `[1, N, C/2, 2]`. This
230+
// avoids callers having to emit a `view` dispatch (view_copy) purely to
231+
// normalize rank.
232+
//
233+
234+
void resize_rotary_embedding_interleaved_node(
235+
ComputeGraph* graph,
236+
const std::vector<ArgGroup>& args,
237+
const std::vector<ValueRef>& resize_args) {
238+
(void)resize_args;
239+
240+
const ValueRef out = args.at(0).refs.at(0);
241+
const ValueRef in = args.at(1).refs.at(0);
242+
243+
graph->virtual_resize(out, graph->sizes_of(in));
244+
}
245+
246+
utils::uvec3 rotary_embedding_interleaved_global_wg_size(
247+
ComputeGraph* graph,
248+
const vkapi::ShaderInfo& shader,
249+
const std::vector<ArgGroup>& args,
250+
const std::vector<ValueRef>& resize_args) {
251+
(void)shader;
252+
(void)resize_args;
253+
254+
const ValueRef out = args.at(0).refs.at(0);
255+
256+
const std::vector<int64_t> out_sizes = graph->sizes_of(out);
257+
VK_CHECK_COND(out_sizes.size() == 3);
258+
259+
const uint32_t B = static_cast<uint32_t>(out_sizes.at(0));
260+
const uint32_t N = static_cast<uint32_t>(out_sizes.at(1));
261+
const uint32_t C = static_cast<uint32_t>(out_sizes.at(2));
262+
263+
// One thread per output texel of 4 elements along C.
264+
return {utils::div_up_4(C), N, B};
265+
}
266+
267+
void add_rotary_embedding_interleaved_node(
268+
ComputeGraph& graph,
269+
const ValueRef x,
270+
const ValueRef freqs_cis,
271+
const ValueRef out) {
272+
const std::vector<int64_t> x_sizes = graph.sizes_of(x);
273+
const std::vector<int64_t> freqs_sizes = graph.sizes_of(freqs_cis);
274+
275+
VK_CHECK_COND(x_sizes.size() == 3);
276+
VK_CHECK_COND(x_sizes.at(2) % 4 == 0);
277+
278+
// freqs_cis may arrive at any rank (commonly 2D [N, C] or 4D [1, N, C/2, 2]
279+
// from `torch.view_as_real(...).unsqueeze(0)`). Validate via numel rather
280+
// than per-dim equality so callers do not need to emit a view_copy purely
281+
// to flatten the shape.
282+
int64_t freqs_numel = 1;
283+
for (const int64_t s : freqs_sizes) {
284+
freqs_numel *= s;
285+
}
286+
const int64_t expected_numel = x_sizes.at(1) * x_sizes.at(2);
287+
VK_CHECK_COND(freqs_numel == expected_numel);
288+
289+
VK_CHECK_COND(graph.packed_dim_of(x) == WHCN::kWidthDim);
290+
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
291+
VK_CHECK_COND(graph.has_standard_axis_map(x));
292+
VK_CHECK_COND(graph.has_standard_axis_map(out));
293+
// freqs_cis is pinned to buffer storage via op_registry so the shader can
294+
// use flat (row, col) indexing regardless of its declared rank.
295+
VK_CHECK_COND(graph.is_buffer_storage(freqs_cis));
296+
297+
std::string kernel_name = "apply_rotary_emb_interleaved";
298+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
299+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
300+
301+
vkapi::ParamsBindList param_ubos = {graph.meta_ubo(out)};
302+
303+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
304+
graph,
305+
VK_KERNEL_FROM_STR(kernel_name),
306+
rotary_embedding_interleaved_global_wg_size,
307+
default_pick_local_wg_size,
308+
// Inputs and Outputs
309+
{{out, vkapi::kWrite}, {{x, freqs_cis}, vkapi::kRead}},
310+
// Parameter buffers
311+
param_ubos,
312+
// Push Constants
313+
{},
314+
// Specialization Constants
315+
{graph.hashed_layout_of(out)},
316+
// Resize Args
317+
{},
318+
// Resizing Logic
319+
resize_rotary_embedding_interleaved_node));
320+
}
321+
322+
void apply_rotary_emb_interleaved(
323+
ComputeGraph& graph,
324+
const std::vector<ValueRef>& args) {
325+
add_rotary_embedding_interleaved_node(graph, args[0], args[1], args[2]);
326+
}
327+
222328
REGISTER_OPERATORS {
223329
VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb);
224330
VK_REGISTER_OP(et_vk.apply_rotary_emb_hf.default, apply_rotary_emb_hf);
331+
VK_REGISTER_OP(
332+
et_vk.apply_rotary_emb_interleaved.default, apply_rotary_emb_interleaved);
225333
}
226334

227335
} // namespace vkcompute

0 commit comments

Comments
 (0)