|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +/* |
| 10 | + * KHR Cooperative Matrix MM kernel — unified linear + matmul. |
| 11 | + * |
| 12 | + * Variants (set in coopmat_mm.yaml): |
| 13 | + * matmul_coopmat row_major weight, no bias (aten.mm runtime mat2) |
| 14 | + * linear_coopmat prepacked weight, no bias (aten.linear) |
| 15 | + * linear_coopmat_bias prepacked weight, +bias (aten.linear w/ bias) |
| 16 | + * |
| 17 | + * Computes: D = A * B[ + bias] |
| 18 | + * A is [M, K] (row-major). |
| 19 | + * B is either [K, N] row-major (matmul), or 4OC x 4IC blocked prepacked |
| 20 | + * with t_weight[(k4 * N4 + n4) * 4 + dk] returning a vec4 of 4 N-elements |
| 21 | + * at K-row k4*4+dk (linear). |
| 22 | + * D is [M, N], buffer storage. |
| 23 | + * |
| 24 | + * fp16 x fp16 -> fp32 MMA. When DTYPE=half, inputs/outputs are native fp16 |
| 25 | + * (no conversion, half the bandwidth). When DTYPE=float, inputs are fp32 |
| 26 | + * with on-the-fly packHalf2x16 conversion at the shared-memory load. |
| 27 | + * |
| 28 | + * When HAS_BIAS, bias is staged once into shared memory and broadcast into |
| 29 | + * each accumulator tile (stride-0 coopMatLoad) before coopMatStore, so |
| 30 | + * t_output is write-only. |
| 31 | + * |
| 32 | + * Tile hierarchy (configurable via yaml; defaults shown for Adreno): |
| 33 | + * MMA_* per-MMA-instruction shape (16x16x16 fp16) |
| 34 | + * WG_TILE_* output tile produced per workgroup (64x64; K-step 32) |
| 35 | + * SG_GRID_* subgroup grid inside the workgroup (2x2 = 4 subgroups) |
| 36 | + * SG_TILE_* per-subgroup output tile (= WG_TILE / SG_GRID; 32x32) |
| 37 | + * SUBGROUP_SIZE hardware subgroup width (64 on Adreno) |
| 38 | + * WG_SIZE threads per workgroup (= NUM_SUBGROUPS * SUBGROUP_SIZE) |
| 39 | + */ |
| 40 | + |
| 41 | +#version 450 core |
| 42 | + |
| 43 | +#extension GL_KHR_cooperative_matrix : require |
| 44 | +#extension GL_KHR_memory_scope_semantics : require |
| 45 | +#extension GL_KHR_shader_subgroup_basic : enable |
| 46 | +#extension GL_EXT_shader_explicit_arithmetic_types : require |
| 47 | +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require |
| 48 | +#extension GL_EXT_control_flow_attributes : enable |
| 49 | + |
| 50 | +#define PRECISION ${PRECISION} |
| 51 | + |
| 52 | +$if DTYPE == "half": |
| 53 | + #define IS_FP16_INPUT |
| 54 | +$if DTYPE == "float": |
| 55 | + #define IS_FP32_INPUT |
| 56 | + |
| 57 | +$if HAS_BIAS: |
| 58 | + #define HAS_BIAS |
| 59 | + |
| 60 | +$if WEIGHT_LAYOUT == "prepacked": |
| 61 | + #define WEIGHT_PREPACKED |
| 62 | + |
| 63 | +layout(std430) buffer; |
| 64 | + |
| 65 | +#include "common.glslh" |
| 66 | + |
| 67 | +// Bindings: output(0), mat1(1), weight(2), [bias(3)] |
| 68 | +${layout_declare_tensor(B, "w", "t_output", DTYPE, "buffer", is_scalar_array=True)} |
| 69 | +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer", is_scalar_array=False)} |
| 70 | +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer", is_scalar_array=False)} |
| 71 | +$if HAS_BIAS: |
| 72 | + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=True)} |
| 73 | + |
| 74 | +// UBOs — N comes from out_sizes (linear) or mat2_sizes (matmul). |
| 75 | +${layout_declare_ubo(B, "ivec4", "mat1_sizes")} |
| 76 | +$if WEIGHT_LAYOUT == "prepacked": |
| 77 | + ${layout_declare_ubo(B, "ivec4", "out_sizes")} |
| 78 | +$else: |
| 79 | + ${layout_declare_ubo(B, "ivec4", "mat2_sizes")} |
| 80 | + |
| 81 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 82 | + |
| 83 | +// Cooperative-matrix instruction shape (must match a property enumerated by |
| 84 | +// vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR for this device). |
| 85 | +const uint MMA_M = ${MMA_M}; |
| 86 | +const uint MMA_N = ${MMA_N}; |
| 87 | +const uint MMA_K = ${MMA_K}; |
| 88 | + |
| 89 | +// Output tile produced per workgroup. |
| 90 | +const uint WG_TILE_M = ${WG_TILE_M}; |
| 91 | +const uint WG_TILE_N = ${WG_TILE_N}; |
| 92 | +const uint WG_TILE_K = ${WG_TILE_K}; |
| 93 | + |
| 94 | +// Subgroup grid inside the workgroup. |
| 95 | +const uint SG_GRID_X = ${SG_GRID_X}; |
| 96 | +const uint SG_GRID_Y = ${SG_GRID_Y}; |
| 97 | +const uint SUBGROUP_SIZE = ${SUBGROUP_SIZE}; |
| 98 | +const uint NUM_SUBGROUPS = SG_GRID_X * SG_GRID_Y; |
| 99 | +const uint WG_SIZE = NUM_SUBGROUPS * SUBGROUP_SIZE; |
| 100 | + |
| 101 | +// Derived: per-subgroup tile and MMAs per subgroup tile. |
| 102 | +const uint SG_TILE_M = WG_TILE_M / SG_GRID_Y; |
| 103 | +const uint SG_TILE_N = WG_TILE_N / SG_GRID_X; |
| 104 | +const uint MMAS_PER_SG_M = SG_TILE_M / MMA_M; |
| 105 | +const uint MMAS_PER_SG_N = SG_TILE_N / MMA_N; |
| 106 | + |
| 107 | +// fp16: 8 elements per uvec4 (128-bit) |
| 108 | +const uint FP16_PER_VEC4 = 8; |
| 109 | + |
| 110 | +// Shared memory with skew padding |
| 111 | +const uint A_STRIDE_VEC4 = (WG_TILE_K + FP16_PER_VEC4) / FP16_PER_VEC4; |
| 112 | +const uint B_STRIDE_VEC4 = (WG_TILE_N + FP16_PER_VEC4) / FP16_PER_VEC4; |
| 113 | + |
| 114 | +shared uvec4 Ash[WG_TILE_M * A_STRIDE_VEC4]; |
| 115 | +shared uvec4 Bsh[WG_TILE_K * B_STRIDE_VEC4]; |
| 116 | + |
| 117 | +#ifdef HAS_BIAS |
| 118 | +// fp32 staging buffer so coopMatLoad can broadcast directly into the |
| 119 | +// fp32 accumulator coopmat without a type conversion at the load. |
| 120 | +shared float bias_sh[WG_TILE_N]; |
| 121 | +#endif |
| 122 | + |
| 123 | +// Accumulator tiles (fp32) |
| 124 | +coopmat<float, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator> result[MMAS_PER_SG_M][MMAS_PER_SG_N]; |
| 125 | + |
| 126 | +#ifdef IS_FP32_INPUT |
| 127 | +uvec2 f32x4_to_f16x4(vec4 v) { |
| 128 | + return uvec2(packHalf2x16(v.xy), packHalf2x16(v.zw)); |
| 129 | +} |
| 130 | +#endif |
| 131 | + |
| 132 | +void main() { |
| 133 | + const uvec2 tileID = uvec2(gl_WorkGroupID.xy); |
| 134 | + const uvec2 warpInTile = uvec2( |
| 135 | + gl_SubgroupID % SG_GRID_X, |
| 136 | + gl_SubgroupID / SG_GRID_X); |
| 137 | + |
| 138 | + const uint K = uint(mat1_sizes.x); |
| 139 | + const uint M = uint(mat1_sizes.y); |
| 140 | +#ifdef WEIGHT_PREPACKED |
| 141 | + const uint N = uint(out_sizes.x); |
| 142 | +#else |
| 143 | + const uint N = uint(mat2_sizes.x); |
| 144 | +#endif |
| 145 | + const uint K4 = (K + 3u) / 4u; |
| 146 | + const uint N4 = (N + 3u) / 4u; |
| 147 | + |
| 148 | + // Defensive: skip workgroups whose tile is out of bounds. The C++ pick |
| 149 | + // function dispatches exactly num_tiles_n x num_tiles_m workgroups under |
| 150 | + // the alignment-gated (M%WG_TILE_M==0, N%WG_TILE_N==0) inputs, so this |
| 151 | + // never triggers today; it guards against a future dispatch error. |
| 152 | + const uint num_tiles_n = (N + WG_TILE_N - 1u) / WG_TILE_N; |
| 153 | + const uint num_tiles_m = (M + WG_TILE_M - 1u) / WG_TILE_M; |
| 154 | + if (tileID.x >= num_tiles_n || tileID.y >= num_tiles_m) { |
| 155 | + return; |
| 156 | + } |
| 157 | + |
| 158 | + [[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) { |
| 159 | + [[unroll]] for (uint j = 0; j < MMAS_PER_SG_N; ++j) { |
| 160 | + result[i][j] = coopmat<float, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator>(0.0); |
| 161 | + } |
| 162 | + } |
| 163 | + |
| 164 | + // Thread assignment for A tile (WG_TILE_M rows x INVS_PER_ROW_A uvec4/row) |
| 165 | + const uint INVS_PER_ROW_A = WG_TILE_K / FP16_PER_VEC4; |
| 166 | + const uint a_col = gl_LocalInvocationID.x % INVS_PER_ROW_A; |
| 167 | + const uint a_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_A; |
| 168 | + |
| 169 | + // Thread assignment for B tile (WG_TILE_K rows x INVS_PER_ROW_B uvec4/row) |
| 170 | + const uint INVS_PER_ROW_B = WG_TILE_N / FP16_PER_VEC4; |
| 171 | + const uint b_col = gl_LocalInvocationID.x % INVS_PER_ROW_B; |
| 172 | + const uint b_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_B; |
| 173 | + |
| 174 | + const uint a_row_base = WG_TILE_M * tileID.y; |
| 175 | + const uint b_col_base = WG_TILE_N * tileID.x; |
| 176 | + |
| 177 | + for (uint chunkK = 0; chunkK < K; chunkK += WG_TILE_K) { |
| 178 | + |
| 179 | + // --- Load A tile -> shared (single pass) --- |
| 180 | + { |
| 181 | + uint row = a_row_base + a_row_offset; |
| 182 | + uint k_elem = chunkK + a_col * FP16_PER_VEC4; |
| 183 | + |
| 184 | +#ifdef IS_FP16_INPUT |
| 185 | + uint k_hv4 = k_elem / 4; |
| 186 | + f16vec4 v0 = t_mat1[row * K4 + k_hv4]; |
| 187 | + f16vec4 v1 = t_mat1[row * K4 + k_hv4 + 1]; |
| 188 | + Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4( |
| 189 | + packFloat2x16(v0.xy), packFloat2x16(v0.zw), |
| 190 | + packFloat2x16(v1.xy), packFloat2x16(v1.zw)); |
| 191 | +#else |
| 192 | + uint k_vec4 = k_elem / 4; |
| 193 | + vec4 v0 = t_mat1[row * K4 + k_vec4]; |
| 194 | + vec4 v1 = t_mat1[row * K4 + k_vec4 + 1]; |
| 195 | + uvec2 h0 = f32x4_to_f16x4(v0); |
| 196 | + uvec2 h1 = f32x4_to_f16x4(v1); |
| 197 | + Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(h0, h1); |
| 198 | +#endif |
| 199 | + } |
| 200 | + |
| 201 | + // --- Load B tile -> shared (single pass) --- |
| 202 | + { |
| 203 | + uint k_row = chunkK + b_row_offset; |
| 204 | + uint n_elem = b_col_base + b_col * FP16_PER_VEC4; |
| 205 | + uint n4_0 = n_elem >> 2u; |
| 206 | +#ifdef WEIGHT_PREPACKED |
| 207 | + // Prepacked: t_weight[(k4 * N4 + n4) * 4 + dk] yields vec4 of |
| 208 | + // 4 N-elements at K-row (k4*4+dk). |
| 209 | + uint k4 = k_row >> 2u; |
| 210 | + uint dk = k_row & 3u; |
| 211 | + uint b_idx0 = (k4 * N4 + n4_0) * 4u + dk; |
| 212 | + uint b_idx1 = (k4 * N4 + n4_0 + 1u) * 4u + dk; |
| 213 | +#else |
| 214 | + // Row-major: t_weight[k_row * N4 + n4] yields vec4 of 4 N-elements. |
| 215 | + uint b_idx0 = k_row * N4 + n4_0; |
| 216 | + uint b_idx1 = k_row * N4 + n4_0 + 1u; |
| 217 | +#endif |
| 218 | + |
| 219 | +#ifdef IS_FP16_INPUT |
| 220 | + f16vec4 v0 = t_weight[b_idx0]; |
| 221 | + f16vec4 v1 = t_weight[b_idx1]; |
| 222 | + Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4( |
| 223 | + packFloat2x16(v0.xy), packFloat2x16(v0.zw), |
| 224 | + packFloat2x16(v1.xy), packFloat2x16(v1.zw)); |
| 225 | +#else |
| 226 | + vec4 v0 = t_weight[b_idx0]; |
| 227 | + vec4 v1 = t_weight[b_idx1]; |
| 228 | + uvec2 h0 = f32x4_to_f16x4(v0); |
| 229 | + uvec2 h1 = f32x4_to_f16x4(v1); |
| 230 | + Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4(h0, h1); |
| 231 | +#endif |
| 232 | + } |
| 233 | + |
| 234 | + barrier(); |
| 235 | + |
| 236 | + // --- Cooperative matrix MMA --- |
| 237 | + [[unroll]] for (uint k = 0; k < WG_TILE_K / MMA_K; ++k) { |
| 238 | + uint k_start = MMA_K * k; |
| 239 | + |
| 240 | + coopmat<float16_t, gl_ScopeSubgroup, MMA_M, MMA_K, gl_MatrixUseA> matA[MMAS_PER_SG_M]; |
| 241 | + [[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) { |
| 242 | + uint row_a = MMA_M * (MMAS_PER_SG_M * warpInTile.y + i); |
| 243 | + coopMatLoad( |
| 244 | + matA[i], Ash, |
| 245 | + row_a * A_STRIDE_VEC4 + k_start / FP16_PER_VEC4, |
| 246 | + A_STRIDE_VEC4, |
| 247 | + gl_CooperativeMatrixLayoutRowMajor); |
| 248 | + } |
| 249 | + |
| 250 | + coopmat<float16_t, gl_ScopeSubgroup, MMA_K, MMA_N, gl_MatrixUseB> matB; |
| 251 | + [[unroll]] for (uint j = 0; j < MMAS_PER_SG_N; ++j) { |
| 252 | + uint col_b = MMA_N * (MMAS_PER_SG_N * warpInTile.x + j) / FP16_PER_VEC4; |
| 253 | + coopMatLoad( |
| 254 | + matB, Bsh, |
| 255 | + k_start * B_STRIDE_VEC4 + col_b, |
| 256 | + B_STRIDE_VEC4, |
| 257 | + gl_CooperativeMatrixLayoutRowMajor); |
| 258 | + |
| 259 | + [[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) { |
| 260 | + result[i][j] = coopMatMulAdd(matA[i], matB, result[i][j]); |
| 261 | + } |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + barrier(); |
| 266 | + } |
| 267 | + |
| 268 | +#ifdef HAS_BIAS |
| 269 | + // Stage one WG_TILE_N-wide row of bias into shared memory. The C++ dispatch |
| 270 | + // gate ensures N % WG_TILE_N == 0, so no per-element bounds check is needed. |
| 271 | + { |
| 272 | + const uint tile_n_start = WG_TILE_N * tileID.x; |
| 273 | + for (uint t = gl_LocalInvocationID.x; t < WG_TILE_N; t += WG_SIZE) { |
| 274 | + bias_sh[t] = float(t_bias[tile_n_start + t]); |
| 275 | + } |
| 276 | + } |
| 277 | + memoryBarrierShared(); |
| 278 | + barrier(); |
| 279 | +#endif |
| 280 | + |
| 281 | + // --- Store result (with bias folded in pre-store, if present) --- |
| 282 | + [[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) { |
| 283 | + [[unroll]] for (uint j = 0; j < MMAS_PER_SG_N; ++j) { |
| 284 | + uint gi = WG_TILE_M * tileID.y + MMA_M * (MMAS_PER_SG_M * warpInTile.y + i); |
| 285 | + uint gj = WG_TILE_N * tileID.x + MMA_N * (MMAS_PER_SG_N * warpInTile.x + j); |
| 286 | + |
| 287 | +#ifdef HAS_BIAS |
| 288 | + // Stride-0 row-major load broadcasts MMA_N bias values across all |
| 289 | + // MMA_M rows of the accumulator tile. |
| 290 | + uint local_n = MMA_N * (MMAS_PER_SG_N * warpInTile.x + j); |
| 291 | + coopmat<float, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator> bias_tile; |
| 292 | + coopMatLoad( |
| 293 | + bias_tile, bias_sh, |
| 294 | + local_n, /*stride=*/0u, |
| 295 | + gl_CooperativeMatrixLayoutRowMajor); |
| 296 | + result[i][j] += bias_tile; |
| 297 | +#endif |
| 298 | + |
| 299 | +#ifdef IS_FP16_INPUT |
| 300 | + coopmat<float16_t, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator> out_tile = |
| 301 | + coopmat<float16_t, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator>(result[i][j]); |
| 302 | + coopMatStore( |
| 303 | + out_tile, t_output, |
| 304 | + gi * N + gj, N, |
| 305 | + gl_CooperativeMatrixLayoutRowMajor); |
| 306 | +#else |
| 307 | + coopMatStore( |
| 308 | + result[i][j], t_output, |
| 309 | + gi * N + gj, N, |
| 310 | + gl_CooperativeMatrixLayoutRowMajor); |
| 311 | +#endif |
| 312 | + } |
| 313 | + } |
| 314 | +} |
0 commit comments