|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +/* |
| 10 | + * KHR Cooperative Matrix linear shader for prepacked weights. |
| 11 | + * Drop-in replacement for linear_vec when storage=buffer and device |
| 12 | + * supports GL_KHR_cooperative_matrix. |
| 13 | + * |
| 14 | + * Computes: D = A * W_packed (A: [M, K], W_packed: 4OC x 4IC blocked, D: [M, N]) |
| 15 | + * |
| 16 | + * Weight is prepacked by pack_fp_linear_weight into a 4OC x 4IC blocked layout: |
| 17 | + * t_weight_packed[(k4 * N4 + n4) * 4 + dk] = vec4(w[k4*4+dk][n4*4+0..3]) |
| 18 | + * |
| 19 | + * fp16xfp16->fp32 MMA. When DTYPE=half, inputs are native fp16 (no |
| 20 | + * conversion, half the bandwidth). When DTYPE=float, inputs are fp32 |
| 21 | + * with on-the-fly packHalf2x16 conversion. |
| 22 | + * |
| 23 | + * Output is always fp32 (fp32 accumulator -> fp32 store) when DTYPE=float, |
| 24 | + * or fp16 when DTYPE=half. |
| 25 | + * |
| 26 | + * Optional bias: when HAS_BIAS is defined, bias is added post-store via |
| 27 | + * read-modify-write on the output buffer (one pass over the tile). |
| 28 | + */ |
| 29 | + |
| 30 | +#version 450 core |
| 31 | + |
| 32 | +#extension GL_KHR_cooperative_matrix : require |
| 33 | +#extension GL_KHR_memory_scope_semantics : require |
| 34 | +#extension GL_KHR_shader_subgroup_basic : enable |
| 35 | +#extension GL_EXT_shader_explicit_arithmetic_types : require |
| 36 | +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require |
| 37 | +#extension GL_EXT_control_flow_attributes : enable |
| 38 | + |
| 39 | +#define PRECISION ${PRECISION} |
| 40 | + |
| 41 | +$if DTYPE == "half": |
| 42 | + #define IS_FP16_INPUT |
| 43 | +$if DTYPE == "float": |
| 44 | + #define IS_FP32_INPUT |
| 45 | + |
| 46 | +$if HAS_BIAS: |
| 47 | + #define HAS_BIAS |
| 48 | + |
| 49 | +layout(std430) buffer; |
| 50 | + |
| 51 | +#include "common.glslh" |
| 52 | + |
| 53 | +// Bindings: output(0), mat1(1), weight_packed(2), [bias(3)] |
| 54 | +$if HAS_BIAS: |
| 55 | + ${layout_declare_tensor(B, "rw", "t_output", DTYPE, "buffer", is_scalar_array=True)} |
| 56 | +$else: |
| 57 | + ${layout_declare_tensor(B, "w", "t_output", DTYPE, "buffer", is_scalar_array=True)} |
| 58 | +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer", is_scalar_array=False)} |
| 59 | +${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, "buffer", is_scalar_array=False)} |
| 60 | +$if HAS_BIAS: |
| 61 | + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=True)} |
| 62 | + |
| 63 | +// UBOs |
| 64 | +${layout_declare_ubo(B, "ivec4", "mat1_sizes")} |
| 65 | +${layout_declare_ubo(B, "ivec4", "out_sizes")} |
| 66 | + |
| 67 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 68 | + |
| 69 | +// Tile dimensions (same as matmul_coopmat) |
| 70 | +const uint lM = 16; |
| 71 | +const uint lN = 16; |
| 72 | +const uint lK = 16; |
| 73 | +const uint TILE_M = 64; |
| 74 | +const uint TILE_N = 64; |
| 75 | +const uint TILE_K = 32; |
| 76 | + |
| 77 | +// Workgroup: 4 subgroups in 2x2 grid, 64 threads each = 256 total |
| 78 | +const uint WG_WIDTH = 2; |
| 79 | +const uint WG_HEIGHT = 2; |
| 80 | +const uint NUM_SUBGROUPS = 4; |
| 81 | +const uint INVOCATIONS = 64 * NUM_SUBGROUPS; |
| 82 | + |
| 83 | +// Result tiles per subgroup: 2x2 |
| 84 | +const uint C_ROWS = TILE_M / WG_HEIGHT / lM; // 2 |
| 85 | +const uint C_COLS = TILE_N / WG_WIDTH / lN; // 2 |
| 86 | + |
| 87 | +// fp16: 8 elements per uvec4 (128-bit) |
| 88 | +const uint FP16_PER_VEC4 = 8; |
| 89 | + |
| 90 | +// Shared memory with skew padding |
| 91 | +const uint A_STRIDE_VEC4 = (TILE_K + FP16_PER_VEC4) / FP16_PER_VEC4; // 5 |
| 92 | +const uint B_STRIDE_VEC4 = (TILE_N + FP16_PER_VEC4) / FP16_PER_VEC4; // 9 |
| 93 | + |
| 94 | +shared uvec4 Ash[TILE_M * A_STRIDE_VEC4]; // 5KB |
| 95 | +shared uvec4 Bsh[TILE_K * B_STRIDE_VEC4]; // 4.5KB |
| 96 | + |
| 97 | +// Accumulator tiles (fp32) |
| 98 | +coopmat<float, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> result[C_ROWS][C_COLS]; |
| 99 | + |
| 100 | +#ifdef IS_FP32_INPUT |
| 101 | +uvec2 f32x4_to_f16x4(vec4 v) { |
| 102 | + return uvec2(packHalf2x16(v.xy), packHalf2x16(v.zw)); |
| 103 | +} |
| 104 | +#endif |
| 105 | + |
| 106 | +void main() { |
| 107 | + const uvec2 tileID = uvec2(gl_WorkGroupID.xy); |
| 108 | + const uvec2 warpInTile = uvec2( |
| 109 | + gl_SubgroupID % WG_WIDTH, |
| 110 | + gl_SubgroupID / WG_WIDTH); |
| 111 | + |
| 112 | + const uint K = uint(mat1_sizes.x); |
| 113 | + const uint M = uint(mat1_sizes.y); |
| 114 | + const uint N = uint(out_sizes.x); |
| 115 | + const uint K4 = (K + 3u) / 4u; |
| 116 | + const uint N4 = (N + 3u) / 4u; |
| 117 | + |
| 118 | + [[unroll]] for (uint i = 0; i < C_ROWS; ++i) { |
| 119 | + [[unroll]] for (uint j = 0; j < C_COLS; ++j) { |
| 120 | + result[i][j] = coopmat<float, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator>(0.0); |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + // Thread assignment for A tile (64 rows x 4 uvec4/row = single pass) |
| 125 | + const uint INVS_PER_ROW_A = TILE_K / FP16_PER_VEC4; // 4 |
| 126 | + const uint a_col = gl_LocalInvocationID.x % INVS_PER_ROW_A; |
| 127 | + const uint a_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_A; |
| 128 | + |
| 129 | + // Thread assignment for B tile (32 rows x 8 uvec4/row = single pass) |
| 130 | + const uint INVS_PER_ROW_B = TILE_N / FP16_PER_VEC4; // 8 |
| 131 | + const uint b_col = gl_LocalInvocationID.x % INVS_PER_ROW_B; |
| 132 | + const uint b_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_B; |
| 133 | + |
| 134 | + const uint a_row_base = TILE_M * tileID.y; |
| 135 | + const uint b_col_base = TILE_N * tileID.x; |
| 136 | + |
| 137 | + for (uint chunkK = 0; chunkK < K; chunkK += TILE_K) { |
| 138 | + |
| 139 | + // --- Load A tile -> shared (same as matmul_coopmat) --- |
| 140 | + { |
| 141 | + uint row = a_row_base + a_row_offset; |
| 142 | + uint k_elem = chunkK + a_col * FP16_PER_VEC4; |
| 143 | + |
| 144 | +#ifdef IS_FP16_INPUT |
| 145 | + uint k_hv4 = k_elem / 4; |
| 146 | + f16vec4 v0 = t_mat1[row * K4 + k_hv4]; |
| 147 | + f16vec4 v1 = t_mat1[row * K4 + k_hv4 + 1]; |
| 148 | + Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4( |
| 149 | + packHalf2x16(vec2(v0.xy)), packHalf2x16(vec2(v0.zw)), |
| 150 | + packHalf2x16(vec2(v1.xy)), packHalf2x16(vec2(v1.zw))); |
| 151 | +#else |
| 152 | + uint k_vec4 = k_elem / 4; |
| 153 | + vec4 v0 = t_mat1[row * K4 + k_vec4]; |
| 154 | + vec4 v1 = t_mat1[row * K4 + k_vec4 + 1]; |
| 155 | + uvec2 h0 = f32x4_to_f16x4(v0); |
| 156 | + uvec2 h1 = f32x4_to_f16x4(v1); |
| 157 | + Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(h0, h1); |
| 158 | +#endif |
| 159 | + } |
| 160 | + |
| 161 | + // --- Load B tile from packed weight -> shared --- |
| 162 | + // Packed weight format: t_weight_packed[(k4 * N4 + n4) * 4 + dk] |
| 163 | + // returns vec4 of 4 N-elements at K-row (k4*4+dk). |
| 164 | + // Load two vec4s to get 8 consecutive N-elements = one uvec4 in Bsh. |
| 165 | + { |
| 166 | + uint k_row = chunkK + b_row_offset; |
| 167 | + uint k4 = k_row >> 2u; |
| 168 | + uint dk = k_row & 3u; |
| 169 | + uint n_elem = b_col_base + b_col * FP16_PER_VEC4; |
| 170 | + uint n4_0 = n_elem >> 2u; |
| 171 | + |
| 172 | +#ifdef IS_FP16_INPUT |
| 173 | + f16vec4 v0 = t_weight_packed[(k4 * N4 + n4_0) * 4u + dk]; |
| 174 | + f16vec4 v1 = t_weight_packed[(k4 * N4 + n4_0 + 1u) * 4u + dk]; |
| 175 | + Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4( |
| 176 | + packHalf2x16(vec2(v0.xy)), packHalf2x16(vec2(v0.zw)), |
| 177 | + packHalf2x16(vec2(v1.xy)), packHalf2x16(vec2(v1.zw))); |
| 178 | +#else |
| 179 | + vec4 v0 = t_weight_packed[(k4 * N4 + n4_0) * 4u + dk]; |
| 180 | + vec4 v1 = t_weight_packed[(k4 * N4 + n4_0 + 1u) * 4u + dk]; |
| 181 | + uvec2 h0 = f32x4_to_f16x4(v0); |
| 182 | + uvec2 h1 = f32x4_to_f16x4(v1); |
| 183 | + Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4(h0, h1); |
| 184 | +#endif |
| 185 | + } |
| 186 | + |
| 187 | + barrier(); |
| 188 | + |
| 189 | + // --- Cooperative matrix MMA --- |
| 190 | + [[unroll]] for (uint k = 0; k < TILE_K / lK; ++k) { |
| 191 | + uint k_start = lK * k; |
| 192 | + |
| 193 | + coopmat<float16_t, gl_ScopeSubgroup, lM, lK, gl_MatrixUseA> matA[C_ROWS]; |
| 194 | + [[unroll]] for (uint i = 0; i < C_ROWS; ++i) { |
| 195 | + uint row_a = lM * (C_ROWS * warpInTile.y + i); |
| 196 | + coopMatLoad( |
| 197 | + matA[i], Ash, |
| 198 | + row_a * A_STRIDE_VEC4 + k_start / FP16_PER_VEC4, |
| 199 | + A_STRIDE_VEC4, |
| 200 | + gl_CooperativeMatrixLayoutRowMajor); |
| 201 | + } |
| 202 | + |
| 203 | + coopmat<float16_t, gl_ScopeSubgroup, lK, lN, gl_MatrixUseB> matB; |
| 204 | + [[unroll]] for (uint j = 0; j < C_COLS; ++j) { |
| 205 | + uint col_b = lN * (C_COLS * warpInTile.x + j) / FP16_PER_VEC4; |
| 206 | + coopMatLoad( |
| 207 | + matB, Bsh, |
| 208 | + k_start * B_STRIDE_VEC4 + col_b, |
| 209 | + B_STRIDE_VEC4, |
| 210 | + gl_CooperativeMatrixLayoutRowMajor); |
| 211 | + |
| 212 | + [[unroll]] for (uint i = 0; i < C_ROWS; ++i) { |
| 213 | + result[i][j] = coopMatMulAdd(matA[i], matB, result[i][j]); |
| 214 | + } |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + barrier(); |
| 219 | + } |
| 220 | + |
| 221 | + // --- Store result --- |
| 222 | + [[unroll]] for (uint i = 0; i < C_ROWS; ++i) { |
| 223 | + [[unroll]] for (uint j = 0; j < C_COLS; ++j) { |
| 224 | + uint gi = TILE_M * tileID.y + lM * (C_ROWS * warpInTile.y + i); |
| 225 | + uint gj = TILE_N * tileID.x + lN * (C_COLS * warpInTile.x + j); |
| 226 | +#ifdef IS_FP16_INPUT |
| 227 | + coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> out_tile = |
| 228 | + coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator>(result[i][j]); |
| 229 | + coopMatStore( |
| 230 | + out_tile, t_output, |
| 231 | + gi * N + gj, N, |
| 232 | + gl_CooperativeMatrixLayoutRowMajor); |
| 233 | +#else |
| 234 | + coopMatStore( |
| 235 | + result[i][j], t_output, |
| 236 | + gi * N + gj, N, |
| 237 | + gl_CooperativeMatrixLayoutRowMajor); |
| 238 | +#endif |
| 239 | + } |
| 240 | + } |
| 241 | + |
| 242 | +#ifdef HAS_BIAS |
| 243 | + // Add bias via read-modify-write on the output buffer. |
| 244 | + // barrier() ensures all coopMatStore writes within this workgroup are visible. |
| 245 | + barrier(); |
| 246 | + |
| 247 | + const uint tile_m_start = TILE_M * tileID.y; |
| 248 | + const uint tile_n_start = TILE_N * tileID.x; |
| 249 | + // 64x64 tile = 4096 elements, 256 threads -> 16 elements per thread |
| 250 | + for (uint idx = gl_LocalInvocationID.x; idx < TILE_M * TILE_N; idx += INVOCATIONS) { |
| 251 | + uint local_m = idx / TILE_N; |
| 252 | + uint local_n = idx % TILE_N; |
| 253 | + uint gm = tile_m_start + local_m; |
| 254 | + uint gn = tile_n_start + local_n; |
| 255 | + if (gm < M && gn < N) { |
| 256 | + uint out_idx = gm * N + gn; |
| 257 | + t_output[out_idx] = t_output[out_idx] + t_bias[gn]; |
| 258 | + } |
| 259 | + } |
| 260 | +#endif |
| 261 | +} |
0 commit comments