Skip to content

Commit 0930b63

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Update fused SDPA operator to support ViT attention
Pull Request resolved: #19114 This diff extends the ET-VK fused SDPA operator so it can be used for the ViT attention blocks in the EdgeTAM ViT-S encoder. The main correctness problem is that Q@K^T dot products in ViT attention can exceed the fp16 max (65504), so fp32 accumulation is required. **fp16 overflow fix**: The intermediate `attn_weights` buffer is now always fp32 regardless of input dtype. Previously the QK shader accumulated in fp32 but stored to an fp16 buffer, causing overflow. The softmax shader reads fp32 attention weights and writes fp16 softmax output (safe since values are in [0, 1]). **Texture support**: The QK and AV shaders support both buffer and texture3d storage for Q/K/V/output. The intermediate `attn_weights` and `attn_weights_softmax` tensors now inherit the storage type of the input/output (q_projected for the LLM path, out for the fused path), so the entire fused SDPA pipeline runs in a uniform storage type and no SDPA-internal layout transitions are needed. ghstack-source-id: 373258239 @exported-using-ghexport Differential Revision: [D102360200](https://our.internmc.facebook.com/intern/diff/D102360200/)
1 parent 8ec6e85 commit 0930b63

25 files changed

Lines changed: 1442 additions & 470 deletions

backends/vulkan/custom_ops_lib.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,34 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
960960
lib.impl(name, select_as_symint_impl, "Meta")
961961
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)
962962

963+
##########
964+
## sdpa ##
965+
##########
966+
967+
968+
def sdpa_impl(
969+
q: torch.Tensor,
970+
k: torch.Tensor,
971+
v: torch.Tensor,
972+
attn_mask: Optional[torch.Tensor] = None,
973+
scale: Optional[float] = None,
974+
):
975+
if scale is None:
976+
scale = 1.0 / (q.size(-1) ** 0.5)
977+
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
978+
if attn_mask is not None:
979+
attn = attn + attn_mask
980+
attn = torch.softmax(attn, dim=-1)
981+
return torch.matmul(attn, v)
982+
983+
984+
name = "sdpa"
985+
lib.define(
986+
f"{name}(Tensor q, Tensor k, Tensor v, Tensor? attn_mask = None, float? scale = None) -> Tensor"
987+
)
988+
lib.impl(name, sdpa_impl, "CompositeExplicitAutograd")
989+
sdpa_op = getattr(getattr(torch.ops, namespace), name)
990+
963991
################
964992
## rms_norm ##
965993
################

backends/vulkan/op_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,20 @@ def register_sdpa_cpp_ops():
10711071
)
10721072

10731073

1074+
# =============================================================================
1075+
# SDPA.cpp (fused SDPA entry point)
1076+
# =============================================================================
1077+
1078+
1079+
@update_features("et_vk::sdpa")
1080+
def register_general_sdpa():
1081+
return OpFeatures(
1082+
inputs_storage=utils.CONTIGUOUS_ANY,
1083+
inputs_dtypes=utils.FP_T,
1084+
supports_resize=True,
1085+
)
1086+
1087+
10741088
# =============================================================================
10751089
# RotaryEmbedding.cpp
10761090
# =============================================================================

backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,19 @@
1313
* Macro Settings:
1414
* - TILE_M
1515
* - TILE_K4
16+
*
17+
* Optional:
18+
* - LINEAR_FP_INPUT_TILE_VEC4_T — input tile vec4 type (default: VEC4_T).
1619
*/
1720

1821
#extension GL_EXT_control_flow_attributes : require
1922

23+
#ifndef LINEAR_FP_INPUT_TILE_VEC4_T
24+
#define LINEAR_FP_INPUT_TILE_VEC4_T VEC4_T
25+
#endif
26+
2027
struct FPInputTile {
21-
VEC4_T data[TILE_M][TILE_K4];
28+
LINEAR_FP_INPUT_TILE_VEC4_T data[TILE_M][TILE_K4];
2229
};
2330

2431
#ifdef DEBUG_MODE

backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121

2222
#include "linear_fp_input_tile.glslh"
2323

24-
VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) {
24+
LINEAR_FP_INPUT_TILE_VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) {
2525
#ifdef INPUT_BUFFER
26-
return t_input[(m * ntexels_k) + k4];
26+
return LINEAR_FP_INPUT_TILE_VEC4_T(t_input[(m * ntexels_k) + k4]);
2727
#else
28-
return texelFetch(t_input, ivec3(k4, m, 0), 0);
28+
return LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_input, ivec3(k4, m, 0), 0));
2929
#endif
3030
}
3131

@@ -53,7 +53,7 @@ void load_input_tile_with_checks(
5353
if (m_start + m < M && k4_start + k4 < K4) {
5454
in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4);
5555
} else {
56-
in_tile.data[m][k4] = VEC4_T(0.0);
56+
in_tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0);
5757
}
5858
}
5959
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,31 @@
1010
* Macro Settings:
1111
* - TILE_M
1212
* - TILE_N4
13+
*
14+
* Optional:
15+
* - LINEAR_FP_OUTPUT_TILE_VEC4_T — accumulator vec4 type (default: VEC4_T).
16+
* Set this to `vec4` to force fp32 accumulation regardless of DTYPE; used
17+
* by fused SDPA QK to avoid fp16 overflow in Q@K^T.
1318
*/
1419

1520
#ifndef LINEAR_FP_OUTPUT_TILE_GLSLH
1621
#define LINEAR_FP_OUTPUT_TILE_GLSLH
1722

1823
#extension GL_EXT_control_flow_attributes : require
1924

25+
#ifndef LINEAR_FP_OUTPUT_TILE_VEC4_T
26+
#define LINEAR_FP_OUTPUT_TILE_VEC4_T VEC4_T
27+
#define LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT
28+
#endif
29+
2030
struct FPOutTile {
21-
VEC4_T data[TILE_M][TILE_N4];
31+
LINEAR_FP_OUTPUT_TILE_VEC4_T data[TILE_M][TILE_N4];
2232
};
2333

2434
void initialize(out FPOutTile out_tile) {
2535
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
2636
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
27-
out_tile.data[m][n4] = VEC4_T(0);
37+
out_tile.data[m][n4] = LINEAR_FP_OUTPUT_TILE_VEC4_T(0);
2838
}
2939
}
3040
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
#include "linear_fp_per_out_channel_params.glslh"
2222
#include "linear_fp_weight_tile.glslh"
2323

24+
#if defined(LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT) == defined(LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT)
25+
#define MAYBE_CAST_WVEC4(x) (x)
26+
#else
27+
#define MAYBE_CAST_WVEC4(x) LINEAR_FP_OUTPUT_TILE_VEC4_T(x)
28+
#endif
29+
2430
void fp_accumulate_with_fp_weight(
2531
inout FPOutTile accum,
2632
FPInputTile in_tile,
@@ -29,23 +35,23 @@ void fp_accumulate_with_fp_weight(
2935
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
3036
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
3137
accum.data[m][n4] =
32-
fma(VEC4_T(in_tile.data[m][k4][0]),
33-
w_tile.data[mul_4(k4)][n4],
38+
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][0]),
39+
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4)][n4]),
3440
accum.data[m][n4]);
3541

3642
accum.data[m][n4] =
37-
fma(VEC4_T(in_tile.data[m][k4][1]),
38-
w_tile.data[mul_4(k4) + 1][n4],
43+
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][1]),
44+
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 1][n4]),
3945
accum.data[m][n4]);
4046

4147
accum.data[m][n4] =
42-
fma(VEC4_T(in_tile.data[m][k4][2]),
43-
w_tile.data[mul_4(k4) + 2][n4],
48+
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][2]),
49+
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 2][n4]),
4450
accum.data[m][n4]);
4551

4652
accum.data[m][n4] =
47-
fma(VEC4_T(in_tile.data[m][k4][3]),
48-
w_tile.data[mul_4(k4) + 3][n4],
53+
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][3]),
54+
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 3][n4]),
4955
accum.data[m][n4]);
5056
}
5157
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
#include "linear_fp_output_tile.glslh"
2626

2727
void write_output_x4(
28-
const VEC4_T out_texel,
28+
const LINEAR_FP_OUTPUT_TILE_VEC4_T out_texel,
2929
const int n4,
3030
const int m,
3131
const int N4) {
3232
#ifdef OUTPUT_BUFFER
33-
t_output[m * N4 + n4] = out_texel;
33+
t_output[m * N4 + n4] = VEC4_T(out_texel);
3434
#else
35-
imageStore(t_output, ivec3(n4, m, 0), out_texel);
35+
imageStore(t_output, ivec3(n4, m, 0), VEC4_T(out_texel));
3636
#endif
3737
}
3838

backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323

2424
#include "linear_fp_weight_tile.glslh"
2525

26-
VEC4_T load_packed_weight_x4(
26+
LINEAR_FP_WEIGHT_TILE_VEC4_T load_packed_weight_x4(
2727
const int n4, const int dk, const int k4, const int b, const int K4, const int N4) {
2828
#ifdef WEIGHT_BUFFER
29-
return t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk];
29+
return LINEAR_FP_WEIGHT_TILE_VEC4_T(t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk]);
3030
#else
31-
return VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0));
31+
return LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0));
3232
#endif
3333
}
3434

@@ -65,7 +65,7 @@ void load_packed_weight_tile_with_checks(
6565
if (k4 < K4 && n4_start + n4 < N4) {
6666
tile.data[k][n4] = load_packed_weight_x4(n4_start + n4, dk, k4, b, K4, N4);
6767
} else {
68-
tile.data[k][n4] = VEC4_T(0);
68+
tile.data[k][n4] = LINEAR_FP_WEIGHT_TILE_VEC4_T(0);
6969
}
7070
}
7171
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
* Macro Settings:
1111
* - TILE_K
1212
* - TILE_N4
13+
*
14+
* Optional:
15+
* - LINEAR_FP_WEIGHT_TILE_VEC4_T — weight tile vec4 type (default: VEC4_T).
1316
*/
1417

1518
#ifndef LINEAR_FP_WEIGHT_TILE_GLSLH
@@ -19,8 +22,13 @@
1922

2023
#include "common.glslh"
2124

25+
#ifndef LINEAR_FP_WEIGHT_TILE_VEC4_T
26+
#define LINEAR_FP_WEIGHT_TILE_VEC4_T VEC4_T
27+
#define LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT
28+
#endif
29+
2230
struct FPWeightTile {
23-
VEC4_T data[TILE_K][TILE_N4];
31+
LINEAR_FP_WEIGHT_TILE_VEC4_T data[TILE_K][TILE_N4];
2432
};
2533

2634
#ifdef DEBUG_MODE

0 commit comments

Comments
 (0)