Skip to content

Commit 67ff1b8

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][qconv] Add layout-agnostic general shader for quantized conv
Pull Request resolved: pytorch#17219 The existing quantized conv2d implementation (`conv2d_q8ta_q8csw_q8to`) only supports the 4W4C memory layout. This limits its use when models require different tensor layouts. This change introduces a new general-purpose quantized conv2d shader (`q8ta_conv2d`) that works with any memory layout by using BufferMetadata for tensor indexing. The routing logic determines which implementation to use based on input/output layouts: when both are 4W4C, the existing optimized path is used; otherwise, the new general shader handles the computation. This enables quantized conv2d to work seamlessly across 4C1W, 4W4C, and 4C memory layouts. Key changes: - New GLSL shader `q8ta_conv2d.glsl` using layout specialization constants - New `Q8taConv2d.cpp` with operator registration and workgroup size heuristics - Refactored routing in QuantizedConvolution.cpp to dispatch based on layout - Extended test coverage to validate all three memory layouts Authored with Claude. ghstack-source-id: 338638545 @exported-using-ghexport Differential Revision: [D92307252](https://our.internmc.facebook.com/intern/diff/D92307252/)
1 parent c415293 commit 67ff1b8

10 files changed

Lines changed: 952 additions & 138 deletions

File tree

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
${define_required_extensions("buffer", DTYPE)}
12+
13+
#extension GL_EXT_control_flow_attributes : require
14+
#extension GL_EXT_integer_dot_product : require
15+
16+
#define PRECISION ${PRECISION}
17+
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
18+
#define T ${texel_load_component_type(DTYPE, "buffer")}
19+
20+
${define_active_storage_type("buffer")}
21+
22+
layout(std430) buffer;
23+
24+
#include "indexing.glslh"
25+
#include "common.glslh"
26+
#include "conv2d_common.glslh"
27+
28+
${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)}
29+
${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)}
30+
${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", "texture2d", is_scalar_array=False)}
31+
${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)}
32+
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
33+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
34+
35+
// Metadata for input/output tensors (memory layout agnostic)
36+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
37+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
38+
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
39+
40+
layout(push_constant) uniform restrict Block {
41+
float input_scale;
42+
int input_zp;
43+
float output_inv_scale;
44+
int output_zp;
45+
};
46+
47+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
48+
49+
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
50+
51+
// Layout specialization constants
52+
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
53+
${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")}
54+
55+
// Load weight block for a given (ic4, kx, ky, oc4) position.
56+
// Weight texture layout (from pack_q8_conv2d_weights.glsl):
57+
// block_x = oc4 * K_w + kx
58+
// block_y = ky * IC4 + ic4
59+
// Each texel ivec4 has 4 components (4 output channels), each component is
60+
// a packed int32 containing 4 int8 values for 4 consecutive input channels.
61+
ivec4 load_weight_block(int ic4, int kx, int ky, int oc4, int IC4, int KW) {
62+
const int block_x = oc4 * KW + kx;
63+
const int block_y = ky * IC4 + ic4;
64+
return texelFetch(t_packed_int8_weight, ivec2(block_x, block_y), 0);
65+
}
66+
67+
ivec4 quantize(const vec4 texel, const float inv_scale, const int zp) {
68+
vec4 quantized = round(texel * inv_scale) + zp;
69+
return clamp(ivec4(quantized), -128, 127);
70+
}
71+
72+
void main() {
73+
// Thread mapping
74+
int oc4 = int(gl_GlobalInvocationID.z);
75+
int w4 = int(gl_GlobalInvocationID.x);
76+
77+
// Initialize output tensor index (WHCN order)
78+
// Each thread handles 4 adjacent widths starting at base_out_w
79+
TensorIndex4D outp_tidx;
80+
outp_tidx.data[0] = w4 * 4;
81+
outp_tidx.data[1] = int(gl_GlobalInvocationID.y);
82+
outp_tidx.data[2] = oc4 * 4;
83+
outp_tidx.data[3] = 0;
84+
85+
const int W = int(outp.sizes[0][0]);
86+
const int OC = int(outp.sizes[0][2]);
87+
const int OC4 = int(div_up_4(OC));
88+
89+
// Bounds check
90+
if (any(greaterThanEqual(outp_tidx.data, ivec4(outp.sizes[0])))) {
91+
return;
92+
}
93+
94+
// Input dimensions
95+
const int inp_W = int(inp.sizes[0][0]);
96+
const int inp_H = int(inp.sizes[0][1]);
97+
const int IC = int(inp.sizes[0][2]);
98+
99+
// Compute channels per group
100+
const int OC_per_group = OC / conv2d_params.groups;
101+
const int IC_per_group = IC / conv2d_params.groups;
102+
const int IC4_per_group = div_up_4(IC_per_group);
103+
104+
// Determine which group this output channel block belongs to
105+
const int group_idx = outp_tidx.data[2] / OC_per_group;
106+
const int ic_group_start = group_idx * IC_per_group;
107+
108+
// Get strides for efficient indexing
109+
const int inp_w_stride = int(inp.strides[0][0]);
110+
const int inp_h_stride = int(inp.strides[0][1]);
111+
const int inp_c_stride = int(inp.strides[0][2]);
112+
const int w_texel_step = conv2d_params.dilation.x * inp_w_stride;
113+
const int h_texel_step = conv2d_params.dilation.y * inp_h_stride;
114+
const int subtile_w_step = conv2d_params.stride.x * inp_w_stride;
115+
116+
// Compute base input position (for subtile_w=0, ic4=0)
117+
TensorIndex4D inp_tidx;
118+
inp_tidx.data[0] = outp_tidx.data[0] * conv2d_params.stride.x - conv2d_params.padding.x;
119+
inp_tidx.data[1] = outp_tidx.data[1] * conv2d_params.stride.y - conv2d_params.padding.y;
120+
inp_tidx.data[2] = ic_group_start;
121+
inp_tidx.data[3] = 0;
122+
123+
int base_inp_texel_idx;
124+
if (get_outer_packed_dim_block_size(inp_layout) == 1) {
125+
base_inp_texel_idx = tensor4d_idx_to_texel_idx(inp, inp_tidx, inp_layout);
126+
}
127+
128+
// Store base width to reset at beginning of each loop
129+
const int base_inp_w = inp_tidx.data[0];
130+
131+
// Create packed input zero point (4 copies of input_zp packed into int32)
132+
const int input_zp_packed = pack_into_int32(ivec4(input_zp));
133+
134+
// Initialize accumulators for 4 width positions × 4 output channels each
135+
ivec4 acc[4];
136+
[[unroll]] for (int i = 0; i < 4; ++i) {
137+
acc[i] = ivec4(0);
138+
}
139+
140+
// Perform convolution using packed int8 dot products
141+
for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
142+
const bool h_in_bounds = (inp_tidx.data[1] >= 0 && inp_tidx.data[1] < inp_H);
143+
144+
// Process input channels in blocks of 4
145+
for (int ic4 = 0; ic4 < IC4_per_group; ic4++) {
146+
// Input channel index for this block (base channel of the 4-channel block)
147+
inp_tidx.data[2] = ic_group_start + ic4 * 4;
148+
149+
// Reset width coordinate at start of each ic4 iteration
150+
inp_tidx.data[0] = base_inp_w;
151+
152+
for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) {
153+
// Load weight block: 4 output channels × 4 input channels
154+
// weight_block[oc] contains packed weights for ic4*4 to ic4*4+3 -> oc
155+
const ivec4 weight_block = load_weight_block(ic4, kx, ky, oc4, IC4_per_group, conv2d_params.kernel_size.x);
156+
157+
// Process 4 adjacent width positions
158+
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
159+
// Load packed input (4 consecutive channels packed into one int32)
160+
// Use input_zp_packed for out-of-bounds positions
161+
int packed_input = input_zp_packed;
162+
if (h_in_bounds && inp_tidx.data[0] >= 0 && inp_tidx.data[0] < inp_W) {
163+
// Compute input texel index using base + offsets
164+
int inp_texel_idx;
165+
if (get_outer_packed_dim_block_size(inp_layout) == 1) {
166+
inp_texel_idx = base_inp_texel_idx + ic4 * inp_c_stride + kx * w_texel_step + subtile_w * subtile_w_step;
167+
} else {
168+
// inp_texel_idx = tensor4d_idx_to_texel_idx(inp, inp_tidx, inp_layout);
169+
const int w4 = div_4(inp_tidx.data[0]);
170+
const int inp_c4 = div_4(inp_tidx.data[2]);
171+
inp_texel_idx = (inp_tidx.data[1] * inp_h_stride + w4 * inp_w_stride + inp_c4) * 4 + mod_4(inp_tidx.data[0]);
172+
}
173+
packed_input = t_packed_int8_input[inp_texel_idx];
174+
}
175+
176+
// Accumulate using packed int8 dot product for each output channel
177+
// dotPacked4x8AccSatEXT computes: acc + dot(unpack(a), unpack(b))
178+
[[unroll]] for (int oc_offset = 0; oc_offset < 4; ++oc_offset) {
179+
acc[subtile_w][oc_offset] = dotPacked4x8AccSatEXT(
180+
packed_input,
181+
weight_block[oc_offset],
182+
acc[subtile_w][oc_offset]);
183+
}
184+
185+
// Advance to next output position's input coordinate
186+
inp_tidx.data[0] += conv2d_params.stride.x;
187+
}
188+
189+
// Adjust for net dilation step
190+
inp_tidx.data[0] += conv2d_params.dilation.x - 4 * conv2d_params.stride.x;
191+
}
192+
}
193+
194+
// Advance height by dilation for next kernel row
195+
inp_tidx.data[1] += conv2d_params.dilation.y;
196+
197+
if (get_outer_packed_dim_block_size(inp_layout) == 1) {
198+
// Advance base index by height step for next kernel row
199+
base_inp_texel_idx += h_texel_step;
200+
}
201+
}
202+
203+
// Apply input zero point correction via weight_sums
204+
const vec4 weight_sums = vec4(t_weight_sums[oc4]);
205+
const vec4 weight_scales = vec4(t_weight_scales[oc4]);
206+
207+
// Convert to float, apply dequantization, and optionally add bias
208+
vec4 facc[4];
209+
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
210+
facc[subtile_w] = vec4(acc[subtile_w]);
211+
facc[subtile_w] -= weight_sums * input_zp;
212+
facc[subtile_w] *= weight_scales * input_scale;
213+
}
214+
215+
// Apply bias if enabled
216+
if (apply_bias > 0) {
217+
const vec4 bias = vec4(t_bias[oc4]);
218+
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
219+
facc[subtile_w] += bias;
220+
}
221+
}
222+
223+
// Compute base output texel index (for subtile_w=0)
224+
const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout);
225+
const int out_w_stride = int(outp.strides[0][0]);
226+
227+
// Quantize and store outputs using stride offsets
228+
[[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) {
229+
// Skip out-of-bounds width positions
230+
if (outp_tidx.data[0] >= W) {
231+
continue;
232+
}
233+
234+
const ivec4 quantized_out = quantize(facc[subtile_w], output_inv_scale, output_zp);
235+
const int packed_out = pack_into_int32(quantized_out);
236+
237+
// Store using stride offset from base
238+
int outp_texel_idx;
239+
if (get_outer_packed_dim_block_size(outp_layout) == 1) {
240+
outp_texel_idx = base_outp_texel_idx + subtile_w * out_w_stride;
241+
} else {
242+
outp_texel_idx = base_outp_texel_idx + subtile_w;
243+
}
244+
245+
t_packed_int8_output[outp_texel_idx] = packed_out;
246+
247+
outp_tidx.data[0] += 1;
248+
}
249+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q8ta_conv2d:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
generate_variant_forall:
11+
DTYPE:
12+
- VALUE: float
13+
shader_variants:
14+
- NAME: q8ta_conv2d

0 commit comments

Comments
 (0)