Skip to content

Commit 40f4fa7

Browse files
authored
[ET-VK] Add VK_KHR_cooperative_matrix dispatch for linear/matmul
Differential Revision: D103971112 Pull Request resolved: #19009
1 parent ce1af10 commit 40f4fa7

17 files changed

Lines changed: 1304 additions & 169 deletions

File tree

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/*
10+
* KHR Cooperative Matrix MM kernel — unified linear + matmul.
11+
*
12+
* Variants (set in coopmat_mm.yaml):
13+
* matmul_coopmat row_major weight, no bias (aten.mm runtime mat2)
14+
* linear_coopmat prepacked weight, no bias (aten.linear)
15+
* linear_coopmat_bias prepacked weight, +bias (aten.linear w/ bias)
16+
*
17+
* Computes: D = A * B[ + bias]
18+
* A is [M, K] (row-major).
19+
* B is either [K, N] row-major (matmul), or 4OC x 4IC blocked prepacked
20+
* with t_weight[(k4 * N4 + n4) * 4 + dk] returning a vec4 of 4 N-elements
21+
* at K-row k4*4+dk (linear).
22+
* D is [M, N], buffer storage.
23+
*
24+
* fp16 x fp16 -> fp32 MMA. When DTYPE=half, inputs/outputs are native fp16
25+
* (no conversion, half the bandwidth). When DTYPE=float, inputs are fp32
26+
* with on-the-fly packHalf2x16 conversion at the shared-memory load.
27+
*
28+
* When HAS_BIAS, bias is staged once into shared memory and broadcast into
29+
* each accumulator tile (stride-0 coopMatLoad) before coopMatStore, so
30+
* t_output is write-only.
31+
*
32+
* Tile hierarchy (configurable via yaml; defaults shown for Adreno):
33+
* MMA_* per-MMA-instruction shape (16x16x16 fp16)
34+
* WG_TILE_* output tile produced per workgroup (64x64; K-step 32)
35+
* SG_GRID_* subgroup grid inside the workgroup (2x2 = 4 subgroups)
36+
* SG_TILE_* per-subgroup output tile (= WG_TILE / SG_GRID; 32x32)
37+
* SUBGROUP_SIZE hardware subgroup width (64 on Adreno)
38+
* WG_SIZE threads per workgroup (= NUM_SUBGROUPS * SUBGROUP_SIZE)
39+
*/
40+
41+
#version 450 core
42+
43+
#extension GL_KHR_cooperative_matrix : require
44+
#extension GL_KHR_memory_scope_semantics : require
45+
#extension GL_KHR_shader_subgroup_basic : enable
46+
#extension GL_EXT_shader_explicit_arithmetic_types : require
47+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
48+
#extension GL_EXT_control_flow_attributes : enable
49+
50+
#define PRECISION ${PRECISION}
51+
52+
$if DTYPE == "half":
53+
#define IS_FP16_INPUT
54+
$if DTYPE == "float":
55+
#define IS_FP32_INPUT
56+
57+
$if HAS_BIAS:
58+
#define HAS_BIAS
59+
60+
$if WEIGHT_LAYOUT == "prepacked":
61+
#define WEIGHT_PREPACKED
62+
63+
layout(std430) buffer;
64+
65+
#include "common.glslh"
66+
67+
// Bindings: output(0), mat1(1), weight(2), [bias(3)]
68+
${layout_declare_tensor(B, "w", "t_output", DTYPE, "buffer", is_scalar_array=True)}
69+
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer", is_scalar_array=False)}
70+
${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer", is_scalar_array=False)}
71+
$if HAS_BIAS:
72+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=True)}
73+
74+
// UBOs — N comes from out_sizes (linear) or mat2_sizes (matmul).
75+
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
76+
$if WEIGHT_LAYOUT == "prepacked":
77+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
78+
$else:
79+
${layout_declare_ubo(B, "ivec4", "mat2_sizes")}
80+
81+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
82+
83+
// Cooperative-matrix instruction shape (must match a property enumerated by
84+
// vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR for this device).
85+
const uint MMA_M = ${MMA_M};
86+
const uint MMA_N = ${MMA_N};
87+
const uint MMA_K = ${MMA_K};
88+
89+
// Output tile produced per workgroup.
90+
const uint WG_TILE_M = ${WG_TILE_M};
91+
const uint WG_TILE_N = ${WG_TILE_N};
92+
const uint WG_TILE_K = ${WG_TILE_K};
93+
94+
// Subgroup grid inside the workgroup.
95+
const uint SG_GRID_X = ${SG_GRID_X};
96+
const uint SG_GRID_Y = ${SG_GRID_Y};
97+
const uint SUBGROUP_SIZE = ${SUBGROUP_SIZE};
98+
const uint NUM_SUBGROUPS = SG_GRID_X * SG_GRID_Y;
99+
const uint WG_SIZE = NUM_SUBGROUPS * SUBGROUP_SIZE;
100+
101+
// Derived: per-subgroup tile and MMAs per subgroup tile.
102+
const uint SG_TILE_M = WG_TILE_M / SG_GRID_Y;
103+
const uint SG_TILE_N = WG_TILE_N / SG_GRID_X;
104+
const uint MMAS_PER_SG_M = SG_TILE_M / MMA_M;
105+
const uint MMAS_PER_SG_N = SG_TILE_N / MMA_N;
106+
107+
// fp16: 8 elements per uvec4 (128-bit)
108+
const uint FP16_PER_VEC4 = 8;
109+
110+
// Shared memory with skew padding
111+
const uint A_STRIDE_VEC4 = (WG_TILE_K + FP16_PER_VEC4) / FP16_PER_VEC4;
112+
const uint B_STRIDE_VEC4 = (WG_TILE_N + FP16_PER_VEC4) / FP16_PER_VEC4;
113+
114+
shared uvec4 Ash[WG_TILE_M * A_STRIDE_VEC4];
115+
shared uvec4 Bsh[WG_TILE_K * B_STRIDE_VEC4];
116+
117+
#ifdef HAS_BIAS
118+
// fp32 staging buffer so coopMatLoad can broadcast directly into the
119+
// fp32 accumulator coopmat without a type conversion at the load.
120+
shared float bias_sh[WG_TILE_N];
121+
#endif
122+
123+
// Accumulator tiles (fp32)
124+
coopmat<float, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator> result[MMAS_PER_SG_M][MMAS_PER_SG_N];
125+
126+
#ifdef IS_FP32_INPUT
127+
uvec2 f32x4_to_f16x4(vec4 v) {
128+
return uvec2(packHalf2x16(v.xy), packHalf2x16(v.zw));
129+
}
130+
#endif
131+
132+
void main() {
133+
const uvec2 tileID = uvec2(gl_WorkGroupID.xy);
134+
const uvec2 warpInTile = uvec2(
135+
gl_SubgroupID % SG_GRID_X,
136+
gl_SubgroupID / SG_GRID_X);
137+
138+
const uint K = uint(mat1_sizes.x);
139+
const uint M = uint(mat1_sizes.y);
140+
#ifdef WEIGHT_PREPACKED
141+
const uint N = uint(out_sizes.x);
142+
#else
143+
const uint N = uint(mat2_sizes.x);
144+
#endif
145+
const uint K4 = (K + 3u) / 4u;
146+
const uint N4 = (N + 3u) / 4u;
147+
148+
// Defensive: skip workgroups whose tile is out of bounds. The C++ pick
149+
// function dispatches exactly num_tiles_n x num_tiles_m workgroups under
150+
// the alignment-gated (M%WG_TILE_M==0, N%WG_TILE_N==0) inputs, so this
151+
// never triggers today; it guards against a future dispatch error.
152+
const uint num_tiles_n = (N + WG_TILE_N - 1u) / WG_TILE_N;
153+
const uint num_tiles_m = (M + WG_TILE_M - 1u) / WG_TILE_M;
154+
if (tileID.x >= num_tiles_n || tileID.y >= num_tiles_m) {
155+
return;
156+
}
157+
158+
[[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) {
159+
[[unroll]] for (uint j = 0; j < MMAS_PER_SG_N; ++j) {
160+
result[i][j] = coopmat<float, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator>(0.0);
161+
}
162+
}
163+
164+
// Thread assignment for A tile (WG_TILE_M rows x INVS_PER_ROW_A uvec4/row)
165+
const uint INVS_PER_ROW_A = WG_TILE_K / FP16_PER_VEC4;
166+
const uint a_col = gl_LocalInvocationID.x % INVS_PER_ROW_A;
167+
const uint a_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_A;
168+
169+
// Thread assignment for B tile (WG_TILE_K rows x INVS_PER_ROW_B uvec4/row)
170+
const uint INVS_PER_ROW_B = WG_TILE_N / FP16_PER_VEC4;
171+
const uint b_col = gl_LocalInvocationID.x % INVS_PER_ROW_B;
172+
const uint b_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_B;
173+
174+
const uint a_row_base = WG_TILE_M * tileID.y;
175+
const uint b_col_base = WG_TILE_N * tileID.x;
176+
177+
for (uint chunkK = 0; chunkK < K; chunkK += WG_TILE_K) {
178+
179+
// --- Load A tile -> shared (single pass) ---
180+
{
181+
uint row = a_row_base + a_row_offset;
182+
uint k_elem = chunkK + a_col * FP16_PER_VEC4;
183+
184+
#ifdef IS_FP16_INPUT
185+
uint k_hv4 = k_elem / 4;
186+
f16vec4 v0 = t_mat1[row * K4 + k_hv4];
187+
f16vec4 v1 = t_mat1[row * K4 + k_hv4 + 1];
188+
Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(
189+
packFloat2x16(v0.xy), packFloat2x16(v0.zw),
190+
packFloat2x16(v1.xy), packFloat2x16(v1.zw));
191+
#else
192+
uint k_vec4 = k_elem / 4;
193+
vec4 v0 = t_mat1[row * K4 + k_vec4];
194+
vec4 v1 = t_mat1[row * K4 + k_vec4 + 1];
195+
uvec2 h0 = f32x4_to_f16x4(v0);
196+
uvec2 h1 = f32x4_to_f16x4(v1);
197+
Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(h0, h1);
198+
#endif
199+
}
200+
201+
// --- Load B tile -> shared (single pass) ---
202+
{
203+
uint k_row = chunkK + b_row_offset;
204+
uint n_elem = b_col_base + b_col * FP16_PER_VEC4;
205+
uint n4_0 = n_elem >> 2u;
206+
#ifdef WEIGHT_PREPACKED
207+
// Prepacked: t_weight[(k4 * N4 + n4) * 4 + dk] yields vec4 of
208+
// 4 N-elements at K-row (k4*4+dk).
209+
uint k4 = k_row >> 2u;
210+
uint dk = k_row & 3u;
211+
uint b_idx0 = (k4 * N4 + n4_0) * 4u + dk;
212+
uint b_idx1 = (k4 * N4 + n4_0 + 1u) * 4u + dk;
213+
#else
214+
// Row-major: t_weight[k_row * N4 + n4] yields vec4 of 4 N-elements.
215+
uint b_idx0 = k_row * N4 + n4_0;
216+
uint b_idx1 = k_row * N4 + n4_0 + 1u;
217+
#endif
218+
219+
#ifdef IS_FP16_INPUT
220+
f16vec4 v0 = t_weight[b_idx0];
221+
f16vec4 v1 = t_weight[b_idx1];
222+
Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4(
223+
packFloat2x16(v0.xy), packFloat2x16(v0.zw),
224+
packFloat2x16(v1.xy), packFloat2x16(v1.zw));
225+
#else
226+
vec4 v0 = t_weight[b_idx0];
227+
vec4 v1 = t_weight[b_idx1];
228+
uvec2 h0 = f32x4_to_f16x4(v0);
229+
uvec2 h1 = f32x4_to_f16x4(v1);
230+
Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4(h0, h1);
231+
#endif
232+
}
233+
234+
barrier();
235+
236+
// --- Cooperative matrix MMA ---
237+
[[unroll]] for (uint k = 0; k < WG_TILE_K / MMA_K; ++k) {
238+
uint k_start = MMA_K * k;
239+
240+
coopmat<float16_t, gl_ScopeSubgroup, MMA_M, MMA_K, gl_MatrixUseA> matA[MMAS_PER_SG_M];
241+
[[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) {
242+
uint row_a = MMA_M * (MMAS_PER_SG_M * warpInTile.y + i);
243+
coopMatLoad(
244+
matA[i], Ash,
245+
row_a * A_STRIDE_VEC4 + k_start / FP16_PER_VEC4,
246+
A_STRIDE_VEC4,
247+
gl_CooperativeMatrixLayoutRowMajor);
248+
}
249+
250+
coopmat<float16_t, gl_ScopeSubgroup, MMA_K, MMA_N, gl_MatrixUseB> matB;
251+
[[unroll]] for (uint j = 0; j < MMAS_PER_SG_N; ++j) {
252+
uint col_b = MMA_N * (MMAS_PER_SG_N * warpInTile.x + j) / FP16_PER_VEC4;
253+
coopMatLoad(
254+
matB, Bsh,
255+
k_start * B_STRIDE_VEC4 + col_b,
256+
B_STRIDE_VEC4,
257+
gl_CooperativeMatrixLayoutRowMajor);
258+
259+
[[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) {
260+
result[i][j] = coopMatMulAdd(matA[i], matB, result[i][j]);
261+
}
262+
}
263+
}
264+
265+
barrier();
266+
}
267+
268+
#ifdef HAS_BIAS
269+
// Stage one WG_TILE_N-wide row of bias into shared memory. The C++ dispatch
270+
// gate ensures N % WG_TILE_N == 0, so no per-element bounds check is needed.
271+
{
272+
const uint tile_n_start = WG_TILE_N * tileID.x;
273+
for (uint t = gl_LocalInvocationID.x; t < WG_TILE_N; t += WG_SIZE) {
274+
bias_sh[t] = float(t_bias[tile_n_start + t]);
275+
}
276+
}
277+
memoryBarrierShared();
278+
barrier();
279+
#endif
280+
281+
// --- Store result (with bias folded in pre-store, if present) ---
282+
[[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) {
283+
[[unroll]] for (uint j = 0; j < MMAS_PER_SG_N; ++j) {
284+
uint gi = WG_TILE_M * tileID.y + MMA_M * (MMAS_PER_SG_M * warpInTile.y + i);
285+
uint gj = WG_TILE_N * tileID.x + MMA_N * (MMAS_PER_SG_N * warpInTile.x + j);
286+
287+
#ifdef HAS_BIAS
288+
// Stride-0 row-major load broadcasts MMA_N bias values across all
289+
// MMA_M rows of the accumulator tile.
290+
uint local_n = MMA_N * (MMAS_PER_SG_N * warpInTile.x + j);
291+
coopmat<float, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator> bias_tile;
292+
coopMatLoad(
293+
bias_tile, bias_sh,
294+
local_n, /*stride=*/0u,
295+
gl_CooperativeMatrixLayoutRowMajor);
296+
result[i][j] += bias_tile;
297+
#endif
298+
299+
#ifdef IS_FP16_INPUT
300+
coopmat<float16_t, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator> out_tile =
301+
coopmat<float16_t, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator>(result[i][j]);
302+
coopMatStore(
303+
out_tile, t_output,
304+
gi * N + gj, N,
305+
gl_CooperativeMatrixLayoutRowMajor);
306+
#else
307+
coopMatStore(
308+
result[i][j], t_output,
309+
gi * N + gj, N,
310+
gl_CooperativeMatrixLayoutRowMajor);
311+
#endif
312+
}
313+
}
314+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Unified KHR Cooperative Matrix MM kernel (linear + matmul).
8+
# Three shader variants over two weight layouts:
9+
# matmul_coopmat row_major weight, no bias (aten.mm runtime mat2)
10+
# linear_coopmat prepacked weight, no bias (aten.linear)
11+
# linear_coopmat_bias prepacked weight, +bias (aten.linear with bias)
12+
13+
coopmat_mm:
14+
parameter_names_with_default_values:
15+
DTYPE: float
16+
PRECISION: highp
17+
WEIGHT_LAYOUT: row_major
18+
HAS_BIAS: false
19+
MMA_M: 16
20+
MMA_N: 16
21+
MMA_K: 16
22+
WG_TILE_M: 64
23+
WG_TILE_N: 64
24+
WG_TILE_K: 32
25+
SG_GRID_X: 2
26+
SG_GRID_Y: 2
27+
SUBGROUP_SIZE: 64
28+
generate_variant_forall:
29+
DTYPE:
30+
- VALUE: float
31+
- VALUE: half
32+
shader_variants:
33+
- NAME: matmul_coopmat
34+
WEIGHT_LAYOUT: row_major
35+
- NAME: linear_coopmat
36+
WEIGHT_LAYOUT: prepacked
37+
- NAME: linear_coopmat_bias
38+
WEIGHT_LAYOUT: prepacked
39+
HAS_BIAS: true

0 commit comments

Comments
 (0)