|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#version 450 core |
| 10 | + |
| 11 | +${define_required_extensions("buffer", DTYPE)} |
| 12 | + |
| 13 | +#extension GL_EXT_control_flow_attributes : require |
| 14 | +#extension GL_EXT_integer_dot_product : require |
| 15 | + |
| 16 | +#define PRECISION ${PRECISION} |
| 17 | +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} |
| 18 | +#define T ${texel_load_component_type(DTYPE, "buffer")} |
| 19 | + |
| 20 | +${define_active_storage_type("buffer")} |
| 21 | + |
| 22 | +layout(std430) buffer; |
| 23 | + |
| 24 | +#include "indexing.glslh" |
| 25 | +#include "common.glslh" |
| 26 | +#include "conv2d_common.glslh" |
| 27 | + |
| 28 | +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)} |
| 29 | +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} |
| 30 | +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", "texture2d", is_scalar_array=False)} |
| 31 | +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} |
| 32 | +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} |
| 33 | +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} |
| 34 | + |
| 35 | +// Metadata for input/output tensors (memory layout agnostic) |
| 36 | +${layout_declare_ubo(B, "BufferMetadata", "outp")} |
| 37 | +${layout_declare_ubo(B, "BufferMetadata", "inp")} |
| 38 | +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} |
| 39 | + |
| 40 | +layout(push_constant) uniform restrict Block { |
| 41 | + float input_scale; |
| 42 | + int input_zp; |
| 43 | + float output_inv_scale; |
| 44 | + int output_zp; |
| 45 | +}; |
| 46 | + |
| 47 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 48 | + |
| 49 | +${layout_declare_spec_const(C, "int", "apply_bias", "1")} |
| 50 | + |
| 51 | +// Layout specialization constants |
| 52 | +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} |
| 53 | +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} |
| 54 | + |
| 55 | +// Load weight block for a given (ic4, kx, ky, oc4) position. |
| 56 | +// Weight texture layout (from pack_q8_conv2d_weights.glsl): |
| 57 | +// block_x = oc4 * K_w + kx |
| 58 | +// block_y = ky * IC4 + ic4 |
| 59 | +// Each texel ivec4 has 4 components (4 output channels), each component is |
| 60 | +// a packed int32 containing 4 int8 values for 4 consecutive input channels. |
| 61 | +ivec4 load_weight_block(int ic4, int kx, int ky, int oc4, int IC4, int KW) { |
| 62 | + const int block_x = oc4 * KW + kx; |
| 63 | + const int block_y = ky * IC4 + ic4; |
| 64 | + return texelFetch(t_packed_int8_weight, ivec2(block_x, block_y), 0); |
| 65 | +} |
| 66 | + |
| 67 | +ivec4 quantize(const vec4 texel, const float inv_scale, const int zp) { |
| 68 | + vec4 quantized = round(texel * inv_scale) + zp; |
| 69 | + return clamp(ivec4(quantized), -128, 127); |
| 70 | +} |
| 71 | + |
| 72 | +void main() { |
| 73 | + // Thread mapping |
| 74 | + int oc4 = int(gl_GlobalInvocationID.z); |
| 75 | + int w4 = int(gl_GlobalInvocationID.x); |
| 76 | + |
| 77 | + // Initialize output tensor index (WHCN order) |
| 78 | + // Each thread handles 4 adjacent widths starting at base_out_w |
| 79 | + TensorIndex4D outp_tidx; |
| 80 | + outp_tidx.data[0] = w4 * 4; |
| 81 | + outp_tidx.data[1] = int(gl_GlobalInvocationID.y); |
| 82 | + outp_tidx.data[2] = oc4 * 4; |
| 83 | + outp_tidx.data[3] = 0; |
| 84 | + |
| 85 | + const int W = int(outp.sizes[0][0]); |
| 86 | + const int OC = int(outp.sizes[0][2]); |
| 87 | + const int OC4 = int(div_up_4(OC)); |
| 88 | + |
| 89 | + // Bounds check |
| 90 | + if (any(greaterThanEqual(outp_tidx.data, ivec4(outp.sizes[0])))) { |
| 91 | + return; |
| 92 | + } |
| 93 | + |
| 94 | + // Input dimensions |
| 95 | + const int inp_W = int(inp.sizes[0][0]); |
| 96 | + const int inp_H = int(inp.sizes[0][1]); |
| 97 | + const int IC = int(inp.sizes[0][2]); |
| 98 | + |
| 99 | + // Compute channels per group |
| 100 | + const int OC_per_group = OC / conv2d_params.groups; |
| 101 | + const int IC_per_group = IC / conv2d_params.groups; |
| 102 | + const int IC4_per_group = div_up_4(IC_per_group); |
| 103 | + |
| 104 | + // Determine which group this output channel block belongs to |
| 105 | + const int group_idx = outp_tidx.data[2] / OC_per_group; |
| 106 | + const int ic_group_start = group_idx * IC_per_group; |
| 107 | + |
| 108 | + // Get strides for efficient indexing |
| 109 | + const int inp_w_stride = int(inp.strides[0][0]); |
| 110 | + const int inp_h_stride = int(inp.strides[0][1]); |
| 111 | + const int inp_c_stride = int(inp.strides[0][2]); |
| 112 | + const int w_texel_step = conv2d_params.dilation.x * inp_w_stride; |
| 113 | + const int h_texel_step = conv2d_params.dilation.y * inp_h_stride; |
| 114 | + const int subtile_w_step = conv2d_params.stride.x * inp_w_stride; |
| 115 | + |
| 116 | + // Compute base input position (for subtile_w=0, ic4=0) |
| 117 | + TensorIndex4D inp_tidx; |
| 118 | + inp_tidx.data[0] = outp_tidx.data[0] * conv2d_params.stride.x - conv2d_params.padding.x; |
| 119 | + inp_tidx.data[1] = outp_tidx.data[1] * conv2d_params.stride.y - conv2d_params.padding.y; |
| 120 | + inp_tidx.data[2] = ic_group_start; |
| 121 | + inp_tidx.data[3] = 0; |
| 122 | + |
| 123 | + int base_inp_texel_idx; |
| 124 | + if (get_outer_packed_dim_block_size(inp_layout) == 1) { |
| 125 | + base_inp_texel_idx = tensor4d_idx_to_texel_idx(inp, inp_tidx, inp_layout); |
| 126 | + } |
| 127 | + |
| 128 | + // Store base width to reset at beginning of each loop |
| 129 | + const int base_inp_w = inp_tidx.data[0]; |
| 130 | + |
| 131 | + // Create packed input zero point (4 copies of input_zp packed into int32) |
| 132 | + const int input_zp_packed = pack_into_int32(ivec4(input_zp)); |
| 133 | + |
| 134 | + // Initialize accumulators for 4 width positions × 4 output channels each |
| 135 | + ivec4 acc[4]; |
| 136 | + [[unroll]] for (int i = 0; i < 4; ++i) { |
| 137 | + acc[i] = ivec4(0); |
| 138 | + } |
| 139 | + |
| 140 | + // Perform convolution using packed int8 dot products |
| 141 | + for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { |
| 142 | + const bool h_in_bounds = (inp_tidx.data[1] >= 0 && inp_tidx.data[1] < inp_H); |
| 143 | + |
| 144 | + // Process input channels in blocks of 4 |
| 145 | + for (int ic4 = 0; ic4 < IC4_per_group; ic4++) { |
| 146 | + // Input channel index for this block (base channel of the 4-channel block) |
| 147 | + inp_tidx.data[2] = ic_group_start + ic4 * 4; |
| 148 | + |
| 149 | + // Reset width coordinate at start of each ic4 iteration |
| 150 | + inp_tidx.data[0] = base_inp_w; |
| 151 | + |
| 152 | + for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) { |
| 153 | + // Load weight block: 4 output channels × 4 input channels |
| 154 | + // weight_block[oc] contains packed weights for ic4*4 to ic4*4+3 -> oc |
| 155 | + const ivec4 weight_block = load_weight_block(ic4, kx, ky, oc4, IC4_per_group, conv2d_params.kernel_size.x); |
| 156 | + |
| 157 | + // Process 4 adjacent width positions |
| 158 | + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { |
| 159 | + // Load packed input (4 consecutive channels packed into one int32) |
| 160 | + // Use input_zp_packed for out-of-bounds positions |
| 161 | + int packed_input = input_zp_packed; |
| 162 | + if (h_in_bounds && inp_tidx.data[0] >= 0 && inp_tidx.data[0] < inp_W) { |
| 163 | + // Compute input texel index using base + offsets |
| 164 | + int inp_texel_idx; |
| 165 | + if (get_outer_packed_dim_block_size(inp_layout) == 1) { |
| 166 | + inp_texel_idx = base_inp_texel_idx + ic4 * inp_c_stride + kx * w_texel_step + subtile_w * subtile_w_step; |
| 167 | + } else { |
| 168 | + // inp_texel_idx = tensor4d_idx_to_texel_idx(inp, inp_tidx, inp_layout); |
| 169 | + const int w4 = div_4(inp_tidx.data[0]); |
| 170 | + const int inp_c4 = div_4(inp_tidx.data[2]); |
| 171 | + inp_texel_idx = (inp_tidx.data[1] * inp_h_stride + w4 * inp_w_stride + inp_c4) * 4 + mod_4(inp_tidx.data[0]); |
| 172 | + } |
| 173 | + packed_input = t_packed_int8_input[inp_texel_idx]; |
| 174 | + } |
| 175 | + |
| 176 | + // Accumulate using packed int8 dot product for each output channel |
| 177 | + // dotPacked4x8AccSatEXT computes: acc + dot(unpack(a), unpack(b)) |
| 178 | + [[unroll]] for (int oc_offset = 0; oc_offset < 4; ++oc_offset) { |
| 179 | + acc[subtile_w][oc_offset] = dotPacked4x8AccSatEXT( |
| 180 | + packed_input, |
| 181 | + weight_block[oc_offset], |
| 182 | + acc[subtile_w][oc_offset]); |
| 183 | + } |
| 184 | + |
| 185 | + // Advance to next output position's input coordinate |
| 186 | + inp_tidx.data[0] += conv2d_params.stride.x; |
| 187 | + } |
| 188 | + |
| 189 | + // Adjust for net dilation step |
| 190 | + inp_tidx.data[0] += conv2d_params.dilation.x - 4 * conv2d_params.stride.x; |
| 191 | + } |
| 192 | + } |
| 193 | + |
| 194 | + // Advance height by dilation for next kernel row |
| 195 | + inp_tidx.data[1] += conv2d_params.dilation.y; |
| 196 | + |
| 197 | + if (get_outer_packed_dim_block_size(inp_layout) == 1) { |
| 198 | + // Advance base index by height step for next kernel row |
| 199 | + base_inp_texel_idx += h_texel_step; |
| 200 | + } |
| 201 | + } |
| 202 | + |
| 203 | + // Apply input zero point correction via weight_sums |
| 204 | + const vec4 weight_sums = vec4(t_weight_sums[oc4]); |
| 205 | + const vec4 weight_scales = vec4(t_weight_scales[oc4]); |
| 206 | + |
| 207 | + // Convert to float, apply dequantization, and optionally add bias |
| 208 | + vec4 facc[4]; |
| 209 | + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { |
| 210 | + facc[subtile_w] = vec4(acc[subtile_w]); |
| 211 | + facc[subtile_w] -= weight_sums * input_zp; |
| 212 | + facc[subtile_w] *= weight_scales * input_scale; |
| 213 | + } |
| 214 | + |
| 215 | + // Apply bias if enabled |
| 216 | + if (apply_bias > 0) { |
| 217 | + const vec4 bias = vec4(t_bias[oc4]); |
| 218 | + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { |
| 219 | + facc[subtile_w] += bias; |
| 220 | + } |
| 221 | + } |
| 222 | + |
| 223 | + // Compute base output texel index (for subtile_w=0) |
| 224 | + const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); |
| 225 | + const int out_w_stride = int(outp.strides[0][0]); |
| 226 | + |
| 227 | + // Quantize and store outputs using stride offsets |
| 228 | + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { |
| 229 | + // Skip out-of-bounds width positions |
| 230 | + if (outp_tidx.data[0] >= W) { |
| 231 | + continue; |
| 232 | + } |
| 233 | + |
| 234 | + const ivec4 quantized_out = quantize(facc[subtile_w], output_inv_scale, output_zp); |
| 235 | + const int packed_out = pack_into_int32(quantized_out); |
| 236 | + |
| 237 | + // Store using stride offset from base |
| 238 | + int outp_texel_idx; |
| 239 | + if (get_outer_packed_dim_block_size(outp_layout) == 1) { |
| 240 | + outp_texel_idx = base_outp_texel_idx + subtile_w * out_w_stride; |
| 241 | + } else { |
| 242 | + outp_texel_idx = base_outp_texel_idx + subtile_w; |
| 243 | + } |
| 244 | + |
| 245 | + t_packed_int8_output[outp_texel_idx] = packed_out; |
| 246 | + |
| 247 | + outp_tidx.data[0] += 1; |
| 248 | + } |
| 249 | +} |
0 commit comments