Skip to content

Commit 616c756

Browse files
author
ssjia
committed
[ET-VK][qconv] Add software fallback for dotPacked4x8AccSatEXT in q8ta_ shaders
Devices that lack VK_KHR_shader_integer_dot_product (older GPUs, emulators) currently fail with ShaderNotSupportedError when running int8-quantized conv2d/linear because the q8ta_ shaders unconditionally require GL_EXT_integer_dot_product. This adds fallback SPIR-V variants that use a pure-GLSL software implementation so those devices can still execute the operators at a performance cost. Approach: compile-time macro USE_INT8_DOT_PRODUCT_EXT selects the implementation. Each affected YAML file gains a *_fallback shader variant compiled with USE_INT8_DOT_PRODUCT_EXT=0. At C++ dispatch time, adapter_ptr()->supports_int8_dot_product() picks the matching variant. Changes: - common.glslh: add dotPacked4x8Acc_fallback() and dotPacked4x8AccSat() dispatch macro - linear_fp_output_tile_int8_int8_compute.glslh: guard extension + use macro - q8ta_conv2d/pw/linear/linear_gemv .glsl: inject USE_INT8_DOT_PRODUCT_EXT template define, guard extension, replace direct EXT calls with macro - q8ta_conv2d/pw/linear/linear_gemv .yaml: add USE_INT8_DOT_PRODUCT_EXT parameter and *_fallback shader variants - Q8taConv2d/PW/Linear/LinearGemv .cpp: call supports_int8_dot_product() to select hardware vs. fallback variant at runtime Differential Revision: [D94314256](https://our.internmc.facebook.com/intern/diff/D94314256/) [ghstack-poisoned]
1 parent bace537 commit 616c756

15 files changed

Lines changed: 70 additions & 20 deletions

backends/vulkan/runtime/graph/ops/glsl/common.glslh

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ int extract_8bit_from_packed_int_le(const int packed, const int i) {
5151

5252
ivec4 unpack_int8x4(const int packed) {
5353
return ivec4(
54-
extract_8bit_from_packed_int_le(packed, 0),
55-
extract_8bit_from_packed_int_le(packed, 1),
56-
extract_8bit_from_packed_int_le(packed, 2),
57-
extract_8bit_from_packed_int_le(packed, 3));
54+
bitfieldExtract(packed, 0, 8),
55+
bitfieldExtract(packed, 8, 8),
56+
bitfieldExtract(packed, 16, 8),
57+
bitfieldExtract(packed, 24, 8));
5858
}
5959

6060
int pack_4xqint_into_int32(
@@ -89,6 +89,24 @@ int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) {
8989
return pack_into_int32(quantized);
9090
}
9191

92+
// Software fallback for dotPacked4x8AccSatEXT when GL_EXT_integer_dot_product
93+
// is unavailable. Saturation is omitted: for typical neural network inputs,
94+
// int32 overflow does not occur in practice.
95+
int dotPacked4x8Acc_fallback(const int a, const int b, const int acc) {
96+
const vec4 fa = vec4(unpack_int8x4(a));
97+
const vec4 fb = vec4(unpack_int8x4(b));
98+
return acc + int(dot(fa, fb));
99+
}
100+
101+
// Dispatch macro resolved at GLSL compile time by USE_INT8_DOT_PRODUCT_EXT.
102+
// When USE_INT8_DOT_PRODUCT_EXT == 0, uses the software fallback.
103+
// All other cases (flag=1 or undefined) use the hardware intrinsic.
104+
#if defined(USE_INT8_DOT_PRODUCT_EXT) && USE_INT8_DOT_PRODUCT_EXT == 0
105+
#define dotPacked4x8AccSat(a, b, acc) dotPacked4x8Acc_fallback(a, b, acc)
106+
#else
107+
#define dotPacked4x8AccSat(a, b, acc) dotPacked4x8AccSatEXT(a, b, acc)
108+
#endif
109+
92110
#ifdef DEBUG_MODE
93111

94112
#define printf debugPrintfEXT

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
#define LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH
1919

2020
#extension GL_EXT_control_flow_attributes : require
21+
#if !defined(USE_INT8_DOT_PRODUCT_EXT) || USE_INT8_DOT_PRODUCT_EXT != 0
2122
#extension GL_EXT_integer_dot_product : require
23+
#endif
2224

2325
#include "linear_common.glslh"
2426
#include "linear_fp_output_tile.glslh"
@@ -50,7 +52,7 @@ void int_accumulate_with_int8_weight(
5052
const int n4 = div_4(n);
5153
const int n4i = mod_4(n);
5254
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
53-
accum.data[m][n4][n4i] = dotPacked4x8AccSatEXT(
55+
accum.data[m][n4][n4i] = dotPacked4x8AccSat(
5456
in_tile.data[m4][k4][m4i],
5557
w_tile.data[k4][n4][n4i],
5658
accum.data[m][n4][n4i]);

backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111
${define_required_extensions("buffer", DTYPE)}
1212

13+
#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT}
14+
1315
#extension GL_EXT_control_flow_attributes : require
14-
#extension GL_EXT_integer_dot_product : require
16+
$if USE_INT8_DOT_PRODUCT_EXT == 1:
17+
#extension GL_EXT_integer_dot_product : require
1518

1619
#define PRECISION ${PRECISION}
1720
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
@@ -177,7 +180,7 @@ void main() {
177180
// Accumulate using packed int8 dot product for each output channel
178181
// dotPacked4x8AccSatEXT computes: acc + dot(unpack(a), unpack(b))
179182
[[unroll]] for (int oc_offset = 0; oc_offset < 4; ++oc_offset) {
180-
acc[subtile_w][oc_offset] = dotPacked4x8AccSatEXT(
183+
acc[subtile_w][oc_offset] = dotPacked4x8AccSat(
181184
packed_input,
182185
weight_block[oc_offset],
183186
acc[subtile_w][oc_offset]);

backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
q8ta_conv2d:
88
parameter_names_with_default_values:
99
DTYPE: float
10+
USE_INT8_DOT_PRODUCT_EXT: 1
1011
generate_variant_forall:
1112
DTYPE:
1213
- VALUE: float
1314
shader_variants:
1415
- NAME: q8ta_conv2d
16+
- NAME: q8ta_conv2d_fallback
17+
USE_INT8_DOT_PRODUCT_EXT: 0

backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111
${define_required_extensions("buffer", DTYPE)}
1212

13+
#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT}
14+
1315
#extension GL_EXT_control_flow_attributes : require
14-
#extension GL_EXT_integer_dot_product : require
16+
$if USE_INT8_DOT_PRODUCT_EXT == 1:
17+
#extension GL_EXT_integer_dot_product : require
1518

1619
#define PRECISION ${PRECISION}
1720
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
@@ -146,7 +149,7 @@ void main() {
146149
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
147150
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
148151
[[unroll]] for (int n4i = 0; n4i < 4; ++n4i) {
149-
out_accum[m][n4][n4i] = dotPacked4x8AccSatEXT(
152+
out_accum[m][n4][n4i] = dotPacked4x8AccSat(
150153
int8_input_tile[m],
151154
int8_weight_tile[n4][n4i],
152155
out_accum[m][n4][n4i]);

backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
q8ta_conv2d_pw:
88
parameter_names_with_default_values:
99
DTYPE: float
10+
USE_INT8_DOT_PRODUCT_EXT: 1
1011
generate_variant_forall:
1112
DTYPE:
1213
- VALUE: float
1314
shader_variants:
1415
- NAME: q8ta_conv2d_pw
16+
- NAME: q8ta_conv2d_pw_fallback
17+
USE_INT8_DOT_PRODUCT_EXT: 0

backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111
${define_required_extensions("buffer", DTYPE)}
1212

13+
#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT}
14+
1315
#extension GL_EXT_control_flow_attributes : require
14-
#extension GL_EXT_integer_dot_product : require
16+
$if USE_INT8_DOT_PRODUCT_EXT == 1:
17+
#extension GL_EXT_integer_dot_product : require
1518

1619
#define PRECISION ${PRECISION}
1720
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}

backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ q8ta_linear:
1111
TILE_M4: 1
1212
TILE_N4: 2
1313
TILE_K4: 1
14+
USE_INT8_DOT_PRODUCT_EXT: 1
1415
generate_variant_forall:
1516
DTYPE:
1617
- VALUE: float
1718
shader_variants:
1819
- NAME: q8ta_linear
20+
- NAME: q8ta_linear_fallback
21+
USE_INT8_DOT_PRODUCT_EXT: 0

backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111
${define_required_extensions("buffer", DTYPE)}
1212

13+
#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT}
14+
1315
#extension GL_EXT_control_flow_attributes : require
14-
#extension GL_EXT_integer_dot_product : require
16+
$if USE_INT8_DOT_PRODUCT_EXT == 1:
17+
#extension GL_EXT_integer_dot_product : require
1518

1619
#define PRECISION ${PRECISION}
1720
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
@@ -94,7 +97,7 @@ void main() {
9497
[[unroll]] for (int n = 0; n < TILE_N; ++n) {
9598
const int tile_n4 = div_4(n);
9699
const int n4i = mod_4(n);
97-
out_accum.data[0][tile_n4][n4i] = dotPacked4x8AccSatEXT(
100+
out_accum.data[0][tile_n4][n4i] = dotPacked4x8AccSat(
98101
packed_input,
99102
int8_weight_tile.data[0][tile_n4][n4i],
100103
out_accum.data[0][tile_n4][n4i]);

backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ q8ta_linear_gemv:
1111
TILE_K4: 1
1212
TILE_N4: 2
1313
WGS: 64
14+
USE_INT8_DOT_PRODUCT_EXT: 1
1415
generate_variant_forall:
1516
DTYPE:
1617
- VALUE: float
1718
shader_variants:
1819
- NAME: q8ta_linear_gemv
20+
- NAME: q8ta_linear_gemv_fallback
21+
USE_INT8_DOT_PRODUCT_EXT: 0

0 commit comments

Comments
 (0)