Skip to content

Commit 226a7b3

Browse files
authored
[ET-VK] Update fused SDPA operator to support ViT attention
Differential Revision: D102360200 Pull Request resolved: #19114
1 parent 2458318 commit 226a7b3

25 files changed

Lines changed: 1442 additions & 470 deletions

backends/vulkan/custom_ops_lib.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,34 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
960960
lib.impl(name, select_as_symint_impl, "Meta")
961961
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)
962962

963+
##########
964+
## sdpa ##
965+
##########
966+
967+
968+
def sdpa_impl(
969+
q: torch.Tensor,
970+
k: torch.Tensor,
971+
v: torch.Tensor,
972+
attn_mask: Optional[torch.Tensor] = None,
973+
scale: Optional[float] = None,
974+
):
975+
if scale is None:
976+
scale = 1.0 / (q.size(-1) ** 0.5)
977+
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
978+
if attn_mask is not None:
979+
attn = attn + attn_mask
980+
attn = torch.softmax(attn, dim=-1)
981+
return torch.matmul(attn, v)
982+
983+
984+
name = "sdpa"
985+
lib.define(
986+
f"{name}(Tensor q, Tensor k, Tensor v, Tensor? attn_mask = None, float? scale = None) -> Tensor"
987+
)
988+
lib.impl(name, sdpa_impl, "CompositeExplicitAutograd")
989+
sdpa_op = getattr(getattr(torch.ops, namespace), name)
990+
963991
################
964992
## rms_norm ##
965993
################

backends/vulkan/op_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,20 @@ def register_sdpa_cpp_ops():
10711071
)
10721072

10731073

1074+
# =============================================================================
1075+
# SDPA.cpp (fused SDPA entry point)
1076+
# =============================================================================
1077+
1078+
1079+
@update_features("et_vk::sdpa")
1080+
def register_general_sdpa():
1081+
return OpFeatures(
1082+
inputs_storage=utils.CONTIGUOUS_ANY,
1083+
inputs_dtypes=utils.FP_T,
1084+
supports_resize=True,
1085+
)
1086+
1087+
10741088
# =============================================================================
10751089
# RotaryEmbedding.cpp
10761090
# =============================================================================

backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,19 @@
1313
* Macro Settings:
1414
* - TILE_M
1515
* - TILE_K4
16+
*
17+
* Optional:
18+
* - LINEAR_FP_INPUT_TILE_VEC4_T — input tile vec4 type (default: VEC4_T).
1619
*/
1720

1821
#extension GL_EXT_control_flow_attributes : require
1922

23+
#ifndef LINEAR_FP_INPUT_TILE_VEC4_T
24+
#define LINEAR_FP_INPUT_TILE_VEC4_T VEC4_T
25+
#endif
26+
2027
struct FPInputTile {
21-
VEC4_T data[TILE_M][TILE_K4];
28+
LINEAR_FP_INPUT_TILE_VEC4_T data[TILE_M][TILE_K4];
2229
};
2330

2431
#ifdef DEBUG_MODE

backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121

2222
#include "linear_fp_input_tile.glslh"
2323

24-
VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) {
24+
LINEAR_FP_INPUT_TILE_VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) {
2525
#ifdef INPUT_BUFFER
26-
return t_input[(m * ntexels_k) + k4];
26+
return LINEAR_FP_INPUT_TILE_VEC4_T(t_input[(m * ntexels_k) + k4]);
2727
#else
28-
return texelFetch(t_input, ivec3(k4, m, 0), 0);
28+
return LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_input, ivec3(k4, m, 0), 0));
2929
#endif
3030
}
3131

@@ -53,7 +53,7 @@ void load_input_tile_with_checks(
5353
if (m_start + m < M && k4_start + k4 < K4) {
5454
in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4);
5555
} else {
56-
in_tile.data[m][k4] = VEC4_T(0.0);
56+
in_tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0);
5757
}
5858
}
5959
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,31 @@
1010
* Macro Settings:
1111
* - TILE_M
1212
* - TILE_N4
13+
*
14+
* Optional:
15+
* - LINEAR_FP_OUTPUT_TILE_VEC4_T — accumulator vec4 type (default: VEC4_T).
16+
* Set this to `vec4` to force fp32 accumulation regardless of DTYPE; used
17+
* by fused SDPA QK to avoid fp16 overflow in Q@K^T.
1318
*/
1419

1520
#ifndef LINEAR_FP_OUTPUT_TILE_GLSLH
1621
#define LINEAR_FP_OUTPUT_TILE_GLSLH
1722

1823
#extension GL_EXT_control_flow_attributes : require
1924

25+
#ifndef LINEAR_FP_OUTPUT_TILE_VEC4_T
26+
#define LINEAR_FP_OUTPUT_TILE_VEC4_T VEC4_T
27+
#define LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT
28+
#endif
29+
2030
struct FPOutTile {
21-
VEC4_T data[TILE_M][TILE_N4];
31+
LINEAR_FP_OUTPUT_TILE_VEC4_T data[TILE_M][TILE_N4];
2232
};
2333

2434
void initialize(out FPOutTile out_tile) {
2535
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
2636
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
27-
out_tile.data[m][n4] = VEC4_T(0);
37+
out_tile.data[m][n4] = LINEAR_FP_OUTPUT_TILE_VEC4_T(0);
2838
}
2939
}
3040
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
#include "linear_fp_per_out_channel_params.glslh"
2222
#include "linear_fp_weight_tile.glslh"
2323

24+
#if defined(LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT) == defined(LINEAR_FP_OUTPUT_TILE_VEC4_T_IS_DEFAULT)
25+
#define MAYBE_CAST_WVEC4(x) (x)
26+
#else
27+
#define MAYBE_CAST_WVEC4(x) LINEAR_FP_OUTPUT_TILE_VEC4_T(x)
28+
#endif
29+
2430
void fp_accumulate_with_fp_weight(
2531
inout FPOutTile accum,
2632
FPInputTile in_tile,
@@ -29,23 +35,23 @@ void fp_accumulate_with_fp_weight(
2935
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
3036
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
3137
accum.data[m][n4] =
32-
fma(VEC4_T(in_tile.data[m][k4][0]),
33-
w_tile.data[mul_4(k4)][n4],
38+
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][0]),
39+
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4)][n4]),
3440
accum.data[m][n4]);
3541

3642
accum.data[m][n4] =
37-
fma(VEC4_T(in_tile.data[m][k4][1]),
38-
w_tile.data[mul_4(k4) + 1][n4],
43+
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][1]),
44+
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 1][n4]),
3945
accum.data[m][n4]);
4046

4147
accum.data[m][n4] =
42-
fma(VEC4_T(in_tile.data[m][k4][2]),
43-
w_tile.data[mul_4(k4) + 2][n4],
48+
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][2]),
49+
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 2][n4]),
4450
accum.data[m][n4]);
4551

4652
accum.data[m][n4] =
47-
fma(VEC4_T(in_tile.data[m][k4][3]),
48-
w_tile.data[mul_4(k4) + 3][n4],
53+
fma(LINEAR_FP_OUTPUT_TILE_VEC4_T(in_tile.data[m][k4][3]),
54+
MAYBE_CAST_WVEC4(w_tile.data[mul_4(k4) + 3][n4]),
4955
accum.data[m][n4]);
5056
}
5157
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
#include "linear_fp_output_tile.glslh"
2626

2727
void write_output_x4(
28-
const VEC4_T out_texel,
28+
const LINEAR_FP_OUTPUT_TILE_VEC4_T out_texel,
2929
const int n4,
3030
const int m,
3131
const int N4) {
3232
#ifdef OUTPUT_BUFFER
33-
t_output[m * N4 + n4] = out_texel;
33+
t_output[m * N4 + n4] = VEC4_T(out_texel);
3434
#else
35-
imageStore(t_output, ivec3(n4, m, 0), out_texel);
35+
imageStore(t_output, ivec3(n4, m, 0), VEC4_T(out_texel));
3636
#endif
3737
}
3838

backends/vulkan/runtime/graph/ops/glsl/linear_fp_packed_weight_tile_load.glslh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323

2424
#include "linear_fp_weight_tile.glslh"
2525

26-
VEC4_T load_packed_weight_x4(
26+
LINEAR_FP_WEIGHT_TILE_VEC4_T load_packed_weight_x4(
2727
const int n4, const int dk, const int k4, const int b, const int K4, const int N4) {
2828
#ifdef WEIGHT_BUFFER
29-
return t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk];
29+
return LINEAR_FP_WEIGHT_TILE_VEC4_T(t_weight_packed[((b * K4 + k4) * N4 + n4) * 4 + dk]);
3030
#else
31-
return VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0));
31+
return LINEAR_FP_WEIGHT_TILE_VEC4_T(texelFetch(t_weight_packed, ivec2(n4 * 4 + dk, b * K4 + k4), 0));
3232
#endif
3333
}
3434

@@ -65,7 +65,7 @@ void load_packed_weight_tile_with_checks(
6565
if (k4 < K4 && n4_start + n4 < N4) {
6666
tile.data[k][n4] = load_packed_weight_x4(n4_start + n4, dk, k4, b, K4, N4);
6767
} else {
68-
tile.data[k][n4] = VEC4_T(0);
68+
tile.data[k][n4] = LINEAR_FP_WEIGHT_TILE_VEC4_T(0);
6969
}
7070
}
7171
}

backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
* Macro Settings:
1111
* - TILE_K
1212
* - TILE_N4
13+
*
14+
* Optional:
15+
* - LINEAR_FP_WEIGHT_TILE_VEC4_T — weight tile vec4 type (default: VEC4_T).
1316
*/
1417

1518
#ifndef LINEAR_FP_WEIGHT_TILE_GLSLH
@@ -19,8 +22,13 @@
1922

2023
#include "common.glslh"
2124

25+
#ifndef LINEAR_FP_WEIGHT_TILE_VEC4_T
26+
#define LINEAR_FP_WEIGHT_TILE_VEC4_T VEC4_T
27+
#define LINEAR_FP_WEIGHT_TILE_VEC4_T_IS_DEFAULT
28+
#endif
29+
2230
struct FPWeightTile {
23-
VEC4_T data[TILE_K][TILE_N4];
31+
LINEAR_FP_WEIGHT_TILE_VEC4_T data[TILE_K][TILE_N4];
2432
};
2533

2634
#ifdef DEBUG_MODE

0 commit comments

Comments
 (0)