Skip to content

Commit b26728a

Browse files
committed
Add linear_coopmat + matmul_coopmat drop-in shader variants
Adds VK_KHR_cooperative_matrix GLSL variants of the tiled linear and matmul shaders. Dispatch is gated by Adapter::supports_cooperative_matrix() and buffer output storage, with automatic fallback to the tiled shader when unsupported. An M >= 64 guard avoids a known OOB in the current coopmat store; that guard will be removed once partial-tile bounds checking is added to the shader. Includes linear_coopmat_bench and matmul_coopmat_bench microbenchmarks that compare against linear_vec / matmul_vec across BERT and LLM-sized shapes using Vulkan query-pool timestamps.
1 parent 08d748d commit b26728a

12 files changed

Lines changed: 1390 additions & 24 deletions

File tree

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/*
10+
* KHR Cooperative Matrix linear shader for prepacked weights.
11+
* Drop-in replacement for linear_vec when storage=buffer and device
12+
* supports GL_KHR_cooperative_matrix.
13+
*
14+
* Computes: D = A * W_packed (A: [M, K], W_packed: 4OC x 4IC blocked, D: [M, N])
15+
*
16+
* Weight is prepacked by pack_fp_linear_weight into a 4OC x 4IC blocked layout:
17+
* t_weight_packed[(k4 * N4 + n4) * 4 + dk] = vec4(w[k4*4+dk][n4*4+0..3])
18+
*
19+
* fp16xfp16->fp32 MMA. When DTYPE=half, inputs are native fp16 (no
20+
* conversion, half the bandwidth). When DTYPE=float, inputs are fp32
21+
* with on-the-fly packHalf2x16 conversion.
22+
*
23+
* Output is always fp32 (fp32 accumulator -> fp32 store) when DTYPE=float,
24+
* or fp16 when DTYPE=half.
25+
*
26+
* Optional bias: when HAS_BIAS is defined, bias is added post-store via
27+
* read-modify-write on the output buffer (one pass over the tile).
28+
*/
29+
30+
#version 450 core
31+
32+
#extension GL_KHR_cooperative_matrix : require
33+
#extension GL_KHR_memory_scope_semantics : require
34+
#extension GL_KHR_shader_subgroup_basic : enable
35+
#extension GL_EXT_shader_explicit_arithmetic_types : require
36+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
37+
#extension GL_EXT_control_flow_attributes : enable
38+
39+
#define PRECISION ${PRECISION}
40+
41+
$if DTYPE == "half":
42+
#define IS_FP16_INPUT
43+
$if DTYPE == "float":
44+
#define IS_FP32_INPUT
45+
46+
$if HAS_BIAS:
47+
#define HAS_BIAS
48+
49+
layout(std430) buffer;
50+
51+
#include "common.glslh"
52+
53+
// Bindings: output(0), mat1(1), weight_packed(2), [bias(3)]
54+
$if HAS_BIAS:
55+
${layout_declare_tensor(B, "rw", "t_output", DTYPE, "buffer", is_scalar_array=True)}
56+
$else:
57+
${layout_declare_tensor(B, "w", "t_output", DTYPE, "buffer", is_scalar_array=True)}
58+
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer", is_scalar_array=False)}
59+
${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, "buffer", is_scalar_array=False)}
60+
$if HAS_BIAS:
61+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=True)}
62+
63+
// UBOs
64+
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
65+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
66+
67+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
68+
69+
// Tile dimensions (same as matmul_coopmat)
70+
const uint lM = 16;
71+
const uint lN = 16;
72+
const uint lK = 16;
73+
const uint TILE_M = 64;
74+
const uint TILE_N = 64;
75+
const uint TILE_K = 32;
76+
77+
// Workgroup: 4 subgroups in 2x2 grid, 64 threads each = 256 total
78+
const uint WG_WIDTH = 2;
79+
const uint WG_HEIGHT = 2;
80+
const uint NUM_SUBGROUPS = 4;
81+
const uint INVOCATIONS = 64 * NUM_SUBGROUPS;
82+
83+
// Result tiles per subgroup: 2x2
84+
const uint C_ROWS = TILE_M / WG_HEIGHT / lM; // 2
85+
const uint C_COLS = TILE_N / WG_WIDTH / lN; // 2
86+
87+
// fp16: 8 elements per uvec4 (128-bit)
88+
const uint FP16_PER_VEC4 = 8;
89+
90+
// Shared memory with skew padding
91+
const uint A_STRIDE_VEC4 = (TILE_K + FP16_PER_VEC4) / FP16_PER_VEC4; // 5
92+
const uint B_STRIDE_VEC4 = (TILE_N + FP16_PER_VEC4) / FP16_PER_VEC4; // 9
93+
94+
shared uvec4 Ash[TILE_M * A_STRIDE_VEC4]; // 5KB
95+
shared uvec4 Bsh[TILE_K * B_STRIDE_VEC4]; // 4.5KB
96+
97+
// Accumulator tiles (fp32)
98+
coopmat<float, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> result[C_ROWS][C_COLS];
99+
100+
#ifdef IS_FP32_INPUT
101+
uvec2 f32x4_to_f16x4(vec4 v) {
102+
return uvec2(packHalf2x16(v.xy), packHalf2x16(v.zw));
103+
}
104+
#endif
105+
106+
void main() {
107+
const uvec2 tileID = uvec2(gl_WorkGroupID.xy);
108+
const uvec2 warpInTile = uvec2(
109+
gl_SubgroupID % WG_WIDTH,
110+
gl_SubgroupID / WG_WIDTH);
111+
112+
const uint K = uint(mat1_sizes.x);
113+
const uint M = uint(mat1_sizes.y);
114+
const uint N = uint(out_sizes.x);
115+
const uint K4 = (K + 3u) / 4u;
116+
const uint N4 = (N + 3u) / 4u;
117+
118+
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
119+
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
120+
result[i][j] = coopmat<float, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator>(0.0);
121+
}
122+
}
123+
124+
// Thread assignment for A tile (64 rows x 4 uvec4/row = single pass)
125+
const uint INVS_PER_ROW_A = TILE_K / FP16_PER_VEC4; // 4
126+
const uint a_col = gl_LocalInvocationID.x % INVS_PER_ROW_A;
127+
const uint a_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_A;
128+
129+
// Thread assignment for B tile (32 rows x 8 uvec4/row = single pass)
130+
const uint INVS_PER_ROW_B = TILE_N / FP16_PER_VEC4; // 8
131+
const uint b_col = gl_LocalInvocationID.x % INVS_PER_ROW_B;
132+
const uint b_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_B;
133+
134+
const uint a_row_base = TILE_M * tileID.y;
135+
const uint b_col_base = TILE_N * tileID.x;
136+
137+
for (uint chunkK = 0; chunkK < K; chunkK += TILE_K) {
138+
139+
// --- Load A tile -> shared (same as matmul_coopmat) ---
140+
{
141+
uint row = a_row_base + a_row_offset;
142+
uint k_elem = chunkK + a_col * FP16_PER_VEC4;
143+
144+
#ifdef IS_FP16_INPUT
145+
uint k_hv4 = k_elem / 4;
146+
f16vec4 v0 = t_mat1[row * K4 + k_hv4];
147+
f16vec4 v1 = t_mat1[row * K4 + k_hv4 + 1];
148+
Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(
149+
packHalf2x16(vec2(v0.xy)), packHalf2x16(vec2(v0.zw)),
150+
packHalf2x16(vec2(v1.xy)), packHalf2x16(vec2(v1.zw)));
151+
#else
152+
uint k_vec4 = k_elem / 4;
153+
vec4 v0 = t_mat1[row * K4 + k_vec4];
154+
vec4 v1 = t_mat1[row * K4 + k_vec4 + 1];
155+
uvec2 h0 = f32x4_to_f16x4(v0);
156+
uvec2 h1 = f32x4_to_f16x4(v1);
157+
Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(h0, h1);
158+
#endif
159+
}
160+
161+
// --- Load B tile from packed weight -> shared ---
162+
// Packed weight format: t_weight_packed[(k4 * N4 + n4) * 4 + dk]
163+
// returns vec4 of 4 N-elements at K-row (k4*4+dk).
164+
// Load two vec4s to get 8 consecutive N-elements = one uvec4 in Bsh.
165+
{
166+
uint k_row = chunkK + b_row_offset;
167+
uint k4 = k_row >> 2u;
168+
uint dk = k_row & 3u;
169+
uint n_elem = b_col_base + b_col * FP16_PER_VEC4;
170+
uint n4_0 = n_elem >> 2u;
171+
172+
#ifdef IS_FP16_INPUT
173+
f16vec4 v0 = t_weight_packed[(k4 * N4 + n4_0) * 4u + dk];
174+
f16vec4 v1 = t_weight_packed[(k4 * N4 + n4_0 + 1u) * 4u + dk];
175+
Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4(
176+
packHalf2x16(vec2(v0.xy)), packHalf2x16(vec2(v0.zw)),
177+
packHalf2x16(vec2(v1.xy)), packHalf2x16(vec2(v1.zw)));
178+
#else
179+
vec4 v0 = t_weight_packed[(k4 * N4 + n4_0) * 4u + dk];
180+
vec4 v1 = t_weight_packed[(k4 * N4 + n4_0 + 1u) * 4u + dk];
181+
uvec2 h0 = f32x4_to_f16x4(v0);
182+
uvec2 h1 = f32x4_to_f16x4(v1);
183+
Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4(h0, h1);
184+
#endif
185+
}
186+
187+
barrier();
188+
189+
// --- Cooperative matrix MMA ---
190+
[[unroll]] for (uint k = 0; k < TILE_K / lK; ++k) {
191+
uint k_start = lK * k;
192+
193+
coopmat<float16_t, gl_ScopeSubgroup, lM, lK, gl_MatrixUseA> matA[C_ROWS];
194+
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
195+
uint row_a = lM * (C_ROWS * warpInTile.y + i);
196+
coopMatLoad(
197+
matA[i], Ash,
198+
row_a * A_STRIDE_VEC4 + k_start / FP16_PER_VEC4,
199+
A_STRIDE_VEC4,
200+
gl_CooperativeMatrixLayoutRowMajor);
201+
}
202+
203+
coopmat<float16_t, gl_ScopeSubgroup, lK, lN, gl_MatrixUseB> matB;
204+
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
205+
uint col_b = lN * (C_COLS * warpInTile.x + j) / FP16_PER_VEC4;
206+
coopMatLoad(
207+
matB, Bsh,
208+
k_start * B_STRIDE_VEC4 + col_b,
209+
B_STRIDE_VEC4,
210+
gl_CooperativeMatrixLayoutRowMajor);
211+
212+
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
213+
result[i][j] = coopMatMulAdd(matA[i], matB, result[i][j]);
214+
}
215+
}
216+
}
217+
218+
barrier();
219+
}
220+
221+
// --- Store result ---
222+
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
223+
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
224+
uint gi = TILE_M * tileID.y + lM * (C_ROWS * warpInTile.y + i);
225+
uint gj = TILE_N * tileID.x + lN * (C_COLS * warpInTile.x + j);
226+
#ifdef IS_FP16_INPUT
227+
coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> out_tile =
228+
coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator>(result[i][j]);
229+
coopMatStore(
230+
out_tile, t_output,
231+
gi * N + gj, N,
232+
gl_CooperativeMatrixLayoutRowMajor);
233+
#else
234+
coopMatStore(
235+
result[i][j], t_output,
236+
gi * N + gj, N,
237+
gl_CooperativeMatrixLayoutRowMajor);
238+
#endif
239+
}
240+
}
241+
242+
#ifdef HAS_BIAS
243+
// Add bias via read-modify-write on the output buffer.
244+
// barrier() ensures all coopMatStore writes within this workgroup are visible.
245+
barrier();
246+
247+
const uint tile_m_start = TILE_M * tileID.y;
248+
const uint tile_n_start = TILE_N * tileID.x;
249+
// 64x64 tile = 4096 elements, 256 threads -> 16 elements per thread
250+
for (uint idx = gl_LocalInvocationID.x; idx < TILE_M * TILE_N; idx += INVOCATIONS) {
251+
uint local_m = idx / TILE_N;
252+
uint local_n = idx % TILE_N;
253+
uint gm = tile_m_start + local_m;
254+
uint gn = tile_n_start + local_n;
255+
if (gm < M && gn < N) {
256+
uint out_idx = gm * N + gn;
257+
t_output[out_idx] = t_output[out_idx] + t_bias[gn];
258+
}
259+
}
260+
#endif
261+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
linear_coopmat:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
PRECISION: highp
11+
HAS_BIAS: false
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: float
15+
- VALUE: half
16+
shader_variants:
17+
- NAME: linear_coopmat
18+
- NAME: linear_coopmat_bias
19+
HAS_BIAS: true

0 commit comments

Comments
 (0)